rpc: add option to configure TextMapPropagator on client (#35132)

This adds a client option to configure trace context propagation via the
`traceparent` HTTP header.
I'm adding this so that prysm can enable distributed tracing on their
engine API client.
This commit is contained in:
Felix Lange 2026-06-10 19:04:54 +02:00 committed by GitHub
parent 39b17c5585
commit e444c267a2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 94 additions and 0 deletions

View file

@ -20,6 +20,7 @@ import (
"net/http"
"github.com/gorilla/websocket"
"go.opentelemetry.io/otel/propagation"
)
// ClientOption is a configuration option for the RPC client.
@ -32,6 +33,7 @@ type clientConfig struct {
httpClient *http.Client
httpHeaders http.Header
httpAuth HTTPAuth
tmprop propagation.TextMapPropagator
// WebSocket options
wsDialer *websocket.Dialer
@ -95,6 +97,22 @@ func WithHeaders(headers http.Header) ClientOption {
})
}
// WithTextMapPropagator configures OpenTelemetry trace propagation.
// Note, by default, trace context is NOT propagated by rpc.Client.
// To enable propagation via the `traceparent` header, you must explicitly
// enable it by setting a propagator, e.g.
//
// prop := propagation.TraceContext{}
// c, err := rpc.DialOptions(ctx, "http://", rpc.WithTextMapPropagator(prop))
func WithTextMapPropagator(tmp propagation.TextMapPropagator) ClientOption {
if tmp == nil {
panic("nil TextMapPropagator configured")
}
return optionFunc(func(cfg *clientConfig) {
cfg.tmprop = tmp
})
}
// WithHTTPClient configures the http.Client used by the RPC client.
func WithHTTPClient(c *http.Client) ClientOption {
return optionFunc(func(cfg *clientConfig) {

View file

@ -52,6 +52,7 @@ type httpConn struct {
mu sync.Mutex // protects headers
headers http.Header
auth HTTPAuth
tmprop propagation.TextMapPropagator
}
// httpConn implements ServerCodec, but it is treated specially by Client
@ -168,6 +169,7 @@ func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc {
headers: headers,
url: endpoint,
auth: cfg.httpAuth,
tmprop: cfg.tmprop,
closeCh: make(chan interface{}),
}
@ -230,6 +232,9 @@ func (hc *httpConn) doRequest(ctx context.Context, body []byte) (io.ReadCloser,
req.Header = hc.headers.Clone()
hc.mu.Unlock()
setHeaders(req.Header, headersFromContext(ctx))
if hc.tmprop != nil {
hc.tmprop.Inject(ctx, propagation.HeaderCarrier(req.Header))
}
if hc.auth != nil {
if err := hc.auth(req.Header); err != nil {

View file

@ -18,6 +18,7 @@ package rpc
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
@ -541,6 +542,76 @@ func TestTracingBatchHTTPTooLarge(t *testing.T) {
}
}
// newHeaderRecordingServer creates an HTTP test server that responds to any
// JSON-RPC call and records the traceparent header of incoming requests.
func newHeaderRecordingServer(t *testing.T, headerCh chan<- string) *httptest.Server {
t.Helper()
httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
headerCh <- r.Header.Get("traceparent")
var msg jsonrpcMessage
if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
t.Errorf("invalid request body: %v", err)
}
resp := jsonrpcMessage{Version: vsn, ID: msg.ID, Result: []byte("null")}
w.Header().Set("Content-Type", contentType)
json.NewEncoder(w).Encode(&resp)
}))
t.Cleanup(httpsrv.Close)
return httpsrv
}
// TestTracingClientPropagation verifies that the client injects the W3C
// traceparent header into outgoing HTTP requests when configured with the
// WithTextMapPropagator option.
func TestTracingClientPropagation(t *testing.T) {
t.Parallel()
headerCh := make(chan string, 1)
httpsrv := newHeaderRecordingServer(t, headerCh)
client, err := DialOptions(context.Background(), httpsrv.URL, WithTextMapPropagator(propagation.TraceContext{}))
if err != nil {
t.Fatalf("failed to dial: %v", err)
}
t.Cleanup(client.Close)
// Build a context carrying a sampled remote span context.
const (
traceID = "4bf92f3577b34da6a3ce929d0e0e4736"
spanID = "00f067aa0ba902b7"
)
tid, err := trace.TraceIDFromHex(traceID)
if err != nil {
t.Fatal(err)
}
sid, err := trace.SpanIDFromHex(spanID)
if err != nil {
t.Fatal(err)
}
sc := trace.NewSpanContext(trace.SpanContextConfig{
TraceID: tid,
SpanID: sid,
TraceFlags: trace.FlagsSampled,
})
ctx := trace.ContextWithSpanContext(context.Background(), sc)
if err := client.CallContext(ctx, nil, "test_foo"); err != nil {
t.Fatalf("RPC call failed: %v", err)
}
want := "00-" + traceID + "-" + spanID + "-01"
if got := <-headerCh; got != want {
t.Errorf("traceparent header: got %q, want %q", got, want)
}
// A call without a span context in ctx must not produce a traceparent header.
if err := client.CallContext(context.Background(), nil, "test_foo"); err != nil {
t.Fatalf("RPC call failed: %v", err)
}
if got := <-headerCh; got != "" {
t.Errorf("traceparent header without span context: got %q, want none", got)
}
}
// TestTracingHTTPTimeout verifies that when a non-batch call exceeds the HTTP
// server's WriteTimeout, the SERVER span ends with error status (carrying the
// timeout error message).