From 0e36f1d782e911323b025fd93358f4b537846eaa Mon Sep 17 00:00:00 2001 From: Mateusz Jankowski Date: Fri, 27 Feb 2026 11:17:19 +0100 Subject: [PATCH] Add possibility to specify additional credential source in sanssh to select client cert for mTLS at Proxy level --- cmd/proxy-server/server/server.go | 45 ++- cmd/proxy-server/server/server_test.go | 191 +++++++++ cmd/sanssh/client/client.go | 5 + proxy/proxy.pb.go | 526 ++++++++++++++++--------- proxy/proxy.proto | 6 + proxy/proxy/proxy.go | 14 +- proxy/proxy_grpc.pb.go | 2 +- proxy/server/server.go | 16 +- proxy/server/target.go | 24 +- proxy/server/target_test.go | 29 +- 10 files changed, 651 insertions(+), 207 deletions(-) create mode 100644 cmd/proxy-server/server/server_test.go diff --git a/cmd/proxy-server/server/server.go b/cmd/proxy-server/server/server.go index e9424ba8..0357786b 100644 --- a/cmd/proxy-server/server/server.go +++ b/cmd/proxy-server/server/server.go @@ -73,6 +73,7 @@ type runState struct { authzHooks []rpcauth.RPCAuthzHook services []func(*grpc.Server) metricsRecorder metrics.MetricsRecorder + namedCredSources map[string]string // hint name -> mtls loader name } type Option interface { @@ -308,6 +309,20 @@ func WithOtelTracing(interceptorOpts ...otelgrpc.Option) Option { }) } +// WithNamedClientCredSource registers an additional client credential source +// that the proxy can use when a client sends a matching force_credential in +// StartStream. hintName is the value clients will send; credSource is the name +// registered with the mtls package for loading client credentials. +func WithNamedClientCredSource(hintName, credSource string) Option { + return optionFunc(func(_ context.Context, r *runState) error { + if r.namedCredSources == nil { + r.namedCredSources = make(map[string]string) + } + r.namedCredSources[hintName] = credSource + return nil + }) +} + // Run takes the given context and RunState along with any authz hooks and starts up a sansshell proxy server // using the flags above to provide credentials. An address hook (based on the remote host) with always be added. // As this is intended to be called from main() it doesn't return errors and will instead exit on any errors. @@ -383,21 +398,23 @@ func Run(ctx context.Context, opts ...Option) { unaryClient = append(unaryClient, clientAuthz.AuthorizeClient) streamClient = append(streamClient, clientAuthz.AuthorizeClientStream) } - dialOpts := []grpc.DialOption{ - grpc.WithTransportCredentials(clientCreds), + sharedDialOpts := []grpc.DialOption{ grpc.WithChainUnaryInterceptor(unaryClient...), grpc.WithChainStreamInterceptor(streamClient...), // Use 16MB instead of the default 4MB to allow larger responses grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(16 * 1024 * 1024)), } if rs.statsClientHandler != nil { - dialOpts = append(dialOpts, grpc.WithStatsHandler(rs.statsClientHandler)) + sharedDialOpts = append(sharedDialOpts, grpc.WithStatsHandler(rs.statsClientHandler)) } - targetDialer := server.NewDialer(dialOpts...) + defaultDialOpts := append([]grpc.DialOption{grpc.WithTransportCredentials(clientCreds)}, sharedDialOpts...) + targetDialer := server.NewDialer(defaultDialOpts...) + + dialers := buildDialers(ctx, rs, targetDialer, sharedDialOpts) svcMap := server.LoadGlobalServiceMap() rs.logger.Info("loaded service map", "serviceMap", svcMap) - server := server.New(targetDialer, authz) + server := server.NewWithDialersAndServiceMap(dialers, authz, svcMap) // Even though the proxy RPC is streaming we have unary RPCs (logging, reflection) we // also need to properly auth and log. @@ -456,6 +473,24 @@ func Run(ctx context.Context, opts ...Option) { } } +// buildDialers constructs the named dialers map from the default dialer and +// any additional credential sources registered via WithNamedClientCredSource. +// If a named source fails to load, it is logged and skipped; the default +// dialer is always present under key "". +func buildDialers(ctx context.Context, rs *runState, defaultDialer server.TargetDialer, sharedDialOpts []grpc.DialOption) map[string]server.TargetDialer { + dialers := map[string]server.TargetDialer{"": defaultDialer} + for hint, src := range rs.namedCredSources { + creds, err := mtls.LoadClientCredentials(ctx, src) + if err != nil { + rs.logger.Error(err, "failed to load named client cred source, skipping", "hint", hint, "source", src) + continue + } + hintDialOpts := append([]grpc.DialOption{grpc.WithTransportCredentials(creds)}, sharedDialOpts...) + dialers[hint] = server.NewDialer(hintDialOpts...) + } + return dialers +} + // extractClientTransportCredentialsFromRunState extracts transport credentials from runState. Will error if both credSource and tlsConfig are specified func extractClientTransportCredentialsFromRunState(ctx context.Context, rs *runState) (credentials.TransportCredentials, error) { var creds credentials.TransportCredentials diff --git a/cmd/proxy-server/server/server_test.go b/cmd/proxy-server/server/server_test.go new file mode 100644 index 00000000..484f8f92 --- /dev/null +++ b/cmd/proxy-server/server/server_test.go @@ -0,0 +1,191 @@ +/* Copyright (c) 2019 Snowflake Inc. All rights reserved. + + Licensed under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. +*/ + +package server + +import ( + "context" + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "os" + "testing" + + "github.com/go-logr/logr" + "google.golang.org/grpc" + + "github.com/Snowflake-Labs/sansshell/auth/mtls" + proxyserver "github.com/Snowflake-Labs/sansshell/proxy/server" +) + +const ( + failLoaderName = "test-buildDialers-fail" + okLoaderName = "test-buildDialers-ok" +) + +func TestMain(m *testing.M) { + if err := mtls.Register(failLoaderName, failingLoader{}); err != nil { + fmt.Fprintf(os.Stderr, "mtls.Register(%s): %v\n", failLoaderName, err) + os.Exit(1) + } + if err := mtls.Register(okLoaderName, successLoader{}); err != nil { + fmt.Fprintf(os.Stderr, "mtls.Register(%s): %v\n", okLoaderName, err) + os.Exit(1) + } + os.Exit(m.Run()) +} + +type fakeDialer struct{} + +func (fakeDialer) DialContext(_ context.Context, _ string, _ ...grpc.DialOption) (proxyserver.ClientConnCloser, error) { + return nil, errors.New("not implemented") +} + +func TestWithNamedClientCredSourceSingle(t *testing.T) { + rs := &runState{} + opt := WithNamedClientCredSource("pg", "some-loader") + if err := opt.apply(context.Background(), rs); err != nil { + t.Fatalf("apply: %v", err) + } + if got, ok := rs.namedCredSources["pg"]; !ok || got != "some-loader" { + t.Fatalf("expected namedCredSources[\"pg\"] = \"some-loader\", got %q (ok=%v)", got, ok) + } +} + +func TestWithNamedClientCredSourceMultiple(t *testing.T) { + rs := &runState{} + for _, pair := range []struct{ hint, src string }{ + {"pg", "loader-a"}, + {"redis", "loader-b"}, + } { + if err := WithNamedClientCredSource(pair.hint, pair.src).apply(context.Background(), rs); err != nil { + t.Fatalf("apply(%q): %v", pair.hint, err) + } + } + if len(rs.namedCredSources) != 2 { + t.Fatalf("expected 2 entries, got %d", len(rs.namedCredSources)) + } + if rs.namedCredSources["pg"] != "loader-a" { + t.Fatalf("pg: got %q", rs.namedCredSources["pg"]) + } + if rs.namedCredSources["redis"] != "loader-b" { + t.Fatalf("redis: got %q", rs.namedCredSources["redis"]) + } +} + +func TestWithNamedClientCredSourceOverwrite(t *testing.T) { + rs := &runState{} + if err := WithNamedClientCredSource("pg", "old").apply(context.Background(), rs); err != nil { + t.Fatal(err) + } + if err := WithNamedClientCredSource("pg", "new").apply(context.Background(), rs); err != nil { + t.Fatal(err) + } + if rs.namedCredSources["pg"] != "new" { + t.Fatalf("expected overwrite to \"new\", got %q", rs.namedCredSources["pg"]) + } +} + +// --- buildDialers tests --- + +type failingLoader struct{} + +func (failingLoader) LoadClientCA(context.Context) (*x509.CertPool, error) { + return nil, errors.New("no CA") +} +func (failingLoader) LoadRootCA(context.Context) (*x509.CertPool, error) { + return nil, errors.New("no root CA") +} +func (failingLoader) LoadClientCertificate(context.Context) (tls.Certificate, error) { + return tls.Certificate{}, errors.New("no client cert") +} +func (failingLoader) LoadServerCertificate(context.Context) (tls.Certificate, error) { + return tls.Certificate{}, errors.New("no server cert") +} +func (failingLoader) CertsRefreshed() bool { return false } +func (failingLoader) GetClientCertInfo(context.Context, string) (*mtls.ClientCertInfo, error) { + return nil, nil +} + +type successLoader struct{} + +func (successLoader) LoadClientCA(context.Context) (*x509.CertPool, error) { + return x509.NewCertPool(), nil +} +func (successLoader) LoadRootCA(context.Context) (*x509.CertPool, error) { + return x509.NewCertPool(), nil +} +func (successLoader) LoadClientCertificate(context.Context) (tls.Certificate, error) { + return tls.Certificate{}, nil +} +func (successLoader) LoadServerCertificate(context.Context) (tls.Certificate, error) { + return tls.Certificate{}, nil +} +func (successLoader) CertsRefreshed() bool { return false } +func (successLoader) GetClientCertInfo(context.Context, string) (*mtls.ClientCertInfo, error) { + return nil, nil +} + +func TestBuildDialersDefaultOnly(t *testing.T) { + rs := &runState{logger: logr.Discard()} + dialers := buildDialers(context.Background(), rs, fakeDialer{}, nil) + if _, ok := dialers[""]; !ok { + t.Fatal("expected default dialer under key \"\"") + } + if len(dialers) != 1 { + t.Fatalf("expected 1 dialer, got %d", len(dialers)) + } +} + +func TestBuildDialersSkipsFailedCredSources(t *testing.T) { + rs := &runState{ + logger: logr.Discard(), + namedCredSources: map[string]string{"bad-hint": failLoaderName}, + } + dialers := buildDialers(context.Background(), rs, fakeDialer{}, nil) + + if _, ok := dialers[""]; !ok { + t.Fatal("expected default dialer under key \"\"") + } + if _, ok := dialers["bad-hint"]; ok { + t.Fatal("expected failed hint to be absent from dialers map") + } + if len(dialers) != 1 { + t.Fatalf("expected 1 dialer (default only), got %d", len(dialers)) + } +} + +func TestBuildDialersRegistersSuccessfulHint(t *testing.T) { + rs := &runState{ + logger: logr.Discard(), + namedCredSources: map[string]string{"good-hint": okLoaderName}, + } + shared := []grpc.DialOption{ + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(4 * 1024 * 1024)), + } + dialers := buildDialers(context.Background(), rs, fakeDialer{}, shared) + + if _, ok := dialers[""]; !ok { + t.Fatal("expected default dialer under key \"\"") + } + if _, ok := dialers["good-hint"]; !ok { + t.Fatal("expected successful hint to be present in dialers map") + } + if len(dialers) != 2 { + t.Fatalf("expected 2 dialers, got %d", len(dialers)) + } +} diff --git a/cmd/sanssh/client/client.go b/cmd/sanssh/client/client.go index 3b49b4a9..2421d188 100644 --- a/cmd/sanssh/client/client.go +++ b/cmd/sanssh/client/client.go @@ -82,6 +82,10 @@ type RunState struct { EnableMPA bool // If true, the command is authz dry run and real action should not be executed AuthzDryRun bool + // ForceCredential is passed to the proxy to force a specific client + // credential when dialing targets. The proxy will fail with an error if + // the requested credential is not configured. Empty means default. + ForceCredential string // Interspectors for unary calls to the connection to the proxy ClientUnaryInterceptors []proxy.UnaryInterceptor @@ -376,6 +380,7 @@ func Run(ctx context.Context, rs RunState) { } conn.AuthzDryRun = rs.AuthzDryRun + conn.ForceCredential = rs.ForceCredential if rs.EnableMPA { conn.UnaryInterceptors = []proxy.UnaryInterceptor{mpahooks.ProxyClientUnaryInterceptor(state)} diff --git a/proxy/proxy.pb.go b/proxy/proxy.pb.go index b83de3c2..31e89c35 100644 --- a/proxy/proxy.pb.go +++ b/proxy/proxy.pb.go @@ -15,8 +15,8 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.10 -// protoc v6.33.4 +// protoc-gen-go v1.34.2 +// protoc v7.34.0 // source: proxy.proto package proxy @@ -28,7 +28,6 @@ import ( durationpb "google.golang.org/protobuf/types/known/durationpb" reflect "reflect" sync "sync" - unsafe "unsafe" ) const ( @@ -39,23 +38,26 @@ const ( ) type ProxyRequest struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Types that are valid to be assigned to Request: + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to Request: // // *ProxyRequest_StartStream // *ProxyRequest_StreamData // *ProxyRequest_ClientClose // *ProxyRequest_ClientCancel - Request isProxyRequest_Request `protobuf_oneof:"request"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Request isProxyRequest_Request `protobuf_oneof:"request"` } func (x *ProxyRequest) Reset() { *x = ProxyRequest{} - mi := &file_proxy_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *ProxyRequest) String() string { @@ -66,7 +68,7 @@ func (*ProxyRequest) ProtoMessage() {} func (x *ProxyRequest) ProtoReflect() protoreflect.Message { mi := &file_proxy_proto_msgTypes[0] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -81,45 +83,37 @@ func (*ProxyRequest) Descriptor() ([]byte, []int) { return file_proxy_proto_rawDescGZIP(), []int{0} } -func (x *ProxyRequest) GetRequest() isProxyRequest_Request { - if x != nil { - return x.Request +func (m *ProxyRequest) GetRequest() isProxyRequest_Request { + if m != nil { + return m.Request } return nil } func (x *ProxyRequest) GetStartStream() *StartStream { - if x != nil { - if x, ok := x.Request.(*ProxyRequest_StartStream); ok { - return x.StartStream - } + if x, ok := x.GetRequest().(*ProxyRequest_StartStream); ok { + return x.StartStream } return nil } func (x *ProxyRequest) GetStreamData() *StreamData { - if x != nil { - if x, ok := x.Request.(*ProxyRequest_StreamData); ok { - return x.StreamData - } + if x, ok := x.GetRequest().(*ProxyRequest_StreamData); ok { + return x.StreamData } return nil } func (x *ProxyRequest) GetClientClose() *ClientClose { - if x != nil { - if x, ok := x.Request.(*ProxyRequest_ClientClose); ok { - return x.ClientClose - } + if x, ok := x.GetRequest().(*ProxyRequest_ClientClose); ok { + return x.ClientClose } return nil } func (x *ProxyRequest) GetClientCancel() *ClientCancel { - if x != nil { - if x, ok := x.Request.(*ProxyRequest_ClientCancel); ok { - return x.ClientCancel - } + if x, ok := x.GetRequest().(*ProxyRequest_ClientCancel); ok { + return x.ClientCancel } return nil } @@ -159,22 +153,25 @@ func (*ProxyRequest_ClientClose) isProxyRequest_Request() {} func (*ProxyRequest_ClientCancel) isProxyRequest_Request() {} type ProxyReply struct { - state protoimpl.MessageState `protogen:"open.v1"` - // Types that are valid to be assigned to Reply: + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // Types that are assignable to Reply: // // *ProxyReply_StartStreamReply // *ProxyReply_StreamData // *ProxyReply_ServerClose - Reply isProxyReply_Reply `protobuf_oneof:"reply"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Reply isProxyReply_Reply `protobuf_oneof:"reply"` } func (x *ProxyReply) Reset() { *x = ProxyReply{} - mi := &file_proxy_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *ProxyReply) String() string { @@ -185,7 +182,7 @@ func (*ProxyReply) ProtoMessage() {} func (x *ProxyReply) ProtoReflect() protoreflect.Message { mi := &file_proxy_proto_msgTypes[1] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -200,36 +197,30 @@ func (*ProxyReply) Descriptor() ([]byte, []int) { return file_proxy_proto_rawDescGZIP(), []int{1} } -func (x *ProxyReply) GetReply() isProxyReply_Reply { - if x != nil { - return x.Reply +func (m *ProxyReply) GetReply() isProxyReply_Reply { + if m != nil { + return m.Reply } return nil } func (x *ProxyReply) GetStartStreamReply() *StartStreamReply { - if x != nil { - if x, ok := x.Reply.(*ProxyReply_StartStreamReply); ok { - return x.StartStreamReply - } + if x, ok := x.GetReply().(*ProxyReply_StartStreamReply); ok { + return x.StartStreamReply } return nil } func (x *ProxyReply) GetStreamData() *StreamData { - if x != nil { - if x, ok := x.Reply.(*ProxyReply_StreamData); ok { - return x.StreamData - } + if x, ok := x.GetReply().(*ProxyReply_StreamData); ok { + return x.StreamData } return nil } func (x *ProxyReply) GetServerClose() *ServerClose { - if x != nil { - if x, ok := x.Reply.(*ProxyReply_ServerClose); ok { - return x.ServerClose - } + if x, ok := x.GetReply().(*ProxyReply_ServerClose); ok { + return x.ServerClose } return nil } @@ -264,7 +255,10 @@ func (*ProxyReply_ServerClose) isProxyReply_Reply() {} // that will be echoed in the returned reply to allow clients // to correlate this request with the associated stream id. type StartStream struct { - state protoimpl.MessageState `protogen:"open.v1"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + // The stream target, as accepted by grpc.Dial. Target string `protobuf:"bytes,1,opt,name=target,proto3" json:"target,omitempty"` // The fully-qualified method name (e.g. "/Package.Service/Method") @@ -281,16 +275,21 @@ type StartStream struct { // own routine so this will not block the overall progress for a stream. DialTimeout *durationpb.Duration `protobuf:"bytes,4,opt,name=dial_timeout,json=dialTimeout,proto3" json:"dial_timeout,omitempty"` // Perform authz dry run instead actual execution. - AuthzDryRun bool `protobuf:"varint,5,opt,name=authz_dry_run,json=authzDryRun,proto3" json:"authz_dry_run,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + AuthzDryRun bool `protobuf:"varint,5,opt,name=authz_dry_run,json=authzDryRun,proto3" json:"authz_dry_run,omitempty"` + // Optional. Forces the proxy to use a specific client credential when + // dialing this target. If empty or unset, the proxy uses its default + // credential. The proxy will reject unrecognized non-empty values with + // InvalidArgument rather than falling back to the default. + ForceCredential string `protobuf:"bytes,6,opt,name=force_credential,json=forceCredential,proto3" json:"force_credential,omitempty"` } func (x *StartStream) Reset() { *x = StartStream{} - mi := &file_proxy_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *StartStream) String() string { @@ -301,7 +300,7 @@ func (*StartStream) ProtoMessage() {} func (x *StartStream) ProtoReflect() protoreflect.Message { mi := &file_proxy_proto_msgTypes[2] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -351,26 +350,36 @@ func (x *StartStream) GetAuthzDryRun() bool { return false } +func (x *StartStream) GetForceCredential() string { + if x != nil { + return x.ForceCredential + } + return "" +} + type StartStreamReply struct { - state protoimpl.MessageState `protogen:"open.v1"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + // The target string originally supplied to StartStream Target string `protobuf:"bytes,1,opt,name=target,proto3" json:"target,omitempty"` // The nonce value supplied by the client in StartStream. Nonce uint32 `protobuf:"varint,2,opt,name=nonce,proto3" json:"nonce,omitempty"` - // Types that are valid to be assigned to Reply: + // Types that are assignable to Reply: // // *StartStreamReply_StreamId // *StartStreamReply_ErrorStatus - Reply isStartStreamReply_Reply `protobuf_oneof:"reply"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Reply isStartStreamReply_Reply `protobuf_oneof:"reply"` } func (x *StartStreamReply) Reset() { *x = StartStreamReply{} - mi := &file_proxy_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *StartStreamReply) String() string { @@ -381,7 +390,7 @@ func (*StartStreamReply) ProtoMessage() {} func (x *StartStreamReply) ProtoReflect() protoreflect.Message { mi := &file_proxy_proto_msgTypes[3] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -410,27 +419,23 @@ func (x *StartStreamReply) GetNonce() uint32 { return 0 } -func (x *StartStreamReply) GetReply() isStartStreamReply_Reply { - if x != nil { - return x.Reply +func (m *StartStreamReply) GetReply() isStartStreamReply_Reply { + if m != nil { + return m.Reply } return nil } func (x *StartStreamReply) GetStreamId() uint64 { - if x != nil { - if x, ok := x.Reply.(*StartStreamReply_StreamId); ok { - return x.StreamId - } + if x, ok := x.GetReply().(*StartStreamReply_StreamId); ok { + return x.StreamId } return 0 } func (x *StartStreamReply) GetErrorStatus() *Status { - if x != nil { - if x, ok := x.Reply.(*StartStreamReply_ErrorStatus); ok { - return x.ErrorStatus - } + if x, ok := x.GetReply().(*StartStreamReply_ErrorStatus); ok { + return x.ErrorStatus } return nil } @@ -462,18 +467,21 @@ func (*StartStreamReply_ErrorStatus) isStartStreamReply_Reply() {} // Note that clients do not need to send a ClientClose for // streams where client_streams is false. type ClientClose struct { - state protoimpl.MessageState `protogen:"open.v1"` - // The server-asssigned stream id(s) to close. - StreamIds []uint64 `protobuf:"varint,1,rep,packed,name=stream_ids,json=streamIds,proto3" json:"stream_ids,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // The server-asssigned stream id(s) to close. + StreamIds []uint64 `protobuf:"varint,1,rep,packed,name=stream_ids,json=streamIds,proto3" json:"stream_ids,omitempty"` } func (x *ClientClose) Reset() { *x = ClientClose{} - mi := &file_proxy_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *ClientClose) String() string { @@ -484,7 +492,7 @@ func (*ClientClose) ProtoMessage() {} func (x *ClientClose) ProtoReflect() protoreflect.Message { mi := &file_proxy_proto_msgTypes[4] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -509,18 +517,21 @@ func (x *ClientClose) GetStreamIds() []uint64 { // ClientCancel is sent by the proxy client to request // cancellation of the given stream(s). type ClientCancel struct { - state protoimpl.MessageState `protogen:"open.v1"` - // The server-assigned stream id(s) to cancel. - StreamIds []uint64 `protobuf:"varint,1,rep,packed,name=stream_ids,json=streamIds,proto3" json:"stream_ids,omitempty"` - unknownFields protoimpl.UnknownFields + state protoimpl.MessageState sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + // The server-assigned stream id(s) to cancel. + StreamIds []uint64 `protobuf:"varint,1,rep,packed,name=stream_ids,json=streamIds,proto3" json:"stream_ids,omitempty"` } func (x *ClientCancel) Reset() { *x = ClientCancel{} - mi := &file_proxy_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *ClientCancel) String() string { @@ -531,7 +542,7 @@ func (*ClientCancel) ProtoMessage() {} func (x *ClientCancel) ProtoReflect() protoreflect.Message { mi := &file_proxy_proto_msgTypes[5] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -556,22 +567,25 @@ func (x *ClientCancel) GetStreamIds() []uint64 { // StreamData is used by both clients and servers to transmit // data for an established stream. type StreamData struct { - state protoimpl.MessageState `protogen:"open.v1"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + // The stream identifier, as returned in StartStreamReply // This can be repeated, to indicate that the same data is relevant // to multiple established streams. StreamIds []uint64 `protobuf:"varint,1,rep,packed,name=stream_ids,json=streamIds,proto3" json:"stream_ids,omitempty"` // The message payload - Payload *anypb.Any `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Payload *anypb.Any `protobuf:"bytes,2,opt,name=payload,proto3" json:"payload,omitempty"` } func (x *StreamData) Reset() { *x = StreamData{} - mi := &file_proxy_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *StreamData) String() string { @@ -582,7 +596,7 @@ func (*StreamData) ProtoMessage() {} func (x *StreamData) ProtoReflect() protoreflect.Message { mi := &file_proxy_proto_msgTypes[6] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -614,22 +628,25 @@ func (x *StreamData) GetPayload() *anypb.Any { // A server end-of-stream response, containing the final status // of the stream. type ServerClose struct { - state protoimpl.MessageState `protogen:"open.v1"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + // The stream identifier, as returned in StartStreamReply // This can be repeated, to indicate that the same status is // applicable to multiple streams. StreamIds []uint64 `protobuf:"varint,1,rep,packed,name=stream_ids,json=streamIds,proto3" json:"stream_ids,omitempty"` // The final status of the stream. - Status *Status `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Status *Status `protobuf:"bytes,2,opt,name=status,proto3" json:"status,omitempty"` } func (x *ServerClose) Reset() { *x = ServerClose{} - mi := &file_proxy_proto_msgTypes[7] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *ServerClose) String() string { @@ -640,7 +657,7 @@ func (*ServerClose) ProtoMessage() {} func (x *ServerClose) ProtoReflect() protoreflect.Message { mi := &file_proxy_proto_msgTypes[7] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -671,22 +688,25 @@ func (x *ServerClose) GetStatus() *Status { // A wire-compatible version of google.rpc.Status type Status struct { - state protoimpl.MessageState `protogen:"open.v1"` + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + // The status code (one of google.rpc.Code) Code int32 `protobuf:"varint,1,opt,name=code,proto3" json:"code,omitempty"` // A developer-targeted error message. Message string `protobuf:"bytes,2,opt,name=message,proto3" json:"message,omitempty"` // List of messages carrying error details. - Details []*anypb.Any `protobuf:"bytes,3,rep,name=details,proto3" json:"details,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + Details []*anypb.Any `protobuf:"bytes,3,rep,name=details,proto3" json:"details,omitempty"` } func (x *Status) Reset() { *x = Status{} - mi := &file_proxy_proto_msgTypes[8] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) + if protoimpl.UnsafeEnabled { + mi := &file_proxy_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } } func (x *Status) String() string { @@ -697,7 +717,7 @@ func (*Status) ProtoMessage() {} func (x *Status) ProtoReflect() protoreflect.Message { mi := &file_proxy_proto_msgTypes[8] - if x != nil { + if protoimpl.UnsafeEnabled && x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -735,66 +755,107 @@ func (x *Status) GetDetails() []*anypb.Any { var File_proxy_proto protoreflect.FileDescriptor -const file_proxy_proto_rawDesc = "" + - "\n" + - "\vproxy.proto\x12\x05Proxy\x1a\x19google/protobuf/any.proto\x1a\x1egoogle/protobuf/duration.proto\"\xfd\x01\n" + - "\fProxyRequest\x127\n" + - "\fstart_stream\x18\x01 \x01(\v2\x12.Proxy.StartStreamH\x00R\vstartStream\x124\n" + - "\vstream_data\x18\x02 \x01(\v2\x11.Proxy.StreamDataH\x00R\n" + - "streamData\x127\n" + - "\fclient_close\x18\x03 \x01(\v2\x12.Proxy.ClientCloseH\x00R\vclientClose\x12:\n" + - "\rclient_cancel\x18\x04 \x01(\v2\x13.Proxy.ClientCancelH\x00R\fclientCancelB\t\n" + - "\arequest\"\xcd\x01\n" + - "\n" + - "ProxyReply\x12G\n" + - "\x12start_stream_reply\x18\x01 \x01(\v2\x17.Proxy.StartStreamReplyH\x00R\x10startStreamReply\x124\n" + - "\vstream_data\x18\x02 \x01(\v2\x11.Proxy.StreamDataH\x00R\n" + - "streamData\x127\n" + - "\fserver_close\x18\x03 \x01(\v2\x12.Proxy.ServerCloseH\x00R\vserverCloseB\a\n" + - "\x05reply\"\xbe\x01\n" + - "\vStartStream\x12\x16\n" + - "\x06target\x18\x01 \x01(\tR\x06target\x12\x1f\n" + - "\vmethod_name\x18\x02 \x01(\tR\n" + - "methodName\x12\x14\n" + - "\x05nonce\x18\x03 \x01(\rR\x05nonce\x12<\n" + - "\fdial_timeout\x18\x04 \x01(\v2\x19.google.protobuf.DurationR\vdialTimeout\x12\"\n" + - "\rauthz_dry_run\x18\x05 \x01(\bR\vauthzDryRun\"\x9c\x01\n" + - "\x10StartStreamReply\x12\x16\n" + - "\x06target\x18\x01 \x01(\tR\x06target\x12\x14\n" + - "\x05nonce\x18\x02 \x01(\rR\x05nonce\x12\x1d\n" + - "\tstream_id\x18\x03 \x01(\x04H\x00R\bstreamId\x122\n" + - "\ferror_status\x18\x04 \x01(\v2\r.Proxy.StatusH\x00R\verrorStatusB\a\n" + - "\x05reply\",\n" + - "\vClientClose\x12\x1d\n" + - "\n" + - "stream_ids\x18\x01 \x03(\x04R\tstreamIds\"-\n" + - "\fClientCancel\x12\x1d\n" + - "\n" + - "stream_ids\x18\x01 \x03(\x04R\tstreamIds\"[\n" + - "\n" + - "StreamData\x12\x1d\n" + - "\n" + - "stream_ids\x18\x01 \x03(\x04R\tstreamIds\x12.\n" + - "\apayload\x18\x02 \x01(\v2\x14.google.protobuf.AnyR\apayload\"S\n" + - "\vServerClose\x12\x1d\n" + - "\n" + - "stream_ids\x18\x01 \x03(\x04R\tstreamIds\x12%\n" + - "\x06status\x18\x02 \x01(\v2\r.Proxy.StatusR\x06status\"f\n" + - "\x06Status\x12\x12\n" + - "\x04code\x18\x01 \x01(\x05R\x04code\x12\x18\n" + - "\amessage\x18\x02 \x01(\tR\amessage\x12.\n" + - "\adetails\x18\x03 \x03(\v2\x14.google.protobuf.AnyR\adetails2>\n" + - "\x05Proxy\x125\n" + - "\x05Proxy\x12\x13.Proxy.ProxyRequest\x1a\x11.Proxy.ProxyReply\"\x00(\x010\x01B+Z)github.com/Snowflake-Labs/sansshell/proxyb\x06proto3" +var file_proxy_proto_rawDesc = []byte{ + 0x0a, 0x0b, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x1a, 0x19, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, + 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x61, 0x6e, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, + 0x1e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, + 0x2f, 0x64, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, + 0xfd, 0x01, 0x0a, 0x0c, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x37, 0x0a, 0x0c, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, + 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x53, + 0x74, 0x61, 0x72, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x48, 0x00, 0x52, 0x0b, 0x73, 0x74, + 0x61, 0x72, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, 0x34, 0x0a, 0x0b, 0x73, 0x74, 0x72, + 0x65, 0x61, 0x6d, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, + 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, + 0x61, 0x48, 0x00, 0x52, 0x0a, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x12, + 0x37, 0x0a, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x63, 0x6c, 0x6f, 0x73, 0x65, 0x18, + 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x48, 0x00, 0x52, 0x0b, 0x63, 0x6c, 0x69, + 0x65, 0x6e, 0x74, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x12, 0x3a, 0x0a, 0x0d, 0x63, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x5f, 0x63, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, + 0x13, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x43, 0x61, + 0x6e, 0x63, 0x65, 0x6c, 0x48, 0x00, 0x52, 0x0c, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x43, 0x61, + 0x6e, 0x63, 0x65, 0x6c, 0x42, 0x09, 0x0a, 0x07, 0x72, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, + 0xcd, 0x01, 0x0a, 0x0a, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x47, + 0x0a, 0x12, 0x73, 0x74, 0x61, 0x72, 0x74, 0x5f, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x72, + 0x65, 0x70, 0x6c, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x17, 0x2e, 0x50, 0x72, 0x6f, + 0x78, 0x79, 0x2e, 0x53, 0x74, 0x61, 0x72, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, + 0x70, 0x6c, 0x79, 0x48, 0x00, 0x52, 0x10, 0x73, 0x74, 0x61, 0x72, 0x74, 0x53, 0x74, 0x72, 0x65, + 0x61, 0x6d, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x12, 0x34, 0x0a, 0x0b, 0x73, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x5f, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x50, + 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x48, + 0x00, 0x52, 0x0a, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x44, 0x61, 0x74, 0x61, 0x12, 0x37, 0x0a, + 0x0c, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x5f, 0x63, 0x6c, 0x6f, 0x73, 0x65, 0x18, 0x03, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x53, 0x65, 0x72, 0x76, + 0x65, 0x72, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x48, 0x00, 0x52, 0x0b, 0x73, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x42, 0x07, 0x0a, 0x05, 0x72, 0x65, 0x70, 0x6c, 0x79, 0x22, + 0xe9, 0x01, 0x0a, 0x0b, 0x53, 0x74, 0x61, 0x72, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x12, + 0x16, 0x0a, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, + 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x1f, 0x0a, 0x0b, 0x6d, 0x65, 0x74, 0x68, 0x6f, + 0x64, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6d, 0x65, + 0x74, 0x68, 0x6f, 0x64, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x6f, 0x6e, 0x63, + 0x65, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x3c, + 0x0a, 0x0c, 0x64, 0x69, 0x61, 0x6c, 0x5f, 0x74, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x18, 0x04, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x44, 0x75, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x52, + 0x0b, 0x64, 0x69, 0x61, 0x6c, 0x54, 0x69, 0x6d, 0x65, 0x6f, 0x75, 0x74, 0x12, 0x22, 0x0a, 0x0d, + 0x61, 0x75, 0x74, 0x68, 0x7a, 0x5f, 0x64, 0x72, 0x79, 0x5f, 0x72, 0x75, 0x6e, 0x18, 0x05, 0x20, + 0x01, 0x28, 0x08, 0x52, 0x0b, 0x61, 0x75, 0x74, 0x68, 0x7a, 0x44, 0x72, 0x79, 0x52, 0x75, 0x6e, + 0x12, 0x29, 0x0a, 0x10, 0x66, 0x6f, 0x72, 0x63, 0x65, 0x5f, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, + 0x74, 0x69, 0x61, 0x6c, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0f, 0x66, 0x6f, 0x72, 0x63, + 0x65, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x22, 0x9c, 0x01, 0x0a, 0x10, + 0x53, 0x74, 0x61, 0x72, 0x74, 0x53, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x52, 0x65, 0x70, 0x6c, 0x79, + 0x12, 0x16, 0x0a, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, + 0x52, 0x06, 0x74, 0x61, 0x72, 0x67, 0x65, 0x74, 0x12, 0x14, 0x0a, 0x05, 0x6e, 0x6f, 0x6e, 0x63, + 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0d, 0x52, 0x05, 0x6e, 0x6f, 0x6e, 0x63, 0x65, 0x12, 0x1d, + 0x0a, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x04, 0x48, 0x00, 0x52, 0x08, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x12, 0x32, 0x0a, + 0x0c, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x5f, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x53, 0x74, 0x61, 0x74, + 0x75, 0x73, 0x48, 0x00, 0x52, 0x0b, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x42, 0x07, 0x0a, 0x05, 0x72, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x2c, 0x0a, 0x0b, 0x43, 0x6c, + 0x69, 0x65, 0x6e, 0x74, 0x43, 0x6c, 0x6f, 0x73, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x72, + 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x04, 0x52, 0x09, 0x73, + 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x73, 0x22, 0x2d, 0x0a, 0x0c, 0x43, 0x6c, 0x69, 0x65, + 0x6e, 0x74, 0x43, 0x61, 0x6e, 0x63, 0x65, 0x6c, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x72, 0x65, + 0x61, 0x6d, 0x5f, 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x04, 0x52, 0x09, 0x73, 0x74, + 0x72, 0x65, 0x61, 0x6d, 0x49, 0x64, 0x73, 0x22, 0x5b, 0x0a, 0x0a, 0x53, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x44, 0x61, 0x74, 0x61, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, + 0x69, 0x64, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x04, 0x52, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, + 0x6d, 0x49, 0x64, 0x73, 0x12, 0x2e, 0x0a, 0x07, 0x70, 0x61, 0x79, 0x6c, 0x6f, 0x61, 0x64, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, + 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x07, 0x70, 0x61, 0x79, + 0x6c, 0x6f, 0x61, 0x64, 0x22, 0x53, 0x0a, 0x0b, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x43, 0x6c, + 0x6f, 0x73, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x5f, 0x69, 0x64, + 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x04, 0x52, 0x09, 0x73, 0x74, 0x72, 0x65, 0x61, 0x6d, 0x49, + 0x64, 0x73, 0x12, 0x25, 0x0a, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x18, 0x02, 0x20, 0x01, + 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x53, 0x74, 0x61, 0x74, 0x75, + 0x73, 0x52, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, 0x73, 0x22, 0x66, 0x0a, 0x06, 0x53, 0x74, 0x61, + 0x74, 0x75, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x05, 0x52, 0x04, 0x63, 0x6f, 0x64, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, + 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6d, 0x65, 0x73, 0x73, 0x61, 0x67, + 0x65, 0x12, 0x2e, 0x0a, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, 0x73, 0x18, 0x03, 0x20, 0x03, + 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x41, 0x6e, 0x79, 0x52, 0x07, 0x64, 0x65, 0x74, 0x61, 0x69, 0x6c, + 0x73, 0x32, 0x3e, 0x0a, 0x05, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x12, 0x35, 0x0a, 0x05, 0x50, 0x72, + 0x6f, 0x78, 0x79, 0x12, 0x13, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x2e, 0x50, 0x72, 0x6f, 0x78, + 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x11, 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, + 0x2e, 0x50, 0x72, 0x6f, 0x78, 0x79, 0x52, 0x65, 0x70, 0x6c, 0x79, 0x22, 0x00, 0x28, 0x01, 0x30, + 0x01, 0x42, 0x2b, 0x5a, 0x29, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, + 0x53, 0x6e, 0x6f, 0x77, 0x66, 0x6c, 0x61, 0x6b, 0x65, 0x2d, 0x4c, 0x61, 0x62, 0x73, 0x2f, 0x73, + 0x61, 0x6e, 0x73, 0x73, 0x68, 0x65, 0x6c, 0x6c, 0x2f, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x62, 0x06, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} var ( file_proxy_proto_rawDescOnce sync.Once - file_proxy_proto_rawDescData []byte + file_proxy_proto_rawDescData = file_proxy_proto_rawDesc ) func file_proxy_proto_rawDescGZIP() []byte { file_proxy_proto_rawDescOnce.Do(func() { - file_proxy_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_proxy_proto_rawDesc), len(file_proxy_proto_rawDesc))) + file_proxy_proto_rawDescData = protoimpl.X.CompressGZIP(file_proxy_proto_rawDescData) }) return file_proxy_proto_rawDescData } @@ -840,6 +901,116 @@ func file_proxy_proto_init() { if File_proxy_proto != nil { return } + if !protoimpl.UnsafeEnabled { + file_proxy_proto_msgTypes[0].Exporter = func(v any, i int) any { + switch v := v.(*ProxyRequest); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_proto_msgTypes[1].Exporter = func(v any, i int) any { + switch v := v.(*ProxyReply); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_proto_msgTypes[2].Exporter = func(v any, i int) any { + switch v := v.(*StartStream); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_proto_msgTypes[3].Exporter = func(v any, i int) any { + switch v := v.(*StartStreamReply); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_proto_msgTypes[4].Exporter = func(v any, i int) any { + switch v := v.(*ClientClose); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_proto_msgTypes[5].Exporter = func(v any, i int) any { + switch v := v.(*ClientCancel); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_proto_msgTypes[6].Exporter = func(v any, i int) any { + switch v := v.(*StreamData); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_proto_msgTypes[7].Exporter = func(v any, i int) any { + switch v := v.(*ServerClose); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + file_proxy_proto_msgTypes[8].Exporter = func(v any, i int) any { + switch v := v.(*Status); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } file_proxy_proto_msgTypes[0].OneofWrappers = []any{ (*ProxyRequest_StartStream)(nil), (*ProxyRequest_StreamData)(nil), @@ -859,7 +1030,7 @@ func file_proxy_proto_init() { out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: unsafe.Slice(unsafe.StringData(file_proxy_proto_rawDesc), len(file_proxy_proto_rawDesc)), + RawDescriptor: file_proxy_proto_rawDesc, NumEnums: 0, NumMessages: 9, NumExtensions: 0, @@ -870,6 +1041,7 @@ func file_proxy_proto_init() { MessageInfos: file_proxy_proto_msgTypes, }.Build() File_proxy_proto = out.File + file_proxy_proto_rawDesc = nil file_proxy_proto_goTypes = nil file_proxy_proto_depIdxs = nil } diff --git a/proxy/proxy.proto b/proxy/proxy.proto index 470ba0b3..6189c7b7 100644 --- a/proxy/proxy.proto +++ b/proxy/proxy.proto @@ -87,6 +87,12 @@ message StartStream { // Perform authz dry run instead actual execution. bool authz_dry_run = 5; + + // Optional. Forces the proxy to use a specific client credential when + // dialing this target. If empty or unset, the proxy uses its default + // credential. The proxy will reject unrecognized non-empty values with + // InvalidArgument rather than falling back to the default. + string force_credential = 6; } message StartStreamReply { diff --git a/proxy/proxy/proxy.go b/proxy/proxy/proxy.go index 57c69141..9d141265 100644 --- a/proxy/proxy/proxy.go +++ b/proxy/proxy/proxy.go @@ -79,6 +79,11 @@ type Conn struct { // Perform authz dry run instead of actual execution AuthzDryRun bool + // ForceCredential is passed in each StartStream to tell the proxy which + // client credential to use when dialing the target. The proxy will fail + // if the requested credential is not configured. Empty means default. + ForceCredential string + // UnaryInterceptors allow intercepting Invoke and InvokeOneMany calls // that go through a proxy. // It is unsafe to modify Intercepters while calls are in progress. @@ -455,10 +460,11 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_ req := &proxypb.ProxyRequest{ Request: &proxypb.ProxyRequest_StartStream{ StartStream: &proxypb.StartStream{ - Target: t, - MethodName: method, - Nonce: uint32(i), - AuthzDryRun: p.AuthzDryRun, + Target: t, + MethodName: method, + Nonce: uint32(i), + AuthzDryRun: p.AuthzDryRun, + ForceCredential: p.ForceCredential, }, }, } diff --git a/proxy/proxy_grpc.pb.go b/proxy/proxy_grpc.pb.go index b0650da0..7e73e5ae 100644 --- a/proxy/proxy_grpc.pb.go +++ b/proxy/proxy_grpc.pb.go @@ -16,7 +16,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.5.1 -// - protoc v6.33.4 +// - protoc v7.34.0 // source: proxy.proto package proxy diff --git a/proxy/server/server.go b/proxy/server/server.go index f9346648..296c562c 100644 --- a/proxy/server/server.go +++ b/proxy/server/server.go @@ -77,8 +77,9 @@ type Server struct { // A map of /Package.Service/Method => ServiceMethod serviceMap map[string]*ServiceMethod - // A dialer for making proxy -> target connections - dialer TargetDialer + // Named dialers for making proxy -> target connections. + // Key "" is the default dialer used when no force_credential is specified. + dialers map[string]TargetDialer // A policy authorizer, for authorizing proxy -> target requests authorizer rpcauth.RPCAuthorizer @@ -104,9 +105,16 @@ func New(dialer TargetDialer, authorizer rpcauth.RPCAuthorizer) *Server { // The supplied authorizer is used to authorize requests made // to targets. func NewWithServiceMap(dialer TargetDialer, authorizer rpcauth.RPCAuthorizer, serviceMap map[string]*ServiceMethod) *Server { + return NewWithDialersAndServiceMap(map[string]TargetDialer{"": dialer}, authorizer, serviceMap) +} + +// NewWithDialers creates a new Server with named dialers for credential-hint-based +// dialer selection and the global service map. The dialers map must contain a "" +// key for the default dialer. +func NewWithDialersAndServiceMap(dialers map[string]TargetDialer, authorizer rpcauth.RPCAuthorizer, serviceMap map[string]*ServiceMethod) *Server { return &Server{ serviceMap: serviceMap, - dialer: dialer, + dialers: dialers, authorizer: authorizer, } } @@ -122,7 +130,7 @@ func (s *Server) Proxy(stream pb.Proxy_ProxyServer) error { // create a new TargetStreamSet to manage the target streams // associated with this proxy connection - streamSet := NewTargetStreamSet(s.serviceMap, s.dialer, s.authorizer) + streamSet := NewTargetStreamSet(s.serviceMap, s.dialers, s.authorizer) // A single go-routine for handling all sends to the reply // channel diff --git a/proxy/server/target.go b/proxy/server/target.go index b85e2f8a..5a5de6bf 100644 --- a/proxy/server/target.go +++ b/proxy/server/target.go @@ -424,8 +424,10 @@ type TargetStreamSet struct { // A service method map used to resolve incoming stream requests to service methods serviceMethods map[string]*ServiceMethod - // A TargetDialer for initiating target connections - targetDialer TargetDialer + // Named dialers for initiating target connections. + // Key "" is the default dialer used when no force_credential is specified. + // Non-empty values that are not present in this map cause an error. + dialers map[string]TargetDialer // [rpcauthz.rpcAuthorizerImpl], for authorizing requests sent to targets. authorizer rpcauth.RPCAuthorizer @@ -444,11 +446,12 @@ type TargetStreamSet struct { noncePairs map[string]bool } -// NewTargetStreamSet creates a TargetStreamSet which manages a set of related TargetStreams -func NewTargetStreamSet(serviceMethods map[string]*ServiceMethod, dialer TargetDialer, authorizer rpcauth.RPCAuthorizer) *TargetStreamSet { +// NewTargetStreamSet creates a TargetStreamSet which manages a set of related TargetStreams. +// The dialers map must contain a "" key for the default dialer. +func NewTargetStreamSet(serviceMethods map[string]*ServiceMethod, dialers map[string]TargetDialer, authorizer rpcauth.RPCAuthorizer) *TargetStreamSet { return &TargetStreamSet{ serviceMethods: serviceMethods, - targetDialer: dialer, + dialers: dialers, authorizer: authorizer, streams: make(map[uint64]*TargetStream), closedStreams: make(map[uint64]bool), @@ -504,12 +507,21 @@ func (t *TargetStreamSet) Add(ctx context.Context, req *pb.StartStream, replyCha sendReply(reply) return nil } + hint := req.GetForceCredential() + dialer, ok := t.dialers[hint] + if !ok { + reply.GetStartStreamReply().Reply = &pb.StartStreamReply_ErrorStatus{ + ErrorStatus: convertStatus(status.Newf(codes.InvalidArgument, "unknown credential %q: not configured on this proxy", hint)), + } + sendReply(reply) + return nil + } var dialTimeout *time.Duration if req.DialTimeout != nil { d := req.DialTimeout.AsDuration() dialTimeout = &d } - stream, err := NewTargetStream(ctx, req.GetTarget(), t.targetDialer, dialTimeout, serviceMethod, t.authorizer, req.GetAuthzDryRun()) + stream, err := NewTargetStream(ctx, req.GetTarget(), dialer, dialTimeout, serviceMethod, t.authorizer, req.GetAuthzDryRun()) if err != nil { reply.GetStartStreamReply().Reply = &pb.StartStreamReply_ErrorStatus{ ErrorStatus: convertStatus(status.New(codes.Internal, err.Error())), diff --git a/proxy/server/target_test.go b/proxy/server/target_test.go index 61cc7805..dcad26f0 100644 --- a/proxy/server/target_test.go +++ b/proxy/server/target_test.go @@ -43,7 +43,7 @@ func (e dialErrTargetDialer) DialContext(ctx context.Context, target string, dia func TestEmptyStreamSet(t *testing.T) { ctx := context.Background() errDialer := dialErrTargetDialer(codes.Unimplemented) - ss := NewTargetStreamSet(map[string]*ServiceMethod{}, errDialer, nil) + ss := NewTargetStreamSet(map[string]*ServiceMethod{}, map[string]TargetDialer{"": errDialer}, nil) // wait does not block when no work is being done finishedWait := make(chan struct{}) @@ -82,13 +82,14 @@ func TestEmptyStreamSet(t *testing.T) { func TestStreamSetAddErrors(t *testing.T) { errDialer := dialErrTargetDialer(codes.Unimplemented) serviceMap := LoadGlobalServiceMap() - ss := NewTargetStreamSet(serviceMap, errDialer, nil) + ss := NewTargetStreamSet(serviceMap, map[string]TargetDialer{"": errDialer}, nil) for _, tc := range []struct { - name string - method string - nonce uint32 - errCode codes.Code + name string + method string + nonce uint32 + forceCredential string + errCode codes.Code }{ { name: "dial failure no error", @@ -101,6 +102,13 @@ func TestStreamSetAddErrors(t *testing.T) { method: "/Nosuch.Method/Foo", errCode: codes.InvalidArgument, }, + { + name: "unknown credential hint", + nonce: 3, + method: "/Testdata.TestService/TestUnary", + forceCredential: "nonexistent", + errCode: codes.InvalidArgument, + }, } { tc := tc t.Run(tc.name, func(t *testing.T) { @@ -108,9 +116,10 @@ func TestStreamSetAddErrors(t *testing.T) { replyChan := make(chan *pb.ProxyReply, 1) req := &pb.StartStream{ - Target: "nosuchhost:000", - Nonce: tc.nonce, - MethodName: tc.method, + Target: "nosuchhost:000", + Nonce: tc.nonce, + MethodName: tc.method, + ForceCredential: tc.forceCredential, } err := ss.Add(context.Background(), req, replyChan, nil /*doneChan should not be called*/) testutil.FatalOnErr(fmt.Sprintf("StartStream(+%v)", req), err, t) @@ -167,7 +176,7 @@ func TestTargetStreamAddNonBlocking(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() serviceMap := LoadGlobalServiceMap() - ss := NewTargetStreamSet(serviceMap, blockingClientDialer{}, nil) + ss := NewTargetStreamSet(serviceMap, map[string]TargetDialer{"": blockingClientDialer{}}, nil) replyChan := make(chan *pb.ProxyReply, 1) doneChan := make(chan struct{}) req := &pb.StartStream{