diff --git a/rpc/client_opt.go b/rpc/client_opt.go index 3fa045a9b9..482755c888 100644 --- a/rpc/client_opt.go +++ b/rpc/client_opt.go @@ -20,6 +20,7 @@ import ( "net/http" "github.com/gorilla/websocket" + "go.opentelemetry.io/otel/propagation" ) // ClientOption is a configuration option for the RPC client. @@ -32,6 +33,7 @@ type clientConfig struct { httpClient *http.Client httpHeaders http.Header httpAuth HTTPAuth + tmprop propagation.TextMapPropagator // WebSocket options wsDialer *websocket.Dialer @@ -95,6 +97,22 @@ func WithHeaders(headers http.Header) ClientOption { }) } +// WithTextMapPropagator configures OpenTelemetry trace propagation. +// Note, by default, trace context is NOT propagated by rpc.Client. +// To enable propagation via the `traceparent` header, you must explicitly +// enable it by setting a propagator, e.g. +// +// prop := propagation.TraceContext{} +// c, err := rpc.DialOptions(ctx, "http://", rpc.WithTextMapPropagator(prop)) +func WithTextMapPropagator(tmp propagation.TextMapPropagator) ClientOption { + if tmp == nil { + panic("nil TextMapPropagator configured") + } + return optionFunc(func(cfg *clientConfig) { + cfg.tmprop = tmp + }) +} + // WithHTTPClient configures the http.Client used by the RPC client. func WithHTTPClient(c *http.Client) ClientOption { return optionFunc(func(cfg *clientConfig) { diff --git a/rpc/http.go b/rpc/http.go index 2bd761e9cd..6340175736 100644 --- a/rpc/http.go +++ b/rpc/http.go @@ -52,6 +52,7 @@ type httpConn struct { mu sync.Mutex // protects headers headers http.Header auth HTTPAuth + tmprop propagation.TextMapPropagator } // httpConn implements ServerCodec, but it is treated specially by Client @@ -168,6 +169,7 @@ func newClientTransportHTTP(endpoint string, cfg *clientConfig) reconnectFunc { headers: headers, url: endpoint, auth: cfg.httpAuth, + tmprop: cfg.tmprop, closeCh: make(chan interface{}), } @@ -230,6 +232,9 @@ func (hc *httpConn) doRequest(ctx context.Context, body []byte) (io.ReadCloser, req.Header = hc.headers.Clone() hc.mu.Unlock() setHeaders(req.Header, headersFromContext(ctx)) + if hc.tmprop != nil { + hc.tmprop.Inject(ctx, propagation.HeaderCarrier(req.Header)) + } if hc.auth != nil { if err := hc.auth(req.Header); err != nil { diff --git a/rpc/tracing_test.go b/rpc/tracing_test.go index 58dc2f1758..fa2ef09e76 100644 --- a/rpc/tracing_test.go +++ b/rpc/tracing_test.go @@ -18,6 +18,7 @@ package rpc import ( "context" + "encoding/json" "io" "net/http" "net/http/httptest" @@ -541,6 +542,76 @@ func TestTracingBatchHTTPTooLarge(t *testing.T) { } } +// newHeaderRecordingServer creates an HTTP test server that responds to any +// JSON-RPC call and records the traceparent header of incoming requests. +func newHeaderRecordingServer(t *testing.T, headerCh chan<- string) *httptest.Server { + t.Helper() + httpsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + headerCh <- r.Header.Get("traceparent") + var msg jsonrpcMessage + if err := json.NewDecoder(r.Body).Decode(&msg); err != nil { + t.Errorf("invalid request body: %v", err) + } + resp := jsonrpcMessage{Version: vsn, ID: msg.ID, Result: []byte("null")} + w.Header().Set("Content-Type", contentType) + json.NewEncoder(w).Encode(&resp) + })) + t.Cleanup(httpsrv.Close) + return httpsrv +} + +// TestTracingClientPropagation verifies that the client injects the W3C +// traceparent header into outgoing HTTP requests when configured with the +// WithTextMapPropagator option. +func TestTracingClientPropagation(t *testing.T) { + t.Parallel() + + headerCh := make(chan string, 1) + httpsrv := newHeaderRecordingServer(t, headerCh) + + client, err := DialOptions(context.Background(), httpsrv.URL, WithTextMapPropagator(propagation.TraceContext{})) + if err != nil { + t.Fatalf("failed to dial: %v", err) + } + t.Cleanup(client.Close) + + // Build a context carrying a sampled remote span context. + const ( + traceID = "4bf92f3577b34da6a3ce929d0e0e4736" + spanID = "00f067aa0ba902b7" + ) + tid, err := trace.TraceIDFromHex(traceID) + if err != nil { + t.Fatal(err) + } + sid, err := trace.SpanIDFromHex(spanID) + if err != nil { + t.Fatal(err) + } + sc := trace.NewSpanContext(trace.SpanContextConfig{ + TraceID: tid, + SpanID: sid, + TraceFlags: trace.FlagsSampled, + }) + ctx := trace.ContextWithSpanContext(context.Background(), sc) + + if err := client.CallContext(ctx, nil, "test_foo"); err != nil { + t.Fatalf("RPC call failed: %v", err) + } + want := "00-" + traceID + "-" + spanID + "-01" + if got := <-headerCh; got != want { + t.Errorf("traceparent header: got %q, want %q", got, want) + } + + // A call without a span context in ctx must not produce a traceparent header. + if err := client.CallContext(context.Background(), nil, "test_foo"); err != nil { + t.Fatalf("RPC call failed: %v", err) + } + if got := <-headerCh; got != "" { + t.Errorf("traceparent header without span context: got %q, want none", got) + } +} + // TestTracingHTTPTimeout verifies that when a non-batch call exceeds the HTTP // server's WriteTimeout, the SERVER span ends with error status (carrying the // timeout error message).