Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions internal/rpc/sui/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ import (
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/credentials"
_ "google.golang.org/grpc/encoding/gzip"
"google.golang.org/grpc/keepalive"
"google.golang.org/protobuf/types/known/fieldmaskpb"

Expand All @@ -17,6 +19,7 @@ import (

// SuiClient implements the SuiAPI interface using gRPC
type SuiClient struct {
connMu sync.Mutex
conn *grpc.ClientConn
ledgerClient v2.LedgerServiceClient
subscriberClient v2.SubscriptionServiceClient
Expand All @@ -42,7 +45,21 @@ func NewSuiClient(url string) *SuiClient {
}

// connect establishes the gRPC connection if not already connected
func (c *SuiClient) connect(ctx context.Context) error {
func (c *SuiClient) connect(_ context.Context) error {
c.connMu.Lock()
defer c.connMu.Unlock()

if c.conn != nil {
if c.conn.GetState() != connectivity.Shutdown {
c.ensureServiceClientsLocked(c.conn)
return nil
}
_ = c.conn.Close()
c.conn = nil
c.ledgerClient = nil
c.subscriberClient = nil
}

// Apply default options
options := &clientOptions{
maxMsgSize: 50 * 1024 * 1024, // 50MB
Expand Down Expand Up @@ -78,12 +95,19 @@ func (c *SuiClient) connect(ctx context.Context) error {
return fmt.Errorf("failed to connect: %w", err)
}
c.conn = conn
c.ledgerClient = v2.NewLedgerServiceClient(conn)
// Initialize subscription client
c.subscriberClient = v2.NewSubscriptionServiceClient(conn)
c.ensureServiceClientsLocked(conn)
return nil
}

func (c *SuiClient) ensureServiceClientsLocked(conn *grpc.ClientConn) {
if c.ledgerClient == nil {
c.ledgerClient = v2.NewLedgerServiceClient(conn)
}
if c.subscriberClient == nil {
c.subscriberClient = v2.NewSubscriptionServiceClient(conn)
}
}

// StartStreaming starts the background streaming process
func (c *SuiClient) StartStreaming(ctx context.Context) error {
// Call connect to ensure connection exists before streaming
Expand Down Expand Up @@ -377,8 +401,16 @@ func (c *SuiClient) GetURL() string {

// Close closes the gRPC connection
func (c *SuiClient) Close() error {
if c.conn != nil {
return c.conn.Close()
c.connMu.Lock()
defer c.connMu.Unlock()

if c.conn == nil {
return nil
}
return nil

conn := c.conn
c.conn = nil
c.ledgerClient = nil
c.subscriberClient = nil
return conn.Close()
}
42 changes: 42 additions & 0 deletions internal/rpc/sui/client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package sui

import (
"context"
"testing"

"google.golang.org/grpc/encoding"
grpcgzip "google.golang.org/grpc/encoding/gzip"

"github.com/stretchr/testify/require"
)

func TestSuiClientRegistersGzipCompressor(t *testing.T) {
t.Parallel()

require.NotNil(t, encoding.GetCompressor(grpcgzip.Name))
}

func TestSuiClientConnectReusesExistingConnection(t *testing.T) {
t.Parallel()

client := NewSuiClient("passthrough:///127.0.0.1:12345")

require.NoError(t, client.connect(context.Background()))
firstConn := client.conn
require.NotNil(t, firstConn)
require.NotNil(t, client.ledgerClient)
require.NotNil(t, client.subscriberClient)

require.NoError(t, client.connect(context.Background()))
require.Same(t, firstConn, client.conn)

require.NoError(t, client.Close())
require.Nil(t, client.conn)
require.Nil(t, client.ledgerClient)
require.Nil(t, client.subscriberClient)

require.NoError(t, client.connect(context.Background()))
require.NotNil(t, client.conn)
require.NotSame(t, firstConn, client.conn)
require.NoError(t, client.Close())
}
Loading