Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ require (
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/exporters/prometheus v0.50.0
go.opentelemetry.io/otel/metric v1.38.0
go.opentelemetry.io/otel/sdk v1.38.0
go.opentelemetry.io/otel/sdk/metric v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
gocloud.dev v0.32.0
Expand Down Expand Up @@ -108,7 +109,6 @@ require (
go.opencensus.io v0.24.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.53.0 // indirect
go.opentelemetry.io/otel/sdk v1.38.0 // indirect
golang.org/x/crypto v0.44.0 // indirect
golang.org/x/net v0.47.0 // indirect
golang.org/x/oauth2 v0.32.0 // indirect
Expand Down
16 changes: 14 additions & 2 deletions proxy/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"fmt"
"io"

"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"

Expand Down Expand Up @@ -248,7 +249,10 @@ func dispatch(ctx context.Context, stream pb.Proxy_ProxyServer, requestChan chan
// be removed from the stream set
doneChan := make(chan uint64)
recorder := metrics.RecorderFromContextOrNoop(ctx)
rootSpan := trace.SpanFromContext(ctx)
var addedPeerToContext bool
var targetCount int
defer func() { rootSpan.SetAttributes(attrProxyTargetCount.Int(targetCount)) }()
for {
select {
case <-ctx.Done():
Expand Down Expand Up @@ -279,15 +283,23 @@ func dispatch(ctx context.Context, stream pb.Proxy_ProxyServer, requestChan chan
// Peer information might not be properly populated until rpcauth
// evaluates the initial received message, so let's grab fresh
// peer information when we know we've gotten at least one message.
ctx = rpcauth.AddPeerToContext(ctx, rpcauth.PeerInputFromContext(stream.Context()))
peerInfo := rpcauth.PeerInputFromContext(stream.Context())
ctx = rpcauth.AddPeerToContext(ctx, peerInfo)
enrichRootSpan(ctx, peerInfo)
addedPeerToContext = true
}
// We have a new request
switch req.Request.(type) {
case *pb.ProxyRequest_StartStream:
if err := streamSet.Add(ctx, req.GetStartStream(), replyChan, doneChan); err != nil {
ss := req.GetStartStream()
if err := streamSet.Add(ctx, ss, replyChan, doneChan); err != nil {
return err
}
targetCount++
rootSpan.AddEvent("dispatch.start_stream", trace.WithAttributes(
attrTargetAddress.String(ss.GetTarget()),
attrTargetMethod.String(ss.GetMethodName()),
))
case *pb.ProxyRequest_StreamData:
if err := streamSet.Send(ctx, req.GetStreamData()); err != nil {
return err
Expand Down
69 changes: 63 additions & 6 deletions proxy/server/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ import (
"time"

"github.com/go-logr/logr"
"go.opentelemetry.io/otel/attribute"
otelcodes "go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
Expand Down Expand Up @@ -192,44 +195,74 @@ func (s *TargetStream) Send(req proto.Message) error {
// messages for sending to a proxy client, including the final
// status of the target stream
func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) {
group, ctx := errgroup.WithContext(s.ctx)
// Create a child span for this target stream.
spanAttrs := []attribute.KeyValue{
attrTargetAddress.String(s.target),
attrTargetMethod.String(s.serviceMethod.FullName()),
attrTargetStreamID.String(fmt.Sprintf("%d", s.streamID)),
attrTargetStreamType.String(streamType(s.serviceMethod.ClientStreams(), s.serviceMethod.ServerStreams())),
attrTargetAuthzDryRun.Bool(s.authzDryRun),
}
if s.dialTimeout != nil {
spanAttrs = append(spanAttrs, attrTargetDialTimeoutMs.Int64(s.dialTimeout.Milliseconds()))
}
spanCtx, span := getTracer().Start(s.ctx, "proxy.target"+s.serviceMethod.FullName(),
trace.WithAttributes(spanAttrs...),
)
defer span.End()

group, ctx := errgroup.WithContext(spanCtx)

peer := rpcauth.PeerInputFromContext(ctx)
if peer != nil && peer.Principal != nil {
// Unconditionally add information on the original caller to outgoing RPCs
ctx = proxiedidentity.AppendToMetadataInOutgoingContext(ctx, peer.Principal)
span.SetAttributes(attrTargetProxiedPrincipal.String(peer.Principal.ID))
}

group.Go(func() error {
dialCtx, cancel := context.WithCancel(ctx)
// Sub-span for dial + stream creation.
dialCtx, dialSpan := getTracer().Start(ctx, "proxy.target.dial",
trace.WithAttributes(attrTargetAddress.String(s.target)),
)
dialCtx, dialCancel := context.WithCancel(dialCtx)
var opts []grpc.DialOption
if s.dialTimeout != nil {
dialCtx, cancel = context.WithTimeout(ctx, *s.dialTimeout)
dialCtx, dialCancel = context.WithTimeout(dialCtx, *s.dialTimeout)
opts = append(opts, grpc.WithBlock())
}
var err error
defer cancel()
defer dialCancel()
grpcConn, err := s.dialer.DialContext(dialCtx, s.target, opts...)
if err != nil {
// We cannot create a new stream to the target. So we need to cancel this stream.
s.logger.Info("unable to create stream", "status", err)
s.cancelFunc()
dialSpan.RecordError(err)
dialSpan.SetStatus(otelcodes.Error, "dial failed")
dialSpan.End()
return fmt.Errorf("could not connect to target from the proxy: %w", err)
}
s.grpcConn = grpcConn
grpcStream, err := s.grpcConn.NewStream(ctx, s.serviceMethod.StreamDesc(), s.serviceMethod.FullName())
if err != nil {
// We cannot create a new stream to the target. So we need to cancel this stream.
s.logger.Info("unable to create stream", "status", err)
dialSpan.RecordError(err)
dialSpan.SetStatus(otelcodes.Error, "new stream failed")
dialSpan.End()
return fmt.Errorf("could not connect to target from the proxy: %w", err)
}
dialSpan.End()
span.AddEvent("stream.connected")

// We've successfully connected and can replace the initial unconnected stream
// with the target stream.
s.setStream(grpcStream)

// Receives messages from the server stream
group.Go(func() error {
var receivedFirst bool
for {
msg := s.serviceMethod.NewReply()
err := grpcStream.RecvMsg(msg)
Expand Down Expand Up @@ -262,7 +295,10 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) {
s.CloseWith(err)
return fmt.Errorf("proxy could not receive response from the target: %w", err)
}
// otherwise, this is a streamData reply
if !receivedFirst {
span.AddEvent("stream.first_response")
receivedFirst = true
}
packed, err := anypb.New(msg)
if err != nil {
return err
Expand Down Expand Up @@ -326,9 +362,19 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) {

// If authz fails, close immediately with an error
if err := s.authorizer.Eval(ctx, authinput); err != nil {
span.AddEvent("authz.evaluated", trace.WithAttributes(
attrAuthzResult.String("denied"),
attrAuthzMethod.String(s.Method()),
))
span.RecordError(err)
span.SetStatus(otelcodes.Error, "authz denied")
s.CloseWith(err)
return err
}
span.AddEvent("authz.evaluated", trace.WithAttributes(
attrAuthzResult.String("allowed"),
attrAuthzMethod.String(s.Method()),
))

if s.authzDryRun {
// TODO: make authz dry run request to server
Expand Down Expand Up @@ -380,12 +426,23 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) {
case err = <-s.errChan:
default:
}

// Record final stream status on span.
grpcStatus := status.Convert(err)
span.AddEvent("stream.finished", trace.WithAttributes(
attrGRPCStatusCode.String(grpcStatus.Code().String()),
))
if err != nil {
span.RecordError(err)
span.SetStatus(otelcodes.Error, grpcStatus.Message())
}

s.logger.Info("finished", "status", err)
reply := &pb.ProxyReply{
Reply: &pb.ProxyReply_ServerClose{
ServerClose: &pb.ServerClose{
StreamIds: []uint64{s.streamID},
Status: convertStatus(status.Convert(err)),
Status: convertStatus(grpcStatus),
},
},
}
Expand Down
152 changes: 152 additions & 0 deletions proxy/server/tracing.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/* Copyright (c) 2019 Snowflake Inc. All rights reserved.

Licensed under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
*/

package server

import (
"context"
"strings"
"sync"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc/metadata"

"github.com/Snowflake-Labs/sansshell/auth/rpcauth"
)

const tracerName = "sansshell-proxy"

var (
tracerMu sync.RWMutex
currentTracer trace.Tracer
)

func init() { currentTracer = otel.Tracer(tracerName) }

func getTracer() trace.Tracer {
tracerMu.RLock()
t := currentTracer
tracerMu.RUnlock()
return t
}

// setTracer replaces the active tracer and returns a function that restores
// the previous one. Safe for concurrent use.
func setTracer(t trace.Tracer) (restore func()) {
tracerMu.Lock()
prev := currentTracer
currentTracer = t
tracerMu.Unlock()
return func() {
tracerMu.Lock()
currentTracer = prev
tracerMu.Unlock()
}
}

// Span attribute keys for caller identity.
const (
attrCallerPrincipal attribute.Key = "sansshell.caller.principal"
attrCallerGroups attribute.Key = "sansshell.caller.groups"
attrCallerAddress attribute.Key = "sansshell.caller.address"
attrCallerCertCN attribute.Key = "sansshell.caller.cert.cn"
attrCallerCertSPIFFE attribute.Key = "sansshell.caller.cert.spiffe_id"
)

// Span attribute keys for target stream information.
const (
attrTargetAddress attribute.Key = "sansshell.target.address"
attrTargetMethod attribute.Key = "sansshell.target.method"
attrTargetStreamID attribute.Key = "sansshell.target.stream_id"
attrTargetStreamType attribute.Key = "sansshell.target.stream_type"
attrTargetDialTimeoutMs attribute.Key = "sansshell.target.dial_timeout_ms"
attrTargetAuthzDryRun attribute.Key = "sansshell.target.authz_dry_run"
attrTargetProxiedPrincipal attribute.Key = "sansshell.target.proxied_principal"
)

// Span attribute keys for aggregate proxy-level information.
const (
attrProxyTargetCount attribute.Key = "sansshell.proxy.target_count"
attrProxyJustification attribute.Key = "sansshell.proxy.justification"
)

// Span attribute keys for authz events.
const (
attrAuthzResult attribute.Key = "sansshell.authz.result"
attrAuthzMethod attribute.Key = "sansshell.authz.method"
)

// Span attribute key for stream finish status.
const attrGRPCStatusCode attribute.Key = "grpc.status_code"

func streamType(clientStreams, serverStreams bool) string {
switch {
case clientStreams && serverStreams:
return "bidi"
case clientStreams:
return "client_stream"
case serverStreams:
return "server_stream"
default:
return "unary"
}
}

func callerAttrsFromPeer(peer *rpcauth.PeerAuthInput) []attribute.KeyValue {
if peer == nil {
return nil
}
var attrs []attribute.KeyValue
if peer.Net != nil && peer.Net.Address != "" {
attrs = append(attrs, attrCallerAddress.String(peer.Net.Address))
}
if peer.Principal != nil {
if peer.Principal.ID != "" {
attrs = append(attrs, attrCallerPrincipal.String(peer.Principal.ID))
}
if len(peer.Principal.Groups) > 0 {
attrs = append(attrs, attrCallerGroups.String(strings.Join(peer.Principal.Groups, ",")))
}
}
if peer.Cert != nil {
if peer.Cert.Subject.CommonName != "" {
attrs = append(attrs, attrCallerCertCN.String(peer.Cert.Subject.CommonName))
}
if peer.Cert.SPIFFEID != "" {
attrs = append(attrs, attrCallerCertSPIFFE.String(peer.Cert.SPIFFEID))
}
}
return attrs
}

// enrichRootSpan sets caller identity and justification attributes on the
// current span in ctx. It is called from dispatch once peer info is available.
func enrichRootSpan(ctx context.Context, peer *rpcauth.PeerAuthInput) {
span := trace.SpanFromContext(ctx)
if !span.IsRecording() {
return
}
if attrs := callerAttrsFromPeer(peer); len(attrs) > 0 {
span.SetAttributes(attrs...)
}
if md, ok := metadata.FromIncomingContext(ctx); ok {
if vals := md.Get(rpcauth.ReqJustKey); len(vals) > 0 {
span.SetAttributes(attrProxyJustification.String(vals[0]))
}
}
}
Loading
Loading