From 36ae722fcb152205e199cf52548e42a7ca684357 Mon Sep 17 00:00:00 2001 From: Przemek Kowalewski Date: Wed, 25 Mar 2026 17:31:10 +0100 Subject: [PATCH] Feat: Improve tracing for proxy server --- go.mod | 2 +- proxy/server/server.go | 16 +- proxy/server/target.go | 69 +++++++- proxy/server/tracing.go | 152 ++++++++++++++++++ proxy/server/tracing_test.go | 301 +++++++++++++++++++++++++++++++++++ 5 files changed, 531 insertions(+), 9 deletions(-) create mode 100644 proxy/server/tracing.go create mode 100644 proxy/server/tracing_test.go diff --git a/go.mod b/go.mod index cd1f1d2e..1781c16a 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( go.opentelemetry.io/otel v1.38.0 go.opentelemetry.io/otel/exporters/prometheus v0.50.0 go.opentelemetry.io/otel/metric v1.38.0 + go.opentelemetry.io/otel/sdk v1.38.0 go.opentelemetry.io/otel/sdk/metric v1.38.0 go.opentelemetry.io/otel/trace v1.38.0 gocloud.dev v0.32.0 @@ -108,7 +109,6 @@ require ( go.opencensus.io v0.24.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 // indirect - go.opentelemetry.io/otel/sdk v1.38.0 // indirect golang.org/x/crypto v0.44.0 // indirect golang.org/x/net v0.47.0 // indirect golang.org/x/oauth2 v0.32.0 // indirect diff --git a/proxy/server/server.go b/proxy/server/server.go index f9346648..6c024a81 100644 --- a/proxy/server/server.go +++ b/proxy/server/server.go @@ -23,6 +23,7 @@ import ( "fmt" "io" + "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" "google.golang.org/grpc" @@ -248,7 +249,10 @@ func dispatch(ctx context.Context, stream pb.Proxy_ProxyServer, requestChan chan // be removed from the stream set doneChan := make(chan uint64) recorder := metrics.RecorderFromContextOrNoop(ctx) + rootSpan := trace.SpanFromContext(ctx) var addedPeerToContext bool + var targetCount int + defer func() { rootSpan.SetAttributes(attrProxyTargetCount.Int(targetCount)) }() for { select { case <-ctx.Done(): @@ -279,15 +283,23 @@ func dispatch(ctx context.Context, stream pb.Proxy_ProxyServer, requestChan chan // Peer information might not be properly populated until rpcauth // evaluates the initial received message, so let's grab fresh // peer information when we know we've gotten at least one message. - ctx = rpcauth.AddPeerToContext(ctx, rpcauth.PeerInputFromContext(stream.Context())) + peerInfo := rpcauth.PeerInputFromContext(stream.Context()) + ctx = rpcauth.AddPeerToContext(ctx, peerInfo) + enrichRootSpan(ctx, peerInfo) addedPeerToContext = true } // We have a new request switch req.Request.(type) { case *pb.ProxyRequest_StartStream: - if err := streamSet.Add(ctx, req.GetStartStream(), replyChan, doneChan); err != nil { + ss := req.GetStartStream() + if err := streamSet.Add(ctx, ss, replyChan, doneChan); err != nil { return err } + targetCount++ + rootSpan.AddEvent("dispatch.start_stream", trace.WithAttributes( + attrTargetAddress.String(ss.GetTarget()), + attrTargetMethod.String(ss.GetMethodName()), + )) case *pb.ProxyRequest_StreamData: if err := streamSet.Send(ctx, req.GetStreamData()); err != nil { return err diff --git a/proxy/server/target.go b/proxy/server/target.go index b85e2f8a..da6a2f44 100644 --- a/proxy/server/target.go +++ b/proxy/server/target.go @@ -27,6 +27,9 @@ import ( "time" "github.com/go-logr/logr" + "go.opentelemetry.io/otel/attribute" + otelcodes "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -192,28 +195,52 @@ func (s *TargetStream) Send(req proto.Message) error { // messages for sending to a proxy client, including the final // status of the target stream func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { - group, ctx := errgroup.WithContext(s.ctx) + // Create a child span for this target stream. + spanAttrs := []attribute.KeyValue{ + attrTargetAddress.String(s.target), + attrTargetMethod.String(s.serviceMethod.FullName()), + attrTargetStreamID.String(fmt.Sprintf("%d", s.streamID)), + attrTargetStreamType.String(streamType(s.serviceMethod.ClientStreams(), s.serviceMethod.ServerStreams())), + attrTargetAuthzDryRun.Bool(s.authzDryRun), + } + if s.dialTimeout != nil { + spanAttrs = append(spanAttrs, attrTargetDialTimeoutMs.Int64(s.dialTimeout.Milliseconds())) + } + spanCtx, span := getTracer().Start(s.ctx, "proxy.target"+s.serviceMethod.FullName(), + trace.WithAttributes(spanAttrs...), + ) + defer span.End() + + group, ctx := errgroup.WithContext(spanCtx) peer := rpcauth.PeerInputFromContext(ctx) if peer != nil && peer.Principal != nil { // Unconditionally add information on the original caller to outgoing RPCs ctx = proxiedidentity.AppendToMetadataInOutgoingContext(ctx, peer.Principal) + span.SetAttributes(attrTargetProxiedPrincipal.String(peer.Principal.ID)) } group.Go(func() error { - dialCtx, cancel := context.WithCancel(ctx) + // Sub-span for dial + stream creation. + dialCtx, dialSpan := getTracer().Start(ctx, "proxy.target.dial", + trace.WithAttributes(attrTargetAddress.String(s.target)), + ) + dialCtx, dialCancel := context.WithCancel(dialCtx) var opts []grpc.DialOption if s.dialTimeout != nil { - dialCtx, cancel = context.WithTimeout(ctx, *s.dialTimeout) + dialCtx, dialCancel = context.WithTimeout(dialCtx, *s.dialTimeout) opts = append(opts, grpc.WithBlock()) } var err error - defer cancel() + defer dialCancel() grpcConn, err := s.dialer.DialContext(dialCtx, s.target, opts...) if err != nil { // We cannot create a new stream to the target. So we need to cancel this stream. s.logger.Info("unable to create stream", "status", err) s.cancelFunc() + dialSpan.RecordError(err) + dialSpan.SetStatus(otelcodes.Error, "dial failed") + dialSpan.End() return fmt.Errorf("could not connect to target from the proxy: %w", err) } s.grpcConn = grpcConn @@ -221,8 +248,13 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { if err != nil { // We cannot create a new stream to the target. So we need to cancel this stream. s.logger.Info("unable to create stream", "status", err) + dialSpan.RecordError(err) + dialSpan.SetStatus(otelcodes.Error, "new stream failed") + dialSpan.End() return fmt.Errorf("could not connect to target from the proxy: %w", err) } + dialSpan.End() + span.AddEvent("stream.connected") // We've successfully connected and can replace the initial unconnected stream // with the target stream. @@ -230,6 +262,7 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { // Receives messages from the server stream group.Go(func() error { + var receivedFirst bool for { msg := s.serviceMethod.NewReply() err := grpcStream.RecvMsg(msg) @@ -262,7 +295,10 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { s.CloseWith(err) return fmt.Errorf("proxy could not receive response from the target: %w", err) } - // otherwise, this is a streamData reply + if !receivedFirst { + span.AddEvent("stream.first_response") + receivedFirst = true + } packed, err := anypb.New(msg) if err != nil { return err @@ -326,9 +362,19 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { // If authz fails, close immediately with an error if err := s.authorizer.Eval(ctx, authinput); err != nil { + span.AddEvent("authz.evaluated", trace.WithAttributes( + attrAuthzResult.String("denied"), + attrAuthzMethod.String(s.Method()), + )) + span.RecordError(err) + span.SetStatus(otelcodes.Error, "authz denied") s.CloseWith(err) return err } + span.AddEvent("authz.evaluated", trace.WithAttributes( + attrAuthzResult.String("allowed"), + attrAuthzMethod.String(s.Method()), + )) if s.authzDryRun { // TODO: make authz dry run request to server @@ -380,12 +426,23 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { case err = <-s.errChan: default: } + + // Record final stream status on span. + grpcStatus := status.Convert(err) + span.AddEvent("stream.finished", trace.WithAttributes( + attrGRPCStatusCode.String(grpcStatus.Code().String()), + )) + if err != nil { + span.RecordError(err) + span.SetStatus(otelcodes.Error, grpcStatus.Message()) + } + s.logger.Info("finished", "status", err) reply := &pb.ProxyReply{ Reply: &pb.ProxyReply_ServerClose{ ServerClose: &pb.ServerClose{ StreamIds: []uint64{s.streamID}, - Status: convertStatus(status.Convert(err)), + Status: convertStatus(grpcStatus), }, }, } diff --git a/proxy/server/tracing.go b/proxy/server/tracing.go new file mode 100644 index 00000000..171112c0 --- /dev/null +++ b/proxy/server/tracing.go @@ -0,0 +1,152 @@ +/* Copyright (c) 2019 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package server + +import ( + "context" + "strings" + "sync" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "google.golang.org/grpc/metadata" + + "github.com/Snowflake-Labs/sansshell/auth/rpcauth" +) + +const tracerName = "sansshell-proxy" + +var ( + tracerMu sync.RWMutex + currentTracer trace.Tracer +) + +func init() { currentTracer = otel.Tracer(tracerName) } + +func getTracer() trace.Tracer { + tracerMu.RLock() + t := currentTracer + tracerMu.RUnlock() + return t +} + +// setTracer replaces the active tracer and returns a function that restores +// the previous one. Safe for concurrent use. +func setTracer(t trace.Tracer) (restore func()) { + tracerMu.Lock() + prev := currentTracer + currentTracer = t + tracerMu.Unlock() + return func() { + tracerMu.Lock() + currentTracer = prev + tracerMu.Unlock() + } +} + +// Span attribute keys for caller identity. +const ( + attrCallerPrincipal attribute.Key = "sansshell.caller.principal" + attrCallerGroups attribute.Key = "sansshell.caller.groups" + attrCallerAddress attribute.Key = "sansshell.caller.address" + attrCallerCertCN attribute.Key = "sansshell.caller.cert.cn" + attrCallerCertSPIFFE attribute.Key = "sansshell.caller.cert.spiffe_id" +) + +// Span attribute keys for target stream information. +const ( + attrTargetAddress attribute.Key = "sansshell.target.address" + attrTargetMethod attribute.Key = "sansshell.target.method" + attrTargetStreamID attribute.Key = "sansshell.target.stream_id" + attrTargetStreamType attribute.Key = "sansshell.target.stream_type" + attrTargetDialTimeoutMs attribute.Key = "sansshell.target.dial_timeout_ms" + attrTargetAuthzDryRun attribute.Key = "sansshell.target.authz_dry_run" + attrTargetProxiedPrincipal attribute.Key = "sansshell.target.proxied_principal" +) + +// Span attribute keys for aggregate proxy-level information. +const ( + attrProxyTargetCount attribute.Key = "sansshell.proxy.target_count" + attrProxyJustification attribute.Key = "sansshell.proxy.justification" +) + +// Span attribute keys for authz events. +const ( + attrAuthzResult attribute.Key = "sansshell.authz.result" + attrAuthzMethod attribute.Key = "sansshell.authz.method" +) + +// Span attribute key for stream finish status. +const attrGRPCStatusCode attribute.Key = "grpc.status_code" + +func streamType(clientStreams, serverStreams bool) string { + switch { + case clientStreams && serverStreams: + return "bidi" + case clientStreams: + return "client_stream" + case serverStreams: + return "server_stream" + default: + return "unary" + } +} + +func callerAttrsFromPeer(peer *rpcauth.PeerAuthInput) []attribute.KeyValue { + if peer == nil { + return nil + } + var attrs []attribute.KeyValue + if peer.Net != nil && peer.Net.Address != "" { + attrs = append(attrs, attrCallerAddress.String(peer.Net.Address)) + } + if peer.Principal != nil { + if peer.Principal.ID != "" { + attrs = append(attrs, attrCallerPrincipal.String(peer.Principal.ID)) + } + if len(peer.Principal.Groups) > 0 { + attrs = append(attrs, attrCallerGroups.String(strings.Join(peer.Principal.Groups, ","))) + } + } + if peer.Cert != nil { + if peer.Cert.Subject.CommonName != "" { + attrs = append(attrs, attrCallerCertCN.String(peer.Cert.Subject.CommonName)) + } + if peer.Cert.SPIFFEID != "" { + attrs = append(attrs, attrCallerCertSPIFFE.String(peer.Cert.SPIFFEID)) + } + } + return attrs +} + +// enrichRootSpan sets caller identity and justification attributes on the +// current span in ctx. It is called from dispatch once peer info is available. +func enrichRootSpan(ctx context.Context, peer *rpcauth.PeerAuthInput) { + span := trace.SpanFromContext(ctx) + if !span.IsRecording() { + return + } + if attrs := callerAttrsFromPeer(peer); len(attrs) > 0 { + span.SetAttributes(attrs...) + } + if md, ok := metadata.FromIncomingContext(ctx); ok { + if vals := md.Get(rpcauth.ReqJustKey); len(vals) > 0 { + span.SetAttributes(attrProxyJustification.String(vals[0])) + } + } +} diff --git a/proxy/server/tracing_test.go b/proxy/server/tracing_test.go new file mode 100644 index 00000000..7c90bdf2 --- /dev/null +++ b/proxy/server/tracing_test.go @@ -0,0 +1,301 @@ +/* Copyright (c) 2019 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package server + +import ( + "context" + "testing" + + "go.opentelemetry.io/otel/sdk/trace/tracetest" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/test/bufconn" + + sdktrace "go.opentelemetry.io/otel/sdk/trace" + + "github.com/Snowflake-Labs/sansshell/auth/rpcauth" + pb "github.com/Snowflake-Labs/sansshell/proxy" + tdpb "github.com/Snowflake-Labs/sansshell/proxy/testdata" + "github.com/Snowflake-Labs/sansshell/proxy/testutil" +) + +func setupTracing(t *testing.T) (*tracetest.InMemoryExporter, *sdktrace.TracerProvider) { + t.Helper() + exporter := tracetest.NewInMemoryExporter() + tp := sdktrace.NewTracerProvider(sdktrace.WithSyncer(exporter)) + restore := setTracer(tp.Tracer(tracerName)) + t.Cleanup(func() { + _ = tp.Shutdown(context.Background()) + restore() + }) + return exporter, tp +} + +// runTargetStream creates a TargetStream via NewTargetStream and executes a +// unary RPC by calling Run() directly. A buffered replyChan is used so that +// the unconditional send at the end of Run() never blocks, sidestepping the +// pre-existing dispatch-level deadlock that surfaces when going through the +// full proxy pipeline. +func runTargetStream(t *testing.T, targets map[string]*bufconn.Listener, target, method string, authz rpcauth.RPCAuthorizer, req *tdpb.TestRequest) { + t.Helper() + ctx := context.Background() + + dialer := NewDialer(testutil.WithBufDialer(targets), grpc.WithTransportCredentials(insecure.NewCredentials())) + svcMap := LoadGlobalServiceMap() + svcMethod := svcMap[method] + if svcMethod == nil { + t.Fatalf("unknown service method %s", method) + } + + peerInfo := &rpcauth.PeerAuthInput{ + Net: &rpcauth.NetAuthInput{Network: "bufconn", Address: "test-caller"}, + } + ctx = rpcauth.AddPeerToContext(ctx, peerInfo) + + ts, err := NewTargetStream(ctx, target, dialer, nil, svcMethod, authz, false) + if err != nil { + t.Fatal(err) + } + if err := ts.Send(req); err != nil { + t.Fatal(err) + } + ts.CloseSend() + + replyChan := make(chan *pb.ProxyReply, 10) + ts.Run(1, replyChan) +} + +// ---------- span assertion helpers ---------- + +func findSpan(spans tracetest.SpanStubs, name string) *tracetest.SpanStub { + for i := range spans { + if spans[i].Name == name { + return &spans[i] + } + } + return nil +} + +func findTargetSpan(spans tracetest.SpanStubs, name, target string) *tracetest.SpanStub { + for i := range spans { + if spans[i].Name == name && spanHasAttribute(&spans[i], "sansshell.target.address", target) { + return &spans[i] + } + } + return nil +} + +func spanHasAttribute(span *tracetest.SpanStub, key, value string) bool { + for _, attr := range span.Attributes { + if string(attr.Key) == key && attr.Value.AsString() == value { + return true + } + } + return false +} + +func spanHasEvent(span *tracetest.SpanStub, name string) bool { + for _, ev := range span.Events { + if ev.Name == name { + return true + } + } + return false +} + +func eventHasAttribute(span *tracetest.SpanStub, eventName, key, value string) bool { + for _, ev := range span.Events { + if ev.Name == eventName { + for _, attr := range ev.Attributes { + if string(attr.Key) == key && attr.Value.AsString() == value { + return true + } + } + } + } + return false +} + +// ---------- tests ---------- + +func TestTracing_TargetSpanCreated(t *testing.T) { + exporter, _ := setupTracing(t) + testServerMap := testutil.StartTestDataServers(t, "foo:123") + authz := testutil.NewAllowAllRPCAuthorizer(context.Background(), t) + + runTargetStream(t, testServerMap, "foo:123", "/Testdata.TestService/TestUnary", authz, &tdpb.TestRequest{Input: "hello"}) + + spans := exporter.GetSpans() + targetSpan := findTargetSpan(spans, "proxy.target/Testdata.TestService/TestUnary", "foo:123") + if targetSpan == nil { + t.Fatalf("expected target span 'proxy.target/Testdata.TestService/TestUnary' with target foo:123 not found in %d span(s)", len(spans)) + } + if !spanHasAttribute(targetSpan, "sansshell.target.method", "/Testdata.TestService/TestUnary") { + t.Error("target span missing sansshell.target.method") + } + if !spanHasAttribute(targetSpan, "sansshell.target.stream_type", "unary") { + t.Error("target span missing sansshell.target.stream_type=unary") + } +} + +func TestTracing_DialSubSpan(t *testing.T) { + exporter, _ := setupTracing(t) + testServerMap := testutil.StartTestDataServers(t, "foo:123") + authz := testutil.NewAllowAllRPCAuthorizer(context.Background(), t) + + runTargetStream(t, testServerMap, "foo:123", "/Testdata.TestService/TestUnary", authz, &tdpb.TestRequest{Input: "x"}) + + spans := exporter.GetSpans() + dialSpan := findTargetSpan(spans, "proxy.target.dial", "foo:123") + if dialSpan == nil { + t.Fatal("expected dial span 'proxy.target.dial' with target foo:123 not found") + } + if dialSpan.Status.Code != 0 { + t.Errorf("dial span has unexpected error status: %v", dialSpan.Status) + } +} + +func TestTracing_StreamEvents(t *testing.T) { + exporter, _ := setupTracing(t) + testServerMap := testutil.StartTestDataServers(t, "foo:123") + authz := testutil.NewAllowAllRPCAuthorizer(context.Background(), t) + + runTargetStream(t, testServerMap, "foo:123", "/Testdata.TestService/TestUnary", authz, &tdpb.TestRequest{Input: "hi"}) + + spans := exporter.GetSpans() + targetSpan := findTargetSpan(spans, "proxy.target/Testdata.TestService/TestUnary", "foo:123") + if targetSpan == nil { + t.Fatal("target span not found") + } + + if !spanHasEvent(targetSpan, "stream.connected") { + t.Error("target span missing stream.connected event") + } + if !spanHasEvent(targetSpan, "stream.first_response") { + t.Error("target span missing stream.first_response event") + } + if !spanHasEvent(targetSpan, "stream.finished") { + t.Error("target span missing stream.finished event") + } + if !eventHasAttribute(targetSpan, "stream.finished", "grpc.status_code", "OK") { + t.Error("stream.finished event missing grpc.status_code=OK") + } + if !spanHasEvent(targetSpan, "authz.evaluated") { + t.Error("target span missing authz.evaluated event") + } + if !eventHasAttribute(targetSpan, "authz.evaluated", "sansshell.authz.result", "allowed") { + t.Error("authz.evaluated event missing sansshell.authz.result=allowed") + } +} + +func TestTracing_AuthzDeniedEvent(t *testing.T) { + exporter, _ := setupTracing(t) + ctx := context.Background() + + policy := ` +package sansshell.authz + +default allow = false + +allow { + input.method = "/Testdata.TestService/TestUnary" + input.message.input = "allowed" +} +` + authz := testutil.NewOpaRPCAuthorizer(ctx, t, policy) + testServerMap := testutil.StartTestDataServers(t, "foo:123") + + runTargetStream(t, testServerMap, "foo:123", "/Testdata.TestService/TestUnary", authz, &tdpb.TestRequest{Input: "denied_input"}) + + spans := exporter.GetSpans() + targetSpan := findTargetSpan(spans, "proxy.target/Testdata.TestService/TestUnary", "foo:123") + if targetSpan == nil { + t.Fatal("target span not found") + } + if !eventHasAttribute(targetSpan, "authz.evaluated", "sansshell.authz.result", "denied") { + t.Error("expected authz.evaluated event with result=denied") + } +} + +func TestTracing_FanOutSpans(t *testing.T) { + exporter, _ := setupTracing(t) + testServerMap := testutil.StartTestDataServers(t, "foo:123", "bar:456") + authz := testutil.NewAllowAllRPCAuthorizer(context.Background(), t) + + runTargetStream(t, testServerMap, "foo:123", "/Testdata.TestService/TestUnary", authz, &tdpb.TestRequest{Input: "a"}) + runTargetStream(t, testServerMap, "bar:456", "/Testdata.TestService/TestUnary", authz, &tdpb.TestRequest{Input: "b"}) + + spans := exporter.GetSpans() + + expectedTargets := map[string]bool{"foo:123": false, "bar:456": false} + for _, s := range spans { + if s.Name == "proxy.target/Testdata.TestService/TestUnary" { + for _, attr := range s.Attributes { + if string(attr.Key) == "sansshell.target.address" { + if _, ok := expectedTargets[attr.Value.AsString()]; ok { + expectedTargets[attr.Value.AsString()] = true + } + } + } + } + } + for target, found := range expectedTargets { + if !found { + t.Errorf("expected target span for %s not found", target) + } + } + + if findTargetSpan(spans, "proxy.target.dial", "foo:123") == nil { + t.Error("expected dial span for foo:123 not found") + } + if findTargetSpan(spans, "proxy.target.dial", "bar:456") == nil { + t.Error("expected dial span for bar:456 not found") + } +} + +func TestTracing_EnrichRootSpan(t *testing.T) { + exporter, tp := setupTracing(t) + + tracer := tp.Tracer("test") + ctx, span := tracer.Start(context.Background(), "test-root") + + peerInfo := &rpcauth.PeerAuthInput{ + Net: &rpcauth.NetAuthInput{Network: "tcp", Address: "10.0.0.1"}, + Principal: &rpcauth.PrincipalAuthInput{ID: "user@example.com"}, + } + md := metadata.Pairs(rpcauth.ReqJustKey, "ticket-123") + ctx = metadata.NewIncomingContext(ctx, md) + + enrichRootSpan(ctx, peerInfo) + span.End() + + spans := exporter.GetSpans() + rootSpan := findSpan(spans, "test-root") + if rootSpan == nil { + t.Fatal("root span not found") + } + if !spanHasAttribute(rootSpan, "sansshell.caller.principal", "user@example.com") { + t.Error("missing sansshell.caller.principal") + } + if !spanHasAttribute(rootSpan, "sansshell.caller.address", "10.0.0.1") { + t.Error("missing sansshell.caller.address") + } + if !spanHasAttribute(rootSpan, "sansshell.proxy.justification", "ticket-123") { + t.Error("missing sansshell.proxy.justification") + } +}