diff --git a/internal/rpc/sui/client.go b/internal/rpc/sui/client.go index 2b824f3..c475c29 100644 --- a/internal/rpc/sui/client.go +++ b/internal/rpc/sui/client.go @@ -7,7 +7,9 @@ import ( "time" "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" "google.golang.org/grpc/credentials" + _ "google.golang.org/grpc/encoding/gzip" "google.golang.org/grpc/keepalive" "google.golang.org/protobuf/types/known/fieldmaskpb" @@ -17,6 +19,7 @@ import ( // SuiClient implements the SuiAPI interface using gRPC type SuiClient struct { + connMu sync.Mutex conn *grpc.ClientConn ledgerClient v2.LedgerServiceClient subscriberClient v2.SubscriptionServiceClient @@ -42,7 +45,21 @@ func NewSuiClient(url string) *SuiClient { } // connect establishes the gRPC connection if not already connected -func (c *SuiClient) connect(ctx context.Context) error { +func (c *SuiClient) connect(_ context.Context) error { + c.connMu.Lock() + defer c.connMu.Unlock() + + if c.conn != nil { + if c.conn.GetState() != connectivity.Shutdown { + c.ensureServiceClientsLocked(c.conn) + return nil + } + _ = c.conn.Close() + c.conn = nil + c.ledgerClient = nil + c.subscriberClient = nil + } + // Apply default options options := &clientOptions{ maxMsgSize: 50 * 1024 * 1024, // 50MB @@ -78,12 +95,19 @@ func (c *SuiClient) connect(ctx context.Context) error { return fmt.Errorf("failed to connect: %w", err) } c.conn = conn - c.ledgerClient = v2.NewLedgerServiceClient(conn) - // Initialize subscription client - c.subscriberClient = v2.NewSubscriptionServiceClient(conn) + c.ensureServiceClientsLocked(conn) return nil } +func (c *SuiClient) ensureServiceClientsLocked(conn *grpc.ClientConn) { + if c.ledgerClient == nil { + c.ledgerClient = v2.NewLedgerServiceClient(conn) + } + if c.subscriberClient == nil { + c.subscriberClient = v2.NewSubscriptionServiceClient(conn) + } +} + // StartStreaming starts the background streaming process func (c *SuiClient) StartStreaming(ctx context.Context) error { // Call connect to ensure connection exists before streaming @@ -377,8 +401,16 @@ func (c *SuiClient) GetURL() string { // Close closes the gRPC connection func (c *SuiClient) Close() error { - if c.conn != nil { - return c.conn.Close() + c.connMu.Lock() + defer c.connMu.Unlock() + + if c.conn == nil { + return nil } - return nil + + conn := c.conn + c.conn = nil + c.ledgerClient = nil + c.subscriberClient = nil + return conn.Close() } diff --git a/internal/rpc/sui/client_test.go b/internal/rpc/sui/client_test.go new file mode 100644 index 0000000..937634c --- /dev/null +++ b/internal/rpc/sui/client_test.go @@ -0,0 +1,42 @@ +package sui + +import ( + "context" + "testing" + + "google.golang.org/grpc/encoding" + grpcgzip "google.golang.org/grpc/encoding/gzip" + + "github.com/stretchr/testify/require" +) + +func TestSuiClientRegistersGzipCompressor(t *testing.T) { + t.Parallel() + + require.NotNil(t, encoding.GetCompressor(grpcgzip.Name)) +} + +func TestSuiClientConnectReusesExistingConnection(t *testing.T) { + t.Parallel() + + client := NewSuiClient("passthrough:///127.0.0.1:12345") + + require.NoError(t, client.connect(context.Background())) + firstConn := client.conn + require.NotNil(t, firstConn) + require.NotNil(t, client.ledgerClient) + require.NotNil(t, client.subscriberClient) + + require.NoError(t, client.connect(context.Background())) + require.Same(t, firstConn, client.conn) + + require.NoError(t, client.Close()) + require.Nil(t, client.conn) + require.Nil(t, client.ledgerClient) + require.Nil(t, client.subscriberClient) + + require.NoError(t, client.connect(context.Background())) + require.NotNil(t, client.conn) + require.NotSame(t, firstConn, client.conn) + require.NoError(t, client.Close()) +}