diff --git a/async.go b/async.go index 50d99591..209717c6 100644 --- a/async.go +++ b/async.go @@ -31,27 +31,27 @@ func (f *Async[Resp]) Done() bool { // AsyncMajority returns an Async future that resolves when a majority quorum is reached. // Messages are sent immediately (synchronously) to preserve ordering when multiple // async calls are created in sequence. -func (r *Responses[Resp]) AsyncMajority() *Async[Resp] { +func (r *Responses[T, Resp]) AsyncMajority() *Async[Resp] { quorumSize := r.size/2 + 1 return r.AsyncThreshold(quorumSize) } // AsyncFirst returns an Async future that resolves when the first response is received. // Messages are sent immediately (synchronously) to preserve ordering. -func (r *Responses[Resp]) AsyncFirst() *Async[Resp] { +func (r *Responses[T, Resp]) AsyncFirst() *Async[Resp] { return r.AsyncThreshold(1) } // AsyncAll returns an Async future that resolves when all nodes have responded. // Messages are sent immediately (synchronously) to preserve ordering. -func (r *Responses[Resp]) AsyncAll() *Async[Resp] { +func (r *Responses[T, Resp]) AsyncAll() *Async[Resp] { return r.AsyncThreshold(r.size) } // AsyncThreshold returns an Async future that resolves when the threshold is reached. // Messages are sent immediately (synchronously) to preserve ordering when multiple // async calls are created in sequence. -func (r *Responses[Resp]) AsyncThreshold(threshold int) *Async[Resp] { +func (r *Responses[T, Resp]) AsyncThreshold(threshold int) *Async[Resp] { // Send messages synchronously before spawning the goroutine to preserve ordering r.sendNow() diff --git a/async_test.go b/async_test.go index 307367ba..571290f4 100644 --- a/async_test.go +++ b/async_test.go @@ -13,7 +13,7 @@ import ( func TestAsync(t *testing.T) { // a type alias short hand for the responses type - type respType = *gorums.Responses[*pb.StringValue] + type respType = *gorums.Responses[uint32, *pb.StringValue] tests := []struct { name string call func(respType) *gorums.Async[*pb.StringValue] @@ -45,7 +45,7 @@ func TestAsync(t *testing.T) { t.Run(tt.name, func(t *testing.T) { config := gorums.TestConfiguration(t, tt.numNodes, gorums.EchoServerFn) ctx := gorums.TestContext(t, 2*time.Second) - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.TestMethod, @@ -71,7 +71,7 @@ func TestAsync_Error(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() // Cancel immediately - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.TestMethod, @@ -92,7 +92,7 @@ func BenchmarkAsyncQuorumCall(b *testing.B) { b.Run(fmt.Sprintf("AsyncMajority/%d", numNodes), func(b *testing.B) { b.ReportAllocs() for b.Loop() { - future := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + future := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, @@ -108,7 +108,7 @@ func BenchmarkAsyncQuorumCall(b *testing.B) { b.Run(fmt.Sprintf("BlockingMajority/%d", numNodes), func(b *testing.B) { b.ReportAllocs() for b.Loop() { - _, err := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + _, err := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, diff --git a/benchmark/benchmark.go b/benchmark/benchmark.go index 667cd358..d9e77f7e 100644 --- a/benchmark/benchmark.go +++ b/benchmark/benchmark.go @@ -35,8 +35,8 @@ type Bench struct { type ( benchFunc func(Options) (*Result, error) - qcFunc func(*gorums.ConfigContext, *Echo, int, ...gorums.CallOption) (*Echo, error) - asyncQCFunc func(*gorums.ConfigContext, *Echo, int, ...gorums.CallOption) AsyncEcho + qcFunc func(*ConfigContext, *Echo, int, ...gorums.CallOption) (*Echo, error) + asyncQCFunc func(*ConfigContext, *Echo, int, ...gorums.CallOption) AsyncEcho serverFunc func(context.Context, *TimedMsg) ) @@ -233,7 +233,7 @@ func GetBenchmarks(config Configuration) []Bench { Name: "QuorumCall", Description: "NodeStream based quorum call implementation with FIFO ordering", runBench: func(opts Options) (*Result, error) { - return runQCBenchmark(opts, config, func(ctx *gorums.ConfigContext, in *Echo, quorumSize int, callOpts ...gorums.CallOption) (*Echo, error) { + return runQCBenchmark(opts, config, func(ctx *ConfigContext, in *Echo, quorumSize int, callOpts ...gorums.CallOption) (*Echo, error) { return QuorumCall(ctx, in, callOpts...).Threshold(quorumSize) }) }, @@ -242,7 +242,7 @@ func GetBenchmarks(config Configuration) []Bench { Name: "AsyncQuorumCall", Description: "NodeStream based async quorum call implementation with FIFO ordering", runBench: func(opts Options) (*Result, error) { - return runAsyncQCBenchmark(opts, config, func(ctx *gorums.ConfigContext, in *Echo, quorumSize int, callOpts ...gorums.CallOption) AsyncEcho { + return runAsyncQCBenchmark(opts, config, func(ctx *ConfigContext, in *Echo, quorumSize int, callOpts ...gorums.CallOption) AsyncEcho { return QuorumCall(ctx, in, callOpts...).AsyncThreshold(quorumSize) }) }, @@ -251,7 +251,7 @@ func GetBenchmarks(config Configuration) []Bench { Name: "SlowServer", Description: "Quorum Call with a 10s processing time on the server", runBench: func(opts Options) (*Result, error) { - return runQCBenchmark(opts, config, func(ctx *gorums.ConfigContext, in *Echo, quorumSize int, callOpts ...gorums.CallOption) (*Echo, error) { + return runQCBenchmark(opts, config, func(ctx *ConfigContext, in *Echo, quorumSize int, callOpts ...gorums.CallOption) (*Echo, error) { return SlowServer(ctx, in, callOpts...).Threshold(quorumSize) }) }, diff --git a/benchmark/benchmark.pb.go b/benchmark/benchmark.pb.go index 25131081..4040a087 100644 --- a/benchmark/benchmark.pb.go +++ b/benchmark/benchmark.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: benchmark/benchmark.proto package benchmark diff --git a/benchmark/benchmark_gorums.pb.go b/benchmark/benchmark_gorums.pb.go index 494703ce..3eefa9b1 100644 --- a/benchmark/benchmark_gorums.pb.go +++ b/benchmark/benchmark_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: benchmark/benchmark.proto package benchmark @@ -21,9 +21,11 @@ const ( // Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -32,20 +34,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -66,9 +70,12 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } +// NodeID is a type alias for the type used to identify nodes. +type NodeID = uint32 + // AsyncEcho is a future for async quorum calls returning *Echo. type AsyncEcho = *gorums.Async[*Echo] @@ -97,55 +104,55 @@ type CorrectableStartResponse = *gorums.Correctable[*StartResponse] var _ emptypb.Empty // StartServerBenchmark starts a server-side benchmark campaign. -func StartServerBenchmark(ctx *gorums.ConfigContext, in *StartRequest, opts ...gorums.CallOption) *gorums.Responses[*StartResponse] { - return gorums.QuorumCall[*StartRequest, *StartResponse]( +func StartServerBenchmark(ctx *ConfigContext, in *StartRequest, opts ...gorums.CallOption) *gorums.Responses[NodeID, *StartResponse] { + return gorums.QuorumCall[NodeID, *StartRequest, *StartResponse]( ctx, in, "benchmark.Benchmark.StartServerBenchmark", opts..., ) } // StopServerBenchmark stops a server-side benchmark campaign. -func StopServerBenchmark(ctx *gorums.ConfigContext, in *StopRequest, opts ...gorums.CallOption) *gorums.Responses[*Result] { - return gorums.QuorumCall[*StopRequest, *Result]( +func StopServerBenchmark(ctx *ConfigContext, in *StopRequest, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Result] { + return gorums.QuorumCall[NodeID, *StopRequest, *Result]( ctx, in, "benchmark.Benchmark.StopServerBenchmark", opts..., ) } // StartBenchmark starts a client-side benchmark campaign. -func StartBenchmark(ctx *gorums.ConfigContext, in *StartRequest, opts ...gorums.CallOption) *gorums.Responses[*StartResponse] { - return gorums.QuorumCall[*StartRequest, *StartResponse]( +func StartBenchmark(ctx *ConfigContext, in *StartRequest, opts ...gorums.CallOption) *gorums.Responses[NodeID, *StartResponse] { + return gorums.QuorumCall[NodeID, *StartRequest, *StartResponse]( ctx, in, "benchmark.Benchmark.StartBenchmark", opts..., ) } // StopBenchmark stops a client-side benchmark campaign. -func StopBenchmark(ctx *gorums.ConfigContext, in *StopRequest, opts ...gorums.CallOption) *gorums.Responses[*MemoryStat] { - return gorums.QuorumCall[*StopRequest, *MemoryStat]( +func StopBenchmark(ctx *ConfigContext, in *StopRequest, opts ...gorums.CallOption) *gorums.Responses[NodeID, *MemoryStat] { + return gorums.QuorumCall[NodeID, *StopRequest, *MemoryStat]( ctx, in, "benchmark.Benchmark.StopBenchmark", opts..., ) } // QuorumCall performs an echo quorum call on all servers. -func QuorumCall(ctx *gorums.ConfigContext, in *Echo, opts ...gorums.CallOption) *gorums.Responses[*Echo] { - return gorums.QuorumCall[*Echo, *Echo]( +func QuorumCall(ctx *ConfigContext, in *Echo, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Echo] { + return gorums.QuorumCall[NodeID, *Echo, *Echo]( ctx, in, "benchmark.Benchmark.QuorumCall", opts..., ) } // SlowServer performs an echo quorum call on slow servers. -func SlowServer(ctx *gorums.ConfigContext, in *Echo, opts ...gorums.CallOption) *gorums.Responses[*Echo] { - return gorums.QuorumCall[*Echo, *Echo]( +func SlowServer(ctx *ConfigContext, in *Echo, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Echo] { + return gorums.QuorumCall[NodeID, *Echo, *Echo]( ctx, in, "benchmark.Benchmark.SlowServer", opts..., ) } // Multicast performs a multicast call to all servers. -func Multicast(ctx *gorums.ConfigContext, in *TimedMsg, opts ...gorums.CallOption) error { +func Multicast(ctx *ConfigContext, in *TimedMsg, opts ...gorums.CallOption) error { return gorums.Multicast(ctx, in, "benchmark.Benchmark.Multicast", opts...) } diff --git a/callopts.go b/callopts.go index 5c7efc41..352f9eb8 100644 --- a/callopts.go +++ b/callopts.go @@ -1,7 +1,6 @@ package gorums import ( - "google.golang.org/protobuf/proto" "google.golang.org/protobuf/runtime/protoimpl" ) @@ -54,7 +53,7 @@ func IgnoreErrors() CallOption { // resp, err := ReadQC(ctx, req, // gorums.Interceptors(loggingInterceptor, filterInterceptor), // ).Majority() -func Interceptors[Req, Resp proto.Message](interceptors ...QuorumInterceptor[Req, Resp]) CallOption { +func Interceptors[T NodeID, Req, Resp msg](interceptors ...QuorumInterceptor[T, Req, Resp]) CallOption { return func(o *callOptions) { for _, interceptor := range interceptors { o.interceptors = append(o.interceptors, interceptor) diff --git a/callopts_test.go b/callopts_test.go index 35418268..41dabbf0 100644 --- a/callopts_test.go +++ b/callopts_test.go @@ -33,7 +33,9 @@ func TestCallOptionsMustWaitSendDone(t *testing.T) { } func BenchmarkGetCallOptions(b *testing.B) { - interceptor := func(_ *ClientCtx[msg, msg], next ResponseSeq[msg]) ResponseSeq[msg] { return next } + interceptor := func(_ *ClientCtx[uint32, msg, msg], next ResponseSeq[uint32, msg]) ResponseSeq[uint32, msg] { + return next + } tests := []struct { numOpts int }{ diff --git a/channel.go b/channel.go index 6e032bcf..25cb0e00 100644 --- a/channel.go +++ b/channel.go @@ -14,17 +14,17 @@ import ( ) // NodeResponse wraps a response value from node ID, and an error if any. -type NodeResponse[T any] struct { - NodeID uint32 - Value T +type NodeResponse[T NodeID, V any] struct { + NodeID T + Value V Err error } -// newNodeResponse converts a NodeResponse[msg] to a NodeResponse[Resp]. -// This is necessary because the channel layer's response router returns a -// NodeResponse[msg], while the calltype expects a NodeResponse[Resp]. -func newNodeResponse[Resp msg](r NodeResponse[msg]) NodeResponse[Resp] { - res := NodeResponse[Resp]{ +// newNodeResponse converts a NodeResponse[T, msg] to a NodeResponse[T, Resp]. +// This is necessary because the channel's response router returns a +// NodeResponse[T, msg], while the calltype expects a NodeResponse[T, Resp]. +func newNodeResponse[T NodeID, Resp msg](r NodeResponse[T, msg]) NodeResponse[T, Resp] { + res := NodeResponse[T, Resp]{ NodeID: r.NodeID, Err: r.Err, } @@ -43,18 +43,18 @@ var ( streamDownErr = status.Error(codes.Unavailable, "stream is down") ) -type request struct { +type request[T NodeID] struct { ctx context.Context msg *Message waitSendDone bool streaming bool - responseChan chan<- NodeResponse[proto.Message] + responseChan chan<- NodeResponse[T, proto.Message] sendTime time.Time } -type channel struct { - sendQ chan request - id uint32 +type channel[T NodeID] struct { + sendQ chan request[T] + id T // Connection lifecycle management: node close() cancels the // connection context to stop all goroutines and the NodeStream @@ -79,7 +79,7 @@ type channel struct { // Response routing; the map holds pending requests waiting for responses. // The request contains the responseChan on which to send the response // to the caller. - responseRouters map[uint64]request + responseRouters map[uint64]request[T] responseMut sync.Mutex closeOnceFunc func() error } @@ -91,16 +91,16 @@ type channel struct { // have not yet been established. This is to prevent deadlock when invoking // a call type. The sender blocks on the sendQ and the receiver waits for // the stream to become available. -func newChannel(parentCtx context.Context, conn *grpc.ClientConn, id uint32, sendBufferSize uint) *channel { +func newChannel[T NodeID](parentCtx context.Context, conn *grpc.ClientConn, id T, sendBufferSize uint) *channel[T] { ctx, connCancel := context.WithCancel(parentCtx) - c := &channel{ - sendQ: make(chan request, sendBufferSize), + c := &channel[T]{ + sendQ: make(chan request[T], sendBufferSize), id: id, conn: conn, connCtx: ctx, connCancel: connCancel, latency: -1 * time.Second, - responseRouters: make(map[uint64]request), + responseRouters: make(map[uint64]request[T]), streamReady: make(chan struct{}, 1), } c.closeOnceFunc = sync.OnceValue(func() error { @@ -117,7 +117,7 @@ func newChannel(parentCtx context.Context, conn *grpc.ClientConn, id uint32, sen } // close closes the channel and the underlying connection exactly once. -func (c *channel) close() error { +func (c *channel[T]) close() error { return c.closeOnceFunc() } @@ -125,7 +125,7 @@ func (c *channel) close() error { // receiver goroutines, and signals the receiver when the stream is ready. // gRPC automatically handles TCP connection state when creating the stream. // This method is safe for concurrent use. -func (c *channel) ensureStream() error { +func (c *channel[T]) ensureStream() error { if err := c.ensureConnectedNodeStream(); err != nil { return err } @@ -141,7 +141,7 @@ func (c *channel) ensureStream() error { // ensureConnectedNodeStream ensures there is an active and connected // NodeStream, or creates a new stream if one doesn't already exist. // This method is safe for concurrent use. -func (c *channel) ensureConnectedNodeStream() (err error) { +func (c *channel[T]) ensureConnectedNodeStream() (err error) { c.streamMut.Lock() defer c.streamMut.Unlock() // if we already have a ready connection and an active stream, do nothing @@ -155,7 +155,7 @@ func (c *channel) ensureConnectedNodeStream() (err error) { } // getStream returns the current stream, or nil if no stream is available. -func (c *channel) getStream() grpc.ClientStream { +func (c *channel[T]) getStream() grpc.ClientStream { c.streamMut.Lock() defer c.streamMut.Unlock() return c.gorumsStream @@ -163,7 +163,7 @@ func (c *channel) getStream() grpc.ClientStream { // clearStream cancels the current stream context and clears the stream reference. // This triggers reconnection on the next send attempt. -func (c *channel) clearStream() { +func (c *channel[T]) clearStream() { c.streamMut.Lock() c.streamCancel() c.gorumsStream = nil @@ -172,13 +172,13 @@ func (c *channel) clearStream() { // isConnected returns true if the gRPC connection is in Ready state and we have an active stream. // This method is safe for concurrent use. -func (c *channel) isConnected() bool { +func (c *channel[T]) isConnected() bool { return c.conn.GetState() == connectivity.Ready && c.getStream() != nil } // enqueue adds the request to the send queue and sets up response routing if needed. // If the node is closed, it responds with an error instead. -func (c *channel) enqueue(req request) { +func (c *channel[T]) enqueue(req request[T]) { if req.responseChan != nil { req.sendTime = time.Now() msgID := req.msg.GetMessageID() @@ -191,7 +191,7 @@ func (c *channel) enqueue(req request) { select { case <-c.connCtx.Done(): // the node's close() method was called: respond with error instead of enqueueing - c.routeResponse(req.msg.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.id, Err: nodeClosedErr}) + c.routeResponse(req.msg.GetMessageID(), NodeResponse[T, proto.Message]{NodeID: c.id, Err: nodeClosedErr}) return case c.sendQ <- req: // enqueued successfully @@ -200,7 +200,7 @@ func (c *channel) enqueue(req request) { // routeResponse routes the response to the appropriate response channel based on msgID. // If no matching request is found, the response is discarded. -func (c *channel) routeResponse(msgID uint64, resp NodeResponse[proto.Message]) { +func (c *channel[T]) routeResponse(msgID uint64, resp NodeResponse[T, proto.Message]) { c.responseMut.Lock() defer c.responseMut.Unlock() if req, ok := c.responseRouters[msgID]; ok { @@ -217,11 +217,11 @@ func (c *channel) routeResponse(msgID uint64, resp NodeResponse[proto.Message]) // cancelPendingMsgs cancels all pending messages by sending an error response to each // associated request. This is called when the stream goes down to notify all waiting calls. -func (c *channel) cancelPendingMsgs(err error) { +func (c *channel[T]) cancelPendingMsgs(err error) { c.responseMut.Lock() defer c.responseMut.Unlock() for msgID, req := range c.responseRouters { - req.responseChan <- NodeResponse[proto.Message]{NodeID: c.id, Err: err} + req.responseChan <- NodeResponse[T, proto.Message]{NodeID: c.id, Err: err} // delete the router if we are only expecting a single reply message if !req.streaming { delete(c.responseRouters, msgID) @@ -231,7 +231,7 @@ func (c *channel) cancelPendingMsgs(err error) { // deleteRouter removes the response router for the given msgID. // This is used for cleaning up after streaming calls are done. -func (c *channel) deleteRouter(msgID uint64) { +func (c *channel[T]) deleteRouter(msgID uint64) { c.responseMut.Lock() defer c.responseMut.Unlock() delete(c.responseRouters, msgID) @@ -239,11 +239,11 @@ func (c *channel) deleteRouter(msgID uint64) { // sender goroutine takes requests from the sendQ and sends them on the stream. // If the stream is down, it tries to re-establish it. -func (c *channel) sender() { +func (c *channel[T]) sender() { // eager connect; ignored if stream is down (will be retried on send) _ = c.ensureStream() - var req request + var req request[T] for { select { case <-c.connCtx.Done(): @@ -253,11 +253,11 @@ func (c *channel) sender() { // take next request from sendQ } if err := c.ensureStream(); err != nil { - c.routeResponse(req.msg.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.id, Err: err}) + c.routeResponse(req.msg.GetMessageID(), NodeResponse[T, proto.Message]{NodeID: c.id, Err: err}) continue } if err := c.sendMsg(req); err != nil { - c.routeResponse(req.msg.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.id, Err: err}) + c.routeResponse(req.msg.GetMessageID(), NodeResponse[T, proto.Message]{NodeID: c.id, Err: err}) } } } @@ -265,7 +265,7 @@ func (c *channel) sender() { // receiver goroutine receives messages from the stream and routes them to // the appropriate response router. If the stream goes down, it clears the // stream reference and cancels all pending messages with a stream down error. -func (c *channel) receiver() { +func (c *channel[T]) receiver() { for { stream := c.getStream() if stream == nil { @@ -287,7 +287,7 @@ func (c *channel) receiver() { c.clearStream() } else { err := resp.GetStatus().Err() - c.routeResponse(resp.GetMessageID(), NodeResponse[proto.Message]{NodeID: c.id, Value: resp.GetProtoMessage(), Err: err}) + c.routeResponse(resp.GetMessageID(), NodeResponse[T, proto.Message]{NodeID: c.id, Value: resp.GetProtoMessage(), Err: err}) } select { @@ -299,7 +299,7 @@ func (c *channel) receiver() { } } -func (c *channel) sendMsg(req request) (err error) { +func (c *channel[T]) sendMsg(req request[T]) (err error) { defer func() { // For one-way call types (Unicast/Multicast), the caller can choose between two behaviors: // @@ -317,7 +317,7 @@ func (c *channel) sendMsg(req request) (err error) { // wait for actual server responses, so waitSendDone is false for them. if req.waitSendDone && err == nil { // Send succeeded: unblock the caller and clean up the responseRouter - c.routeResponse(req.msg.GetMessageID(), NodeResponse[proto.Message]{}) + c.routeResponse(req.msg.GetMessageID(), NodeResponse[T, proto.Message]{}) } }() @@ -363,21 +363,21 @@ func (c *channel) sendMsg(req request) (err error) { return err } -func (c *channel) setLastErr(err error) { +func (c *channel[T]) setLastErr(err error) { c.mu.Lock() defer c.mu.Unlock() c.lastError = err } // lastErr returns the last error encountered (if any) when using this channel. -func (c *channel) lastErr() error { +func (c *channel[T]) lastErr() error { c.mu.Lock() defer c.mu.Unlock() return c.lastError } // channelLatency returns the latency between the client and the server associated with this channel. -func (c *channel) channelLatency() time.Duration { +func (c *channel[T]) channelLatency() time.Duration { c.mu.Lock() defer c.mu.Unlock() return c.latency @@ -385,7 +385,7 @@ func (c *channel) channelLatency() time.Duration { // updateLatency updates the latency between the client and the server associated with this channel. // It uses a simple moving average to calculate the latency. -func (c *channel) updateLatency(rtt time.Duration) { +func (c *channel[T]) updateLatency(rtt time.Duration) { c.mu.Lock() defer c.mu.Unlock() if c.latency < 0 { diff --git a/channel_test.go b/channel_test.go index 6a3b1432..26f3bad0 100644 --- a/channel_test.go +++ b/channel_test.go @@ -42,13 +42,13 @@ func delayServerFn(delay time.Duration) func(_ int) ServerIface { } } -func sendRequest(t testing.TB, node *Node, req request, msgID uint64) NodeResponse[proto.Message] { +func sendRequest(t testing.TB, node *Node[uint32], req request[uint32], msgID uint64) NodeResponse[uint32, proto.Message] { t.Helper() if req.ctx == nil { req.ctx = t.Context() } req.msg = NewRequestMessage(ordering.NewGorumsMetadata(req.ctx, msgID, mock.TestMethod), nil) - replyChan := make(chan NodeResponse[proto.Message], 1) + replyChan := make(chan NodeResponse[uint32, proto.Message], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -57,16 +57,16 @@ func sendRequest(t testing.TB, node *Node, req request, msgID uint64) NodeRespon return resp case <-time.After(defaultTestTimeout): t.Fatalf("timeout waiting for response to message %d", msgID) - return NodeResponse[proto.Message]{} + return NodeResponse[uint32, proto.Message]{} } } type msgResponse struct { msgID uint64 - resp NodeResponse[proto.Message] + resp NodeResponse[uint32, proto.Message] } -func sendReq(t testing.TB, results chan<- msgResponse, node *Node, goroutineID, msgsToSend int, req request) { +func sendReq(t testing.TB, results chan<- msgResponse, node *Node[uint32], goroutineID, msgsToSend int, req request[uint32]) { for j := range msgsToSend { msgID := uint64(goroutineID*1000 + j) resp := sendRequest(t, node, req, msgID) @@ -78,10 +78,10 @@ func sendReq(t testing.TB, results chan<- msgResponse, node *Node, goroutineID, // adds it to a new manager. This is useful for testing node and channel // behavior without an active server. The manager is automatically closed // when the test finishes. -func testNodeWithoutServer(t testing.TB, opts ...ManagerOption) *Node { +func testNodeWithoutServer(t testing.TB, opts ...ManagerOption) *Node[uint32] { t.Helper() mgrOpts := append([]ManagerOption{InsecureDialOptions(t)}, opts...) - mgr := NewManager(mgrOpts...) + mgr := NewManager[uint32](mgrOpts...) t.Cleanup(Closer(t, mgr)) // Use a high port number that's unlikely to have anything listening. // We use a fixed ID for simplicity. @@ -94,14 +94,14 @@ func testNodeWithoutServer(t testing.TB, opts ...ManagerOption) *Node { // Helper functions for accessing channel internals -func routerExists(node *Node, msgID uint64) bool { +func routerExists(node *Node[uint32], msgID uint64) bool { node.channel.responseMut.Lock() defer node.channel.responseMut.Unlock() _, exists := node.channel.responseRouters[msgID] return exists } -func getStream(node *Node) grpc.ClientStream { +func getStream(node *Node[uint32]) grpc.ClientStream { return node.channel.getStream() } @@ -109,7 +109,7 @@ func TestChannelCreation(t *testing.T) { node := testNodeWithoutServer(t) // send message when server is down - resp := sendRequest(t, node, request{waitSendDone: true}, 1) + resp := sendRequest(t, node, request[uint32]{waitSendDone: true}, 1) if resp.Err == nil { t.Error("response err: got , want error") } @@ -129,7 +129,7 @@ func TestChannelShutdown(t *testing.T) { var wg sync.WaitGroup for i := range numMessages { wg.Go(func() { - resp := sendRequest(t, node, request{}, uint64(i)) + resp := sendRequest(t, node, request[uint32]{}, uint64(i)) if resp.Err != nil { t.Errorf("unexpected error for message %d, got error: %v", i, resp.Err) } @@ -143,7 +143,7 @@ func TestChannelShutdown(t *testing.T) { } // try to send a message after node closure - resp := sendRequest(t, node, request{}, 999) + resp := sendRequest(t, node, request[uint32]{}, 999) if resp.Err == nil { t.Error("expected error when sending to closed channel") } else if !strings.Contains(resp.Err.Error(), "node closed") { @@ -218,7 +218,7 @@ func TestChannelSendCompletionWaiting(t *testing.T) { for i, tt := range tests { t.Run(tt.name, func(t *testing.T) { start := time.Now() - resp := sendRequest(t, node, request{waitSendDone: tt.waitSendDone}, uint64(i)) + resp := sendRequest(t, node, request[uint32]{waitSendDone: tt.waitSendDone}, uint64(i)) elapsed := time.Since(start) if resp.Err != nil { t.Errorf("unexpected error: %v", resp.Err) @@ -232,19 +232,19 @@ func TestChannelSendCompletionWaiting(t *testing.T) { func TestChannelErrors(t *testing.T) { tests := []struct { name string - setup func(t *testing.T) *Node + setup func(t *testing.T) *Node[uint32] wantErr string }{ { name: "EnqueueWithoutServer", - setup: func(t *testing.T) *Node { + setup: func(t *testing.T) *Node[uint32] { return testNodeWithoutServer(t) }, wantErr: "connection error", }, { name: "EnqueueToClosedChannel", - setup: func(t *testing.T) *Node { + setup: func(t *testing.T) *Node[uint32] { node := testNodeWithoutServer(t) err := node.close() if err != nil { @@ -256,7 +256,7 @@ func TestChannelErrors(t *testing.T) { }, { name: "EnqueueToServerWithClosedNode", - setup: func(t *testing.T) *Node { + setup: func(t *testing.T) *Node[uint32] { node := TestNode(t, delayServerFn(0)) err := node.close() if err != nil { @@ -268,10 +268,10 @@ func TestChannelErrors(t *testing.T) { }, { name: "ServerFailureDuringCommunication", - setup: func(t *testing.T) *Node { + setup: func(t *testing.T) *Node[uint32] { var stopServer func(...int) node := TestNode(t, delayServerFn(0), WithStopFunc(t, &stopServer)) - resp := sendRequest(t, node, request{waitSendDone: true}, 1) + resp := sendRequest(t, node, request[uint32]{waitSendDone: true}, 1) if resp.Err != nil { t.Errorf("first message should succeed, got error: %v", resp.Err) } @@ -287,7 +287,7 @@ func TestChannelErrors(t *testing.T) { time.Sleep(100 * time.Millisecond) // Send message and verify error - resp := sendRequest(t, node, request{waitSendDone: true}, uint64(i)) + resp := sendRequest(t, node, request[uint32]{waitSendDone: true}, uint64(i)) if resp.Err == nil { t.Errorf("expected error containing %q but got nil", tt.wantErr) } else if !strings.Contains(resp.Err.Error(), tt.wantErr) { @@ -300,7 +300,7 @@ func TestChannelErrors(t *testing.T) { // TestChannelEnsureStream verifies that ensureStream correctly manages stream lifecycle. func TestChannelEnsureStream(t *testing.T) { // Helper to prepare a fresh node with no stream - newNodeWithoutStream := func(t *testing.T) *Node { + newNodeWithoutStream := func(t *testing.T) *Node[uint32] { node := TestNode(t, delayServerFn(0)) // ensure sender and receiver goroutines are stopped node.channel.connCancel() @@ -329,14 +329,14 @@ func TestChannelEnsureStream(t *testing.T) { tests := []struct { name string - setup func(t *testing.T) *Node - action func(node *Node) (first, second grpc.ClientStream) + setup func(t *testing.T) *Node[uint32] + action func(node *Node[uint32]) (first, second grpc.ClientStream) wantSame bool }{ { name: "UnconnectedNodeHasNoStream", - setup: func(t *testing.T) *Node { return testNodeWithoutServer(t) }, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { + setup: func(t *testing.T) *Node[uint32] { return testNodeWithoutServer(t) }, + action: func(node *Node[uint32]) (grpc.ClientStream, grpc.ClientStream) { if err := node.channel.ensureStream(); err == nil { t.Error("ensureStream succeeded unexpectedly") } @@ -349,7 +349,7 @@ func TestChannelEnsureStream(t *testing.T) { { name: "CreatesStreamWhenConnected", setup: newNodeWithoutStream, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { + action: func(node *Node[uint32]) (grpc.ClientStream, grpc.ClientStream) { if err := node.channel.ensureStream(); err != nil { t.Errorf("ensureStream failed: %v", err) } @@ -359,7 +359,7 @@ func TestChannelEnsureStream(t *testing.T) { { name: "RepeatedCallsReturnSameStream", setup: newNodeWithoutStream, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { + action: func(node *Node[uint32]) (grpc.ClientStream, grpc.ClientStream) { if err := node.channel.ensureStream(); err != nil { t.Errorf("first ensureStream failed: %v", err) } @@ -374,7 +374,7 @@ func TestChannelEnsureStream(t *testing.T) { { name: "StreamDisconnectionCreatesNewStream", setup: newNodeWithoutStream, - action: func(node *Node) (grpc.ClientStream, grpc.ClientStream) { + action: func(node *Node[uint32]) (grpc.ClientStream, grpc.ClientStream) { if err := node.channel.ensureStream(); err != nil { t.Errorf("initial ensureStream failed: %v", err) } @@ -402,22 +402,22 @@ func TestChannelEnsureStream(t *testing.T) { func TestChannelConnectionState(t *testing.T) { tests := []struct { name string - setup func(t *testing.T) *Node + setup func(t *testing.T) *Node[uint32] wantConnected bool }{ { name: "WithoutServer", - setup: func(t *testing.T) *Node { return testNodeWithoutServer(t) }, + setup: func(t *testing.T) *Node[uint32] { return testNodeWithoutServer(t) }, wantConnected: false, }, { name: "WithLiveServer", - setup: func(t *testing.T) *Node { return TestNode(t, delayServerFn(0)) }, + setup: func(t *testing.T) *Node[uint32] { return TestNode(t, delayServerFn(0)) }, wantConnected: true, }, { name: "RequiresBothReadyAndStream", - setup: func(t *testing.T) *Node { + setup: func(t *testing.T) *Node[uint32] { node := TestNode(t, delayServerFn(0)) // Wait for stream to be established time.Sleep(streamConnectDelay) @@ -454,8 +454,8 @@ func TestChannelConcurrentSends(t *testing.T) { results := make(chan msgResponse, numMessages) for goID := range numGoroutines { go func() { - sendReq(t, results, node, goID, msgsPerGoroutine, request{waitSendDone: true}) - sendReq(t, results, node, goID, msgsPerGoroutine, request{waitSendDone: false}) + sendReq(t, results, node, goID, msgsPerGoroutine, request[uint32]{waitSendDone: true}) + sendReq(t, results, node, goID, msgsPerGoroutine, request[uint32]{waitSendDone: false}) }() } @@ -537,7 +537,7 @@ func TestChannelContext(t *testing.T) { t.Cleanup(cancel) node := TestNode(t, delayServerFn(tt.serverDelay)) - resp := sendRequest(t, node, request{ctx: ctx, waitSendDone: tt.waitSendDone}, uint64(i)) + resp := sendRequest(t, node, request[uint32]{ctx: ctx, waitSendDone: tt.waitSendDone}, uint64(i)) if !errors.Is(resp.Err, tt.wantErr) { t.Errorf("expected %v, got: %v", tt.wantErr, resp.Err) } @@ -567,7 +567,7 @@ func TestChannelDeadlock(t *testing.T) { } // Send a message to activate the stream - sendRequest(t, node, request{waitSendDone: true}, 1) + sendRequest(t, node, request[uint32]{waitSendDone: true}, 1) // Break the stream, forcing a reconnection on next send node.channel.clearStream() @@ -580,7 +580,7 @@ func TestChannelDeadlock(t *testing.T) { go func() { ctx := TestContext(t, 3*time.Second) md := ordering.NewGorumsMetadata(ctx, uint64(100+id), mock.TestMethod) - req := request{ctx: ctx, msg: NewRequestMessage(md, nil)} + req := request[uint32]{ctx: ctx, msg: NewRequestMessage(md, nil)} // try to enqueue select { @@ -628,7 +628,7 @@ func TestChannelRouterLifecycle(t *testing.T) { name string waitSendDone bool streaming bool - afterSend func(t *testing.T, node *Node, msgID uint64) + afterSend func(t *testing.T, node *Node[uint32], msgID uint64) wantRouter bool }{ {name: "WaitSendDone/NonStreamingAutoCleanup", waitSendDone: true, streaming: false, wantRouter: false}, @@ -640,7 +640,7 @@ func TestChannelRouterLifecycle(t *testing.T) { name := fmt.Sprintf("msgID=%d/%s/streaming=%t", i, tt.name, tt.streaming) t.Run(name, func(t *testing.T) { msgID := uint64(i) - resp := sendRequest(t, node, request{waitSendDone: tt.waitSendDone, streaming: tt.streaming}, msgID) + resp := sendRequest(t, node, request[uint32]{waitSendDone: tt.waitSendDone, streaming: tt.streaming}, msgID) if resp.Err != nil { t.Errorf("unexpected error: %v", resp.Err) } @@ -661,7 +661,7 @@ func TestChannelResponseRouting(t *testing.T) { results := make(chan msgResponse, numMessages) for i := range numMessages { - go sendReq(t, results, node, i, 1, request{}) + go sendReq(t, results, node, i, 1, request[uint32]{}) } // Collect and verify results @@ -692,7 +692,7 @@ func TestChannelStreamReadySignaling(t *testing.T) { // The first request triggers stream creation. We measure how quickly // the receiver starts processing after the stream is ready. start := time.Now() - resp := sendRequest(t, node, request{}, 1) + resp := sendRequest(t, node, request[uint32]{}, 1) firstLatency := time.Since(start) if resp.Err != nil { @@ -701,7 +701,7 @@ func TestChannelStreamReadySignaling(t *testing.T) { // Second request should be faster since stream is already established start = time.Now() - resp = sendRequest(t, node, request{}, 2) + resp = sendRequest(t, node, request[uint32]{}, 2) secondLatency := time.Since(start) if resp.Err != nil { @@ -726,7 +726,7 @@ func TestChannelStreamReadyAfterReconnect(t *testing.T) { node := TestNode(t, delayServerFn(0)) // First request to establish the stream - resp := sendRequest(t, node, request{}, 1) + resp := sendRequest(t, node, request[uint32]{}, 1) if resp.Err != nil { t.Fatalf("unexpected error on first request: %v", resp.Err) } @@ -743,7 +743,7 @@ func TestChannelStreamReadyAfterReconnect(t *testing.T) { var lastErr error start := time.Now() for i := range 5 { - resp = sendRequest(t, node, request{}, uint64(i+2)) + resp = sendRequest(t, node, request[uint32]{}, uint64(i+2)) if resp.Err == nil { reconnectLatency = time.Since(start) break @@ -786,9 +786,9 @@ func BenchmarkChannelStreamReadyFirstRequest(b *testing.B) { // Use a fresh context for the benchmark request ctx := TestContext(b, defaultTestTimeout) - req := request{ctx: ctx} + req := request[uint32]{ctx: ctx} req.msg = NewRequestMessage(ordering.NewGorumsMetadata(ctx, 1, mock.TestMethod), nil) - replyChan := make(chan NodeResponse[proto.Message], 1) + replyChan := make(chan NodeResponse[uint32, proto.Message], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -813,14 +813,14 @@ func BenchmarkChannelStreamReadySubsequentRequest(b *testing.B) { node := TestNode(b, delayServerFn(0)) // Warm up: establish the stream - resp := sendRequest(b, node, request{}, 0) + resp := sendRequest(b, node, request[uint32]{}, 0) if resp.Err != nil { b.Fatalf("warmup error: %v", resp.Err) } b.ResetTimer() for i := range b.N { - resp := sendRequest(b, node, request{}, uint64(i+1)) + resp := sendRequest(b, node, request[uint32]{}, uint64(i+1)) if resp.Err != nil { b.Fatalf("unexpected error: %v", resp.Err) } @@ -836,9 +836,9 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Establish initial stream with a fresh context ctx := context.Background() - req := request{ctx: ctx} + req := request[uint32]{ctx: ctx} req.msg = NewRequestMessage(ordering.NewGorumsMetadata(ctx, 0, mock.TestMethod), nil) - replyChan := make(chan NodeResponse[proto.Message], 1) + replyChan := make(chan NodeResponse[uint32, proto.Message], 1) req.responseChan = replyChan node.channel.enqueue(req) @@ -862,9 +862,9 @@ func BenchmarkChannelStreamReadyReconnect(b *testing.B) { // Now send a request which will trigger ensureStream -> newNodeStream -> signal ctx := context.Background() - req := request{ctx: ctx} + req := request[uint32]{ctx: ctx} req.msg = NewRequestMessage(ordering.NewGorumsMetadata(ctx, uint64(i+1), mock.TestMethod), nil) - replyChan := make(chan NodeResponse[proto.Message], 1) + replyChan := make(chan NodeResponse[uint32, proto.Message], 1) req.responseChan = replyChan node.channel.enqueue(req) diff --git a/client_interceptor.go b/client_interceptor.go index 96826599..905c9209 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -13,6 +13,7 @@ import ( // requests, responses, and aggregation logic. Interceptors can be chained together. // // Type parameters: +// - T: The node ID type // - Req: The request message type sent to nodes // - Resp: The response message type from individual nodes // @@ -22,35 +23,35 @@ import ( // // Custom interceptors can be created like this: // -// func LoggingInterceptor[Req, Resp proto.Message]( -// ctx *gorums.ClientCtx[Req, Resp], -// next gorums.ResponseSeq[Resp], -// ) gorums.ResponseSeq[Resp] { -// return func(yield func(gorums.NodeResponse[Resp]) bool) { +// func LoggingInterceptor[T NodeID, Req, Resp proto.Message]( +// ctx *gorums.ClientCtx[T, Req, Resp], +// next gorums.ResponseSeq[T, Resp], +// ) gorums.ResponseSeq[T, Resp] { +// return func(yield func(gorums.NodeResponse[T, Resp]) bool) { // for resp := range next { -// log.Printf("Response from node %d", resp.NodeID) +// log.Printf("Response from node %v", resp.NodeID) // if !yield(resp) { return } // } // } // } -type QuorumInterceptor[Req, Resp msg] func(ctx *ClientCtx[Req, Resp], next ResponseSeq[Resp]) ResponseSeq[Resp] +type QuorumInterceptor[T NodeID, Req, Resp msg] func(ctx *ClientCtx[T, Req, Resp], next ResponseSeq[T, Resp]) ResponseSeq[T, Resp] // ClientCtx provides context and access to the quorum call state for interceptors. // It exposes the request, configuration, metadata about the call, and the response iterator. -type ClientCtx[Req, Resp msg] struct { +type ClientCtx[T NodeID, Req, Resp msg] struct { context.Context - config Configuration + config Configuration[T] request Req method string md *ordering.Metadata - replyChan chan NodeResponse[msg] + replyChan chan NodeResponse[T, msg] // reqTransforms holds request transformation functions registered by interceptors. - reqTransforms []func(Req, *Node) Req + reqTransforms []func(Req, *Node[T]) Req // responseSeq is the iterator that yields node responses. // Interceptors can wrap this iterator to modify responses. - responseSeq ResponseSeq[Resp] + responseSeq ResponseSeq[T, Resp] // expectedReplies is the number of responses expected from nodes. // It is set when messages are sent and may be lower than config size @@ -70,8 +71,8 @@ type ClientCtx[Req, Resp msg] struct { } // clientCtxBuilder provides an interface for constructing ClientCtx instances. -type clientCtxBuilder[Req, Resp msg] struct { - c *ClientCtx[Req, Resp] +type clientCtxBuilder[T NodeID, Req, Resp msg] struct { + c *ClientCtx[T, Req, Resp] // chanMultiplier is the buffer multiplier for the reply channel. // Default is 1; streaming calls use a larger multiplier. chanMultiplier int @@ -80,13 +81,13 @@ type clientCtxBuilder[Req, Resp msg] struct { // newClientCtxBuilder creates a new builder for constructing a ClientCtx. // The required parameters are provided upfront; optional settings use builder methods. // The metadata and reply channel are created at Build() time. -func newClientCtxBuilder[Req, Resp msg]( - ctx *ConfigContext, +func newClientCtxBuilder[T NodeID, Req, Resp msg]( + ctx *ConfigContext[T], req Req, method string, -) *clientCtxBuilder[Req, Resp] { - return &clientCtxBuilder[Req, Resp]{ - c: &ClientCtx[Req, Resp]{ +) *clientCtxBuilder[T, Req, Resp] { + return &clientCtxBuilder[T, Req, Resp]{ + c: &ClientCtx[T, Req, Resp]{ Context: ctx, config: ctx.Configuration(), request: req, @@ -101,7 +102,7 @@ func newClientCtxBuilder[Req, Resp msg]( // When enabled, the response iterator continues until context cancellation // rather than stopping after expectedReplies responses. // It also increases the reply channel buffer size (10x) to handle streaming volume. -func (b *clientCtxBuilder[Req, Resp]) WithStreaming() *clientCtxBuilder[Req, Resp] { +func (b *clientCtxBuilder[T, Req, Resp]) WithStreaming() *clientCtxBuilder[T, Req, Resp] { b.c.streaming = true b.chanMultiplier = 10 return b @@ -109,17 +110,17 @@ func (b *clientCtxBuilder[Req, Resp]) WithStreaming() *clientCtxBuilder[Req, Res // WithWaitSendDone configures the clientCtx to wait for send completion. // Used by multicast calls to ensure messages are sent before returning. -func (b *clientCtxBuilder[Req, Resp]) WithWaitSendDone(waitSendDone bool) *clientCtxBuilder[Req, Resp] { +func (b *clientCtxBuilder[T, Req, Resp]) WithWaitSendDone(waitSendDone bool) *clientCtxBuilder[T, Req, Resp] { b.c.waitSendDone = waitSendDone return b } // Build finalizes the ClientCtx configuration and returns the constructed instance. // It creates the metadata and reply channel, and sets up the appropriate response iterator. -func (b *clientCtxBuilder[Req, Resp]) Build() *ClientCtx[Req, Resp] { +func (b *clientCtxBuilder[T, Req, Resp]) Build() *ClientCtx[T, Req, Resp] { // Create metadata and reply channel at build time b.c.md = ordering.NewGorumsMetadata(b.c.Context, b.c.config.nextMsgID(), b.c.method) - b.c.replyChan = make(chan NodeResponse[msg], b.c.config.Size()*b.chanMultiplier) + b.c.replyChan = make(chan NodeResponse[T, msg], b.c.config.Size()*b.chanMultiplier) if b.c.streaming { b.c.responseSeq = b.c.streamingResponseSeq() @@ -134,29 +135,29 @@ func (b *clientCtxBuilder[Req, Resp]) Build() *ClientCtx[Req, Resp] { // ------------------------------------------------------------------------- // Request returns the original request message for this quorum call. -func (c *ClientCtx[Req, Resp]) Request() Req { +func (c *ClientCtx[T, Req, Resp]) Request() Req { return c.request } // Config returns the configuration (set of nodes) for this quorum call. -func (c *ClientCtx[Req, Resp]) Config() Configuration { +func (c *ClientCtx[T, Req, Resp]) Config() Configuration[T] { return c.config } // Method returns the name of the RPC method being called. -func (c *ClientCtx[Req, Resp]) Method() string { +func (c *ClientCtx[T, Req, Resp]) Method() string { return c.method } // Nodes returns the slice of nodes in this configuration. -func (c *ClientCtx[Req, Resp]) Nodes() []*Node { +func (c *ClientCtx[T, Req, Resp]) Nodes() []*Node[T] { return c.config.Nodes() } // Node returns the node with the given ID. -func (c *ClientCtx[Req, Resp]) Node(id uint32) *Node { +func (c *ClientCtx[T, Req, Resp]) Node(id T) *Node[T] { nodes := c.config.Nodes() - index := slices.IndexFunc(nodes, func(n *Node) bool { + index := slices.IndexFunc(nodes, func(n *Node[T]) bool { return n.ID() == id }) if index != -1 { @@ -166,7 +167,7 @@ func (c *ClientCtx[Req, Resp]) Node(id uint32) *Node { } // Size returns the number of nodes in this configuration. -func (c *ClientCtx[Req, Resp]) Size() int { +func (c *ClientCtx[T, Req, Resp]) Size() int { return c.config.Size() } @@ -174,7 +175,7 @@ func (c *ClientCtx[Req, Resp]) Size() int { // invalid or the node should be skipped. It applies the registered transformation functions to // the given request for the specified node. Transformation functions are applied in the order // they were registered. -func (c *ClientCtx[Req, Resp]) applyTransforms(req Req, node *Node) proto.Message { +func (c *ClientCtx[T, Req, Resp]) applyTransforms(req Req, node *Node[T]) proto.Message { result := req for _, transform := range c.reqTransforms { result = transform(result, node) @@ -190,10 +191,10 @@ func (c *ClientCtx[Req, Resp]) applyTransforms(req Req, node *Node) proto.Messag // applyInterceptors chains the given interceptors, wrapping the response sequence. // Each interceptor receives the current response sequence and returns a new one. // Interceptors are applied in order, with each wrapping the previous result. -func (c *ClientCtx[Req, Resp]) applyInterceptors(interceptors []any) { +func (c *ClientCtx[T, Req, Resp]) applyInterceptors(interceptors []any) { responseSeq := c.responseSeq for _, ic := range interceptors { - interceptor := ic.(QuorumInterceptor[Req, Resp]) + interceptor := ic.(QuorumInterceptor[T, Req, Resp]) responseSeq = interceptor(c, responseSeq) } c.responseSeq = responseSeq @@ -202,7 +203,7 @@ func (c *ClientCtx[Req, Resp]) applyInterceptors(interceptors []any) { // send dispatches requests to all nodes, applying any registered transformations. // It updates expectedReplies based on how many nodes actually receive requests // (nodes may be skipped if a transformation returns nil). -func (c *ClientCtx[Req, Resp]) send() { +func (c *ClientCtx[T, Req, Resp]) send() { var expected int for _, n := range c.config { msg := c.applyTransforms(c.request, n) @@ -210,7 +211,7 @@ func (c *ClientCtx[Req, Resp]) send() { continue // Skip node if transformation returns nil } expected++ - n.channel.enqueue(request{ + n.channel.enqueue(request[T]{ ctx: c.Context, msg: NewRequestMessage(c.md, msg), streaming: c.streaming, @@ -223,14 +224,14 @@ func (c *ClientCtx[Req, Resp]) send() { // defaultResponseSeq returns an iterator that yields at most c.expectedReplies responses // from nodes until the context is canceled or all expected responses are received. -func (c *ClientCtx[Req, Resp]) defaultResponseSeq() ResponseSeq[Resp] { - return func(yield func(NodeResponse[Resp]) bool) { +func (c *ClientCtx[T, Req, Resp]) defaultResponseSeq() ResponseSeq[T, Resp] { + return func(yield func(NodeResponse[T, Resp]) bool) { // Trigger sending on first iteration c.sendOnce.Do(c.send) for range c.expectedReplies { select { case r := <-c.replyChan: - res := newNodeResponse[Resp](r) + res := newNodeResponse[T, Resp](r) if !yield(res) { return // Consumer stopped iteration } @@ -243,14 +244,14 @@ func (c *ClientCtx[Req, Resp]) defaultResponseSeq() ResponseSeq[Resp] { // streamingResponseSeq returns an iterator that yields responses as they arrive // from nodes until the context is canceled or breaking from the range loop. -func (c *ClientCtx[Req, Resp]) streamingResponseSeq() ResponseSeq[Resp] { - return func(yield func(NodeResponse[Resp]) bool) { +func (c *ClientCtx[T, Req, Resp]) streamingResponseSeq() ResponseSeq[T, Resp] { + return func(yield func(NodeResponse[T, Resp]) bool) { // Trigger sending on first iteration c.sendOnce.Do(c.send) for { select { case r := <-c.replyChan: - res := newNodeResponse[Resp](r) + res := newNodeResponse[T, Resp](r) if !yield(res) { return // Consumer stopped iteration } @@ -271,8 +272,8 @@ func (c *ClientCtx[Req, Resp]) streamingResponseSeq() ResponseSeq[Resp] { // The fn receives the original request and a node, and returns the transformed // request to send to that node. If the function returns an invalid message or nil, // the request to that node is skipped. -func MapRequest[Req, Resp msg](fn func(Req, *Node) Req) QuorumInterceptor[Req, Resp] { - return func(ctx *ClientCtx[Req, Resp], next ResponseSeq[Resp]) ResponseSeq[Resp] { +func MapRequest[T NodeID, Req, Resp msg](fn func(Req, *Node[T]) Req) QuorumInterceptor[T, Req, Resp] { + return func(ctx *ClientCtx[T, Req, Resp], next ResponseSeq[T, Resp]) ResponseSeq[T, Resp] { if fn != nil { ctx.reqTransforms = append(ctx.reqTransforms, fn) } @@ -284,13 +285,13 @@ func MapRequest[Req, Resp msg](fn func(Req, *Node) Req) QuorumInterceptor[Req, R // // The fn receives the response from a node and the node itself, and returns the // transformed response. -func MapResponse[Req, Resp msg](fn func(Resp, *Node) Resp) QuorumInterceptor[Req, Resp] { - return func(ctx *ClientCtx[Req, Resp], next ResponseSeq[Resp]) ResponseSeq[Resp] { +func MapResponse[T NodeID, Req, Resp msg](fn func(Resp, *Node[T]) Resp) QuorumInterceptor[T, Req, Resp] { + return func(ctx *ClientCtx[T, Req, Resp], next ResponseSeq[T, Resp]) ResponseSeq[T, Resp] { if fn == nil { return next } // Wrap the response iterator with the transformation logic. - return func(yield func(NodeResponse[Resp]) bool) { + return func(yield func(NodeResponse[T, Resp]) bool) { for resp := range next { // We only apply the transformation if there is no error. // Errors are passed through as-is. diff --git a/client_interceptor_test.go b/client_interceptor_test.go index 9755da86..9d2a3f0a 100644 --- a/client_interceptor_test.go +++ b/client_interceptor_test.go @@ -10,13 +10,15 @@ import ( pb "google.golang.org/protobuf/types/known/wrapperspb" ) +type NodeID = gorums.NodeID + // LoggingInterceptor is a custom interceptor that logs each response. -func LoggingInterceptor[Req, Resp proto.Message]( - ctx *gorums.ClientCtx[Req, Resp], - next gorums.ResponseSeq[Resp], -) gorums.ResponseSeq[Resp] { +func LoggingInterceptor[T NodeID, Req, Resp proto.Message]( + ctx *gorums.ClientCtx[T, Req, Resp], + next gorums.ResponseSeq[T, Resp], +) gorums.ResponseSeq[T, Resp] { _ = ctx.Method() // Access method name (could be used for logging) - return func(yield func(gorums.NodeResponse[Resp]) bool) { + return func(yield func(gorums.NodeResponse[T, Resp]) bool) { for resp := range next { // In a real interceptor, you would log here if !yield(resp) { @@ -27,12 +29,12 @@ func LoggingInterceptor[Req, Resp proto.Message]( } // FilterInterceptor returns an interceptor that filters responses based on a predicate. -func FilterInterceptor[Req, Resp proto.Message]( - keep func(resp gorums.NodeResponse[Resp]) bool, -) gorums.QuorumInterceptor[Req, Resp] { - return func(ctx *gorums.ClientCtx[Req, Resp], next gorums.ResponseSeq[Resp]) gorums.ResponseSeq[Resp] { +func FilterInterceptor[T NodeID, Req, Resp proto.Message]( + keep func(resp gorums.NodeResponse[T, Resp]) bool, +) gorums.QuorumInterceptor[T, Req, Resp] { + return func(ctx *gorums.ClientCtx[T, Req, Resp], next gorums.ResponseSeq[T, Resp]) gorums.ResponseSeq[T, Resp] { _ = ctx.Method() // Access method name (could be used for filtering) - return func(yield func(gorums.NodeResponse[Resp]) bool) { + return func(yield func(gorums.NodeResponse[T, Resp]) bool) { for resp := range next { if keep(resp) { if !yield(resp) { @@ -45,11 +47,11 @@ func FilterInterceptor[Req, Resp proto.Message]( } // CountingInterceptor counts the number of responses passing through. -func CountingInterceptor[Req, Resp proto.Message]( +func CountingInterceptor[T NodeID, Req, Resp proto.Message]( counter *int, -) gorums.QuorumInterceptor[Req, Resp] { - return func(_ *gorums.ClientCtx[Req, Resp], next gorums.ResponseSeq[Resp]) gorums.ResponseSeq[Resp] { - return func(yield func(gorums.NodeResponse[Resp]) bool) { +) gorums.QuorumInterceptor[T, Req, Resp] { + return func(_ *gorums.ClientCtx[T, Req, Resp], next gorums.ResponseSeq[T, Resp]) gorums.ResponseSeq[T, Resp] { + return func(yield func(gorums.NodeResponse[T, Resp]) bool) { for resp := range next { *counter++ if !yield(resp) { @@ -67,11 +69,11 @@ func TestCustomLoggingInterceptor(t *testing.T) { ctx := gorums.TestContext(t, 2*time.Second) // Use the custom logging interceptor from this external package - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.TestMethod, - gorums.Interceptors(LoggingInterceptor[*pb.StringValue, *pb.StringValue]), + gorums.Interceptors(LoggingInterceptor[uint32, *pb.StringValue, *pb.StringValue]), ) result, err := responses.Majority() @@ -90,12 +92,12 @@ func TestCustomFilterInterceptor(t *testing.T) { // Use a filter interceptor that only keeps responses from node 1 // (In practice, this would filter based on response content) - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.TestMethod, - gorums.Interceptors(FilterInterceptor[*pb.StringValue]( - func(resp gorums.NodeResponse[*pb.StringValue]) bool { + gorums.Interceptors(FilterInterceptor[uint32, *pb.StringValue]( + func(resp gorums.NodeResponse[uint32, *pb.StringValue]) bool { return resp.Err == nil // Only keep successful responses }, )), @@ -118,13 +120,13 @@ func TestInterceptorChaining(t *testing.T) { var count int // Chain multiple custom interceptors - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.TestMethod, gorums.Interceptors( - LoggingInterceptor[*pb.StringValue, *pb.StringValue], - CountingInterceptor[*pb.StringValue, *pb.StringValue](&count), + LoggingInterceptor[uint32, *pb.StringValue, *pb.StringValue], + CountingInterceptor[uint32, *pb.StringValue, *pb.StringValue](&count), ), ) @@ -150,16 +152,16 @@ func TestCustomInterceptorWithMapRequest(t *testing.T) { var count int // Mix custom interceptor with built-in MapRequest - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.TestMethod, gorums.Interceptors( // Custom: count responses - CountingInterceptor[*pb.StringValue, *pb.StringValue](&count), + CountingInterceptor[uint32, *pb.StringValue, *pb.StringValue](&count), // Built-in: transform request (identity transform for this test) - gorums.MapRequest[*pb.StringValue, *pb.StringValue]( - func(req *pb.StringValue, node *gorums.Node) *pb.StringValue { + gorums.MapRequest[uint32, *pb.StringValue, *pb.StringValue]( + func(req *pb.StringValue, node *gorums.Node[uint32]) *pb.StringValue { return req }, ), diff --git a/cmd/benchmark/main.go b/cmd/benchmark/main.go index 6b8d4dc7..752aea84 100644 --- a/cmd/benchmark/main.go +++ b/cmd/benchmark/main.go @@ -196,10 +196,10 @@ func main() { gorums.WithSendBufferSize(*sendBuffer), } - mgr := gorums.NewManager(mgrOpts...) + mgr := benchmark.NewManager(mgrOpts...) defer mgr.Close() - cfg, err := gorums.NewConfiguration(mgr, gorums.WithNodeList(remotes[:options.NumNodes])) + cfg, err := benchmark.NewConfiguration(mgr, gorums.WithNodeList(remotes[:options.NumNodes])) checkf("Failed to create configuration: %v", err) results, err := benchmark.RunBenchmarks(benchReg, options, cfg) diff --git a/cmd/protoc-gen-gorums/dev/aliases.go b/cmd/protoc-gen-gorums/dev/aliases.go index 94608a0d..e2c6fb8a 100644 --- a/cmd/protoc-gen-gorums/dev/aliases.go +++ b/cmd/protoc-gen-gorums/dev/aliases.go @@ -5,9 +5,11 @@ import gorums "github.com/relab/gorums" // Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -16,20 +18,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -50,5 +54,5 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } diff --git a/cmd/protoc-gen-gorums/dev/zorums.pb.go b/cmd/protoc-gen-gorums/dev/zorums.pb.go index 042594bf..a5944b08 100644 --- a/cmd/protoc-gen-gorums/dev/zorums.pb.go +++ b/cmd/protoc-gen-gorums/dev/zorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: zorums.proto package dev diff --git a/cmd/protoc-gen-gorums/dev/zorums_multicast_gorums.pb.go b/cmd/protoc-gen-gorums/dev/zorums_multicast_gorums.pb.go index d2fe8694..d05250d9 100644 --- a/cmd/protoc-gen-gorums/dev/zorums_multicast_gorums.pb.go +++ b/cmd/protoc-gen-gorums/dev/zorums_multicast_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: zorums.proto package dev @@ -22,21 +22,21 @@ const ( var _ emptypb.Empty // Multicast plain. Response type is not needed here. -func Multicast(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) error { +func Multicast(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) error { return gorums.Multicast(ctx, in, "dev.ZorumsService.Multicast", opts...) } // Multicast2 is testing whether multiple streams work. -func Multicast2(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) error { +func Multicast2(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) error { return gorums.Multicast(ctx, in, "dev.ZorumsService.Multicast2", opts...) } // Multicast3 is testing imported message type. -func Multicast3(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) error { +func Multicast3(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) error { return gorums.Multicast(ctx, in, "dev.ZorumsService.Multicast3", opts...) } // Multicast4 is testing imported message type. -func Multicast4(ctx *gorums.ConfigContext, in *emptypb.Empty, opts ...gorums.CallOption) error { +func Multicast4(ctx *ConfigContext, in *emptypb.Empty, opts ...gorums.CallOption) error { return gorums.Multicast(ctx, in, "dev.ZorumsService.Multicast4", opts...) } diff --git a/cmd/protoc-gen-gorums/dev/zorums_quorumcall_gorums.pb.go b/cmd/protoc-gen-gorums/dev/zorums_quorumcall_gorums.pb.go index 9e47a638..d7915a98 100644 --- a/cmd/protoc-gen-gorums/dev/zorums_quorumcall_gorums.pb.go +++ b/cmd/protoc-gen-gorums/dev/zorums_quorumcall_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: zorums.proto package dev @@ -19,40 +19,40 @@ const ( ) // QuorumCall plain. -func QuorumCall(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[*Response] { - return gorums.QuorumCall[*Request, *Response]( +func QuorumCall(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Response] { + return gorums.QuorumCall[NodeID, *Request, *Response]( ctx, in, "dev.ZorumsService.QuorumCall", opts..., ) } // QuorumCallEmpty for testing imported message type. -func QuorumCallEmpty(ctx *gorums.ConfigContext, in *emptypb.Empty, opts ...gorums.CallOption) *gorums.Responses[*Response] { - return gorums.QuorumCall[*emptypb.Empty, *Response]( +func QuorumCallEmpty(ctx *ConfigContext, in *emptypb.Empty, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Response] { + return gorums.QuorumCall[NodeID, *emptypb.Empty, *Response]( ctx, in, "dev.ZorumsService.QuorumCallEmpty", opts..., ) } // QuorumCallEmpty2 for testing imported message type. -func QuorumCallEmpty2(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[*emptypb.Empty] { - return gorums.QuorumCall[*Request, *emptypb.Empty]( +func QuorumCallEmpty2(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[NodeID, *emptypb.Empty] { + return gorums.QuorumCall[NodeID, *Request, *emptypb.Empty]( ctx, in, "dev.ZorumsService.QuorumCallEmpty2", opts..., ) } // QuorumCallStream plain. -func QuorumCallStream(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[*Response] { - return gorums.QuorumCallStream[*Request, *Response]( +func QuorumCallStream(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Response] { + return gorums.QuorumCallStream[NodeID, *Request, *Response]( ctx, in, "dev.ZorumsService.QuorumCallStream", opts..., ) } // QuorumCallStreamWithEmpty for testing imported message type. -func QuorumCallStreamWithEmpty(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[*emptypb.Empty] { - return gorums.QuorumCallStream[*Request, *emptypb.Empty]( +func QuorumCallStreamWithEmpty(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[NodeID, *emptypb.Empty] { + return gorums.QuorumCallStream[NodeID, *Request, *emptypb.Empty]( ctx, in, "dev.ZorumsService.QuorumCallStreamWithEmpty", opts..., ) @@ -60,8 +60,8 @@ func QuorumCallStreamWithEmpty(ctx *gorums.ConfigContext, in *Request, opts ...g // QuorumCallStreamWithEmpty2 for testing imported message type; with same return // type as QuorumCallStream: Response. -func QuorumCallStreamWithEmpty2(ctx *gorums.ConfigContext, in *emptypb.Empty, opts ...gorums.CallOption) *gorums.Responses[*Response] { - return gorums.QuorumCallStream[*emptypb.Empty, *Response]( +func QuorumCallStreamWithEmpty2(ctx *ConfigContext, in *emptypb.Empty, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Response] { + return gorums.QuorumCallStream[NodeID, *emptypb.Empty, *Response]( ctx, in, "dev.ZorumsService.QuorumCallStreamWithEmpty2", opts..., ) diff --git a/cmd/protoc-gen-gorums/dev/zorums_rpc_gorums.pb.go b/cmd/protoc-gen-gorums/dev/zorums_rpc_gorums.pb.go index 07c6df7a..7537022c 100644 --- a/cmd/protoc-gen-gorums/dev/zorums_rpc_gorums.pb.go +++ b/cmd/protoc-gen-gorums/dev/zorums_rpc_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: zorums.proto package dev @@ -19,7 +19,7 @@ const ( // GRPCCall plain gRPC call; testing that Gorums can ignore these, but that // they are added to the _grpc.pb.go generated file. -func GRPCCall(ctx *gorums.NodeContext, in *Request) (resp *Response, err error) { +func GRPCCall(ctx *NodeContext, in *Request) (resp *Response, err error) { res, err := gorums.RPCCall(ctx, in, "dev.ZorumsService.GRPCCall") if err != nil { return nil, err diff --git a/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go b/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go index 6d5e0740..98759772 100644 --- a/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go +++ b/cmd/protoc-gen-gorums/dev/zorums_server_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: zorums.proto package dev diff --git a/cmd/protoc-gen-gorums/dev/zorums_types_gorums.pb.go b/cmd/protoc-gen-gorums/dev/zorums_types_gorums.pb.go index f191b836..1cd12ea4 100644 --- a/cmd/protoc-gen-gorums/dev/zorums_types_gorums.pb.go +++ b/cmd/protoc-gen-gorums/dev/zorums_types_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: zorums.proto package dev @@ -18,6 +18,9 @@ const ( _ = gorums.EnforceVersion(gorums.MaxVersion - 11) ) +// NodeID is a type alias for the type used to identify nodes. +type NodeID = uint32 + // AsyncEmpty is a future for async quorum calls returning *emptypb.Empty. type AsyncEmpty = *gorums.Async[*emptypb.Empty] diff --git a/cmd/protoc-gen-gorums/dev/zorums_unicast_gorums.pb.go b/cmd/protoc-gen-gorums/dev/zorums_unicast_gorums.pb.go index 37c4f1c9..1453ed0c 100644 --- a/cmd/protoc-gen-gorums/dev/zorums_unicast_gorums.pb.go +++ b/cmd/protoc-gen-gorums/dev/zorums_unicast_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: zorums.proto package dev @@ -23,12 +23,12 @@ var _ emptypb.Empty // Unicast is a unicast call invoked on the node in ctx. // No reply is returned to the client. -func Unicast(ctx *gorums.NodeContext, in *Request, opts ...gorums.CallOption) error { +func Unicast(ctx *NodeContext, in *Request, opts ...gorums.CallOption) error { return gorums.Unicast(ctx, in, "dev.ZorumsService.Unicast", opts...) } // Unicast2 is a unicast call invoked on the node in ctx. // No reply is returned to the client. -func Unicast2(ctx *gorums.NodeContext, in *Request, opts ...gorums.CallOption) error { +func Unicast2(ctx *NodeContext, in *Request, opts ...gorums.CallOption) error { return gorums.Unicast(ctx, in, "dev.ZorumsService.Unicast2", opts...) } diff --git a/cmd/protoc-gen-gorums/gengorums/gorums.go b/cmd/protoc-gen-gorums/gengorums/gorums.go index 5425b95f..2220b151 100644 --- a/cmd/protoc-gen-gorums/gengorums/gorums.go +++ b/cmd/protoc-gen-gorums/gengorums/gorums.go @@ -96,7 +96,7 @@ func gorumsGuard(file *protogen.File) bool { // GenerateFileContent generates the Gorums service definitions, excluding the package statement. func generateFileContent(file *protogen.File, g *protogen.GeneratedFile) { - // 1. Data Types (Async/Correctable aliases) + // 1. Data Types (NodeID/Async/Correctable aliases) genGorumsType(g, file.Services, "types") // 2. Reference Imports diff --git a/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go b/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go index becfa5d8..f5a94b91 100644 --- a/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go +++ b/cmd/protoc-gen-gorums/gengorums/gorums_func_map.go @@ -8,6 +8,7 @@ import ( "github.com/relab/gorums" "google.golang.org/protobuf/compiler/protogen" + "google.golang.org/protobuf/proto" ) // importMap holds the mapping between short-hand import name @@ -73,6 +74,13 @@ var funcMap = template.FuncMap{ "serviceName": func(method *protogen.Method) string { return string(method.Parent.Desc.Name()) }, + "nodeIDType": func(service *protogen.Service) string { + options := service.Desc.ParentFile().Options() + if proto.HasExtension(options, gorums.E_NodeIdType) { + return proto.GetExtension(options, gorums.E_NodeIdType).(string) + } + return "uint32" + }, "in": func(g *protogen.GeneratedFile, method *protogen.Method) string { return g.QualifiedGoIdent(method.Input.GoIdent) }, diff --git a/cmd/protoc-gen-gorums/gengorums/template_datatypes.go b/cmd/protoc-gen-gorums/gengorums/template_datatypes.go index 23bb380e..9f05d594 100644 --- a/cmd/protoc-gen-gorums/gengorums/template_datatypes.go +++ b/cmd/protoc-gen-gorums/gengorums/template_datatypes.go @@ -1,5 +1,14 @@ package gengorums +// nodeIDDataType is the template for the node ID type. +var nodeIDDataType = ` +{{$gorums := use "gorums.EnforceVersion" .GenFile}} +// NodeID is a type alias for the type used to identify nodes. +{{- with index .Services 0 }} +type NodeID = {{ nodeIDType . }} +{{- end }} +` + // This type alias is generated only once per return type for an async call type. // That is, if multiple async calls use the same return type, this type alias // is only generated once. @@ -22,4 +31,4 @@ type {{$correctableOut}} = *{{$correctable}}[*{{$customOut}}] {{end}} ` -var dataTypes = asyncDataType + correctableDataType +var dataTypes = nodeIDDataType + asyncDataType + correctableDataType diff --git a/cmd/protoc-gen-gorums/gengorums/template_multicast.go b/cmd/protoc-gen-gorums/gengorums/template_multicast.go index ccd5d957..055550c3 100644 --- a/cmd/protoc-gen-gorums/gengorums/template_multicast.go +++ b/cmd/protoc-gen-gorums/gengorums/template_multicast.go @@ -2,7 +2,7 @@ package gengorums var mcVar = ` {{$genFile := .GenFile}} -{{$configContext := use "gorums.ConfigContext" .GenFile}} +{{$configContext := "ConfigContext"}} {{$multicast := use "gorums.Multicast" .GenFile}} {{$callOpt := use "gorums.CallOption" .GenFile}} ` diff --git a/cmd/protoc-gen-gorums/gengorums/template_quorumcall.go b/cmd/protoc-gen-gorums/gengorums/template_quorumcall.go index 8a1b2432..8020c9aa 100644 --- a/cmd/protoc-gen-gorums/gengorums/template_quorumcall.go +++ b/cmd/protoc-gen-gorums/gengorums/template_quorumcall.go @@ -25,7 +25,7 @@ var quorumCallComment = ` var quorumCallVariables = ` {{$genFile := .GenFile}} -{{$configContext := use "gorums.ConfigContext" .GenFile}} +{{$configContext := "ConfigContext"}} {{$quorumCall := use "gorums.QuorumCall" .GenFile}} {{$quorumCallStream := use "gorums.QuorumCallStream" .GenFile}} {{$responses := use "gorums.Responses" .GenFile}} @@ -35,10 +35,10 @@ var quorumCallVariables = ` var quorumCallSignature = `func {{$method}}(` + `ctx *{{$configContext}}, in *{{$in}}, ` + `opts ...{{$callOption}})` + - ` *{{$responses}}[*{{$out}}] { + ` *{{$responses}}[NodeID, *{{$out}}] { ` -var quorumCallBody = ` return {{$quorumCall}}[*{{$in}}, *{{$out}}]( +var quorumCallBody = ` return {{$quorumCall}}[NodeID, *{{$in}}, *{{$out}}]( ctx, in, "{{$fullName}}", opts..., ) @@ -67,7 +67,7 @@ var quorumCallStreamComment = ` {{end -}} ` -var quorumCallStreamBody = ` return {{$quorumCallStream}}[*{{$in}}, *{{$out}}]( +var quorumCallStreamBody = ` return {{$quorumCallStream}}[NodeID, *{{$in}}, *{{$out}}]( ctx, in, "{{$fullName}}", opts..., ) diff --git a/cmd/protoc-gen-gorums/gengorums/template_rpc.go b/cmd/protoc-gen-gorums/gengorums/template_rpc.go index 228eaad2..df495373 100644 --- a/cmd/protoc-gen-gorums/gengorums/template_rpc.go +++ b/cmd/protoc-gen-gorums/gengorums/template_rpc.go @@ -15,7 +15,7 @@ var rpcSignature = `func {{$method}}(` + var rpcVar = ` {{$genFile := .GenFile}} -{{$nodeContext := use "gorums.NodeContext" .GenFile}} +{{$nodeContext := "NodeContext"}} {{$rpcCall := use "gorums.RPCCall" .GenFile}} {{$_ := use "gorums.EnforceVersion" .GenFile}} ` diff --git a/cmd/protoc-gen-gorums/gengorums/template_static.go b/cmd/protoc-gen-gorums/gengorums/template_static.go index 706028c3..e6d79889 100644 --- a/cmd/protoc-gen-gorums/gengorums/template_static.go +++ b/cmd/protoc-gen-gorums/gengorums/template_static.go @@ -5,18 +5,20 @@ package gengorums // pkgIdentMap maps from package name to one of the package's identifiers. // These identifiers are used by the Gorums protoc plugin to generate import statements. -var pkgIdentMap = map[string]string{"github.com/relab/gorums": "Configuration"} +var pkgIdentMap = map[string]string{"github.com/relab/gorums": "ConfigContext"} // reservedIdents holds the set of Gorums reserved identifiers. // These identifiers cannot be used to define message types in a proto file. -var reservedIdents = []string{"Configuration", "Manager", "Node"} +var reservedIdents = []string{"ConfigContext", "Configuration", "Manager", "Node", "NodeContext", "NodeID"} var staticCode = `// Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -25,20 +27,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -59,7 +63,7 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } ` diff --git a/cmd/protoc-gen-gorums/gengorums/template_unicast.go b/cmd/protoc-gen-gorums/gengorums/template_unicast.go index ff8a18ae..9eb0442a 100644 --- a/cmd/protoc-gen-gorums/gengorums/template_unicast.go +++ b/cmd/protoc-gen-gorums/gengorums/template_unicast.go @@ -2,7 +2,7 @@ package gengorums var unicastVar = ` {{$genFile := .GenFile}} -{{$nodeContext := use "gorums.NodeContext" .GenFile}} +{{$nodeContext := "NodeContext"}} {{$unicast := use "gorums.Unicast" .GenFile}} {{$callOpt := use "gorums.CallOption" .GenFile}} ` diff --git a/config.go b/config.go index e495c3f8..8c2daf55 100644 --- a/config.go +++ b/config.go @@ -9,13 +9,13 @@ import ( // It embeds context.Context and provides access to the Configuration. // // Use [Configuration.Context] to create a ConfigContext from an existing context. -type ConfigContext struct { +type ConfigContext[T NodeID] struct { context.Context - cfg Configuration + cfg Configuration[T] } // Configuration returns the Configuration associated with this context. -func (c ConfigContext) Configuration() Configuration { +func (c ConfigContext[T]) Configuration() Configuration[T] { return c.cfg } @@ -23,7 +23,7 @@ func (c ConfigContext) Configuration() Configuration { // // Mutating the configuration is not supported; instead, use NewConfiguration to create // a new configuration. -type Configuration []*Node +type Configuration[T NodeID] []*Node[T] // Context creates a new ConfigContext from the given parent context // and this configuration. @@ -33,18 +33,18 @@ type Configuration []*Node // config, _ := gorums.NewConfiguration(mgr, gorums.WithNodeList(addrs)) // cfgCtx := config.Context(context.Background()) // resp, err := paxos.Prepare(cfgCtx, req) -func (cfg Configuration) Context(parent context.Context) *ConfigContext { - if len(cfg) == 0 { +func (c Configuration[T]) Context(parent context.Context) *ConfigContext[T] { + if len(c) == 0 { panic("gorums: Context called with empty configuration") } - return &ConfigContext{Context: parent, cfg: cfg} + return &ConfigContext[T]{Context: parent, cfg: c} } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt NodeListOption) (nodes Configuration, err error) { +func NewConfiguration[T NodeID](mgr *Manager[T], opt NodeListOption[T]) (nodes Configuration[T], err error) { if opt == nil { return nil, fmt.Errorf("config: missing required node list") } @@ -57,7 +57,7 @@ func NewConfiguration(mgr *Manager, opt NodeListOption) (nodes Configuration, er // // Example: // -// cfg, err := NewConfig( +// cfg, err := NewConfig[uint32]( // gorums.WithNodeList([]string{"localhost:8080", "localhost:8081", "localhost:8082"}), // gorums.WithDialOptions(grpc.WithTransportCredentials(insecure.NewCredentials())), // ) @@ -67,16 +67,16 @@ func NewConfiguration(mgr *Manager, opt NodeListOption) (nodes Configuration, er // [Configuration.Manager] method. This method should only be used once since it // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. -func NewConfig(opts ...Option) (Configuration, error) { +func NewConfig[T NodeID](opts ...Option) (Configuration[T], error) { var ( managerOptions []ManagerOption - nodeListOption NodeListOption + nodeListOption NodeListOption[T] ) for _, opt := range opts { switch o := opt.(type) { case ManagerOption: managerOptions = append(managerOptions, o) - case NodeListOption: + case NodeListOption[T]: if nodeListOption != nil { return nil, fmt.Errorf("gorums: multiple NodeListOptions provided") } @@ -88,13 +88,13 @@ func NewConfig(opts ...Option) (Configuration, error) { if nodeListOption == nil { return nil, fmt.Errorf("gorums: missing required NodeListOption") } - mgr := NewManager(managerOptions...) + mgr := NewManager[T](managerOptions...) return NewConfiguration(mgr, nodeListOption) } // NodeIDs returns a slice of this configuration's Node IDs. -func (c Configuration) NodeIDs() []uint32 { - ids := make([]uint32, len(c)) +func (c Configuration[T]) NodeIDs() []T { + ids := make([]T, len(c)) for i, node := range c { ids[i] = node.ID() } @@ -102,17 +102,17 @@ func (c Configuration) NodeIDs() []uint32 { } // Nodes returns the nodes in this configuration. -func (c Configuration) Nodes() []*Node { +func (c Configuration[T]) Nodes() []*Node[T] { return c } // Size returns the number of nodes in this configuration. -func (c Configuration) Size() int { +func (c Configuration[T]) Size() int { return len(c) } // Equal returns true if configurations b and c have the same set of nodes. -func (c Configuration) Equal(b Configuration) bool { +func (c Configuration[T]) Equal(b Configuration[T]) bool { if len(c) != len(b) { return false } @@ -126,7 +126,7 @@ func (c Configuration) Equal(b Configuration) bool { // Manager returns the Manager that manages this configuration's nodes. // Returns nil if the configuration is empty. -func (c Configuration) Manager() *Manager { +func (c Configuration[T]) Manager() *Manager[T] { if len(c) == 0 { return nil } @@ -134,6 +134,6 @@ func (c Configuration) Manager() *Manager { } // nextMsgID returns the next message ID from this client's manager. -func (c Configuration) nextMsgID() uint64 { +func (c Configuration[T]) nextMsgID() uint64 { return c[0].msgIDGen() } diff --git a/config_opts.go b/config_opts.go index aeaea28f..66aedac2 100644 --- a/config_opts.go +++ b/config_opts.go @@ -6,9 +6,9 @@ import ( ) // NodeListOption must be implemented by node providers. -type NodeListOption interface { +type NodeListOption[T NodeID] interface { Option - newConfig(*Manager) (Configuration, error) + newConfig(*Manager[T]) (Configuration[T], error) } // NodeAddress must be implemented by types that can be used as node addresses. @@ -16,17 +16,17 @@ type NodeAddress interface { Addr() string } -type structNodeMap[T NodeAddress] struct { - nodes map[uint32]T +type structNodeMap[K NodeID, V NodeAddress] struct { + nodes map[K]V } -func (structNodeMap[T]) isOption() {} +func (structNodeMap[K, V]) isOption() {} -func (o structNodeMap[T]) newConfig(mgr *Manager) (nodes Configuration, err error) { +func (o structNodeMap[K, V]) newConfig(mgr *Manager[K]) (nodes Configuration[K], err error) { if len(o.nodes) == 0 { return nil, fmt.Errorf("config: missing required node map") } - nodes = make(Configuration, 0, len(o.nodes)) + nodes = make(Configuration[K], 0, len(o.nodes)) for id, n := range o.nodes { node, found := mgr.Node(id) if !found { @@ -38,15 +38,15 @@ func (o structNodeMap[T]) newConfig(mgr *Manager) (nodes Configuration, err erro nodes = append(nodes, node) } // Sort nodes to ensure deterministic iteration. - OrderedBy(ID).Sort(mgr.nodes) - OrderedBy(ID).Sort(nodes) + OrderedBy(ID[K]).Sort(mgr.nodes) + OrderedBy(ID[K]).Sort(nodes) return nodes, nil } // WithNodes returns a NodeListOption containing the provided // mapping from application-specific IDs to types implementing NodeAddress. -func WithNodes[T NodeAddress](nodes map[uint32]T) NodeListOption { - return &structNodeMap[T]{nodes: nodes} +func WithNodes[K NodeID, V NodeAddress](nodes map[K]V) NodeListOption[K] { + return &structNodeMap[K, V]{nodes: nodes} } type nodeIDMap struct { @@ -55,11 +55,11 @@ type nodeIDMap struct { func (nodeIDMap) isOption() {} -func (o nodeIDMap) newConfig(mgr *Manager) (nodes Configuration, err error) { +func (o nodeIDMap) newConfig(mgr *Manager[uint32]) (nodes Configuration[uint32], err error) { if len(o.idMap) == 0 { return nil, fmt.Errorf("config: missing required node map") } - nodes = make(Configuration, 0, len(o.idMap)) + nodes = make(Configuration[uint32], 0, len(o.idMap)) for naddr, id := range o.idMap { node, found := mgr.Node(id) if !found { @@ -71,14 +71,14 @@ func (o nodeIDMap) newConfig(mgr *Manager) (nodes Configuration, err error) { nodes = append(nodes, node) } // Sort nodes to ensure deterministic iteration. - OrderedBy(ID).Sort(mgr.nodes) - OrderedBy(ID).Sort(nodes) + OrderedBy(ID[uint32]).Sort(mgr.nodes) + OrderedBy(ID[uint32]).Sort(nodes) return nodes, nil } // WithNodeMap returns a NodeListOption containing the provided // mapping from node addresses to application-specific IDs. -func WithNodeMap(idMap map[string]uint32) NodeListOption { +func WithNodeMap(idMap map[string]uint32) NodeListOption[uint32] { return &nodeIDMap{idMap: idMap} } @@ -88,11 +88,11 @@ type nodeList struct { func (nodeList) isOption() {} -func (o nodeList) newConfig(mgr *Manager) (nodes Configuration, err error) { +func (o nodeList) newConfig(mgr *Manager[uint32]) (nodes Configuration[uint32], err error) { if len(o.addrsList) == 0 { return nil, fmt.Errorf("config: missing required node addresses") } - nodes = make(Configuration, 0, len(o.addrsList)) + nodes = make(Configuration[uint32], 0, len(o.addrsList)) for _, naddr := range o.addrsList { id, err := nodeID(naddr) if err != nil { @@ -108,80 +108,80 @@ func (o nodeList) newConfig(mgr *Manager) (nodes Configuration, err error) { nodes = append(nodes, node) } // Sort nodes to ensure deterministic iteration. - OrderedBy(ID).Sort(mgr.nodes) - OrderedBy(ID).Sort(nodes) + OrderedBy(ID[uint32]).Sort(mgr.nodes) + OrderedBy(ID[uint32]).Sort(nodes) return nodes, nil } // WithNodeList returns a NodeListOption containing the provided list of node addresses. -// With this option, node IDs are generated by the Manager. -func WithNodeList(addrsList []string) NodeListOption { +// With this option, node IDs are generated by the Manager (using uint32). +func WithNodeList(addrsList []string) NodeListOption[uint32] { return &nodeList{addrsList: addrsList} } -type nodeIDs struct { - nodeIDs []uint32 +type nodeIDs[T NodeID] struct { + nodeIDs []T } -func (nodeIDs) isOption() {} +func (nodeIDs[T]) isOption() {} -func (o nodeIDs) newConfig(mgr *Manager) (nodes Configuration, err error) { +func (o nodeIDs[T]) newConfig(mgr *Manager[T]) (nodes Configuration[T], err error) { if len(o.nodeIDs) == 0 { return nil, fmt.Errorf("config: missing required node IDs") } - nodes = make(Configuration, 0, len(o.nodeIDs)) + nodes = make(Configuration[T], 0, len(o.nodeIDs)) for _, id := range o.nodeIDs { node, found := mgr.Node(id) if !found { // Node IDs must have been registered previously - return nil, fmt.Errorf("config: node %d not found", id) + return nil, fmt.Errorf("config: node %v not found", id) } nodes = append(nodes, node) } // Sort nodes to ensure deterministic iteration. - OrderedBy(ID).Sort(mgr.nodes) - OrderedBy(ID).Sort(nodes) + OrderedBy(ID[T]).Sort(mgr.nodes) + OrderedBy(ID[T]).Sort(nodes) return nodes, nil } // WithNodeIDs returns a NodeListOption containing a list of node IDs. // This assumes that the provided node IDs have already been registered with the manager. -func WithNodeIDs(ids []uint32) NodeListOption { - return &nodeIDs{nodeIDs: ids} +func WithNodeIDs[T NodeID](ids []T) NodeListOption[T] { + return &nodeIDs[T]{nodeIDs: ids} } -type addNodes struct { - old Configuration - new NodeListOption +type addNodes[T NodeID] struct { + old Configuration[T] + new NodeListOption[T] } -func (addNodes) isOption() {} +func (addNodes[T]) isOption() {} -func (o addNodes) newConfig(mgr *Manager) (nodes Configuration, err error) { +func (o addNodes[T]) newConfig(mgr *Manager[T]) (nodes Configuration[T], err error) { newNodes, err := o.new.newConfig(mgr) if err != nil { return nil, err } - ac := &addConfig{old: o.old, add: newNodes} + ac := &addConfig[T]{old: o.old, add: newNodes} return ac.newConfig(mgr) } // WithNewNodes returns a NodeListOption that can be used to create a new configuration // combining c and the new nodes. -func (c Configuration) WithNewNodes(newNodes NodeListOption) NodeListOption { - return &addNodes{old: c, new: newNodes} +func (c Configuration[T]) WithNewNodes(newNodes NodeListOption[T]) NodeListOption[T] { + return &addNodes[T]{old: c, new: newNodes} } -type addConfig struct { - old Configuration - add Configuration +type addConfig[T NodeID] struct { + old Configuration[T] + add Configuration[T] } -func (addConfig) isOption() {} +func (addConfig[T]) isOption() {} -func (o addConfig) newConfig(mgr *Manager) (nodes Configuration, err error) { - nodes = make(Configuration, 0, len(o.old)+len(o.add)) - m := make(map[uint32]bool) +func (o addConfig[T]) newConfig(mgr *Manager[T]) (nodes Configuration[T], err error) { + nodes = make(Configuration[T], 0, len(o.old)+len(o.add)) + m := make(map[T]bool) for _, n := range append(o.old, o.add...) { if !m[n.id] { m[n.id] = true @@ -189,46 +189,46 @@ func (o addConfig) newConfig(mgr *Manager) (nodes Configuration, err error) { } } // Sort nodes to ensure deterministic iteration. - OrderedBy(ID).Sort(mgr.nodes) - OrderedBy(ID).Sort(nodes) + OrderedBy(ID[T]).Sort(mgr.nodes) + OrderedBy(ID[T]).Sort(nodes) return nodes, err } // And returns a NodeListOption that can be used to create a new configuration combining c and d. -func (c Configuration) And(d Configuration) NodeListOption { - return &addConfig{old: c, add: d} +func (c Configuration[T]) And(d Configuration[T]) NodeListOption[T] { + return &addConfig[T]{old: c, add: d} } // WithoutNodes returns a NodeListOption that can be used to create a new configuration // from c without the given node IDs. -func (c Configuration) WithoutNodes(ids ...uint32) NodeListOption { - rmIDs := make(map[uint32]bool) +func (c Configuration[T]) WithoutNodes(ids ...T) NodeListOption[T] { + rmIDs := make(map[T]bool) for _, id := range ids { rmIDs[id] = true } - keepIDs := make([]uint32, 0, len(c)) + keepIDs := make([]T, 0, len(c)) for _, cNode := range c { if !rmIDs[cNode.id] { keepIDs = append(keepIDs, cNode.id) } } - return &nodeIDs{nodeIDs: keepIDs} + return &nodeIDs[T]{nodeIDs: keepIDs} } // Except returns a NodeListOption that can be used to create a new configuration // from c without the nodes in rm. -func (c Configuration) Except(rm Configuration) NodeListOption { - rmIDs := make(map[uint32]bool) +func (c Configuration[T]) Except(rm Configuration[T]) NodeListOption[T] { + rmIDs := make(map[T]bool) for _, rmNode := range rm { rmIDs[rmNode.id] = true } - keepIDs := make([]uint32, 0, len(c)) + keepIDs := make([]T, 0, len(c)) for _, cNode := range c { if !rmIDs[cNode.id] { keepIDs = append(keepIDs, cNode.id) } } - return &nodeIDs{nodeIDs: keepIDs} + return &nodeIDs[T]{nodeIDs: keepIDs} } // WithoutErrors returns a NodeListOption that creates a new configuration @@ -236,7 +236,7 @@ func (c Configuration) Except(rm Configuration) NodeListOption { // If specific error types are provided, only nodes whose errors match // one of those types (using errors.Is) will be excluded. // If no error types are provided, all failed nodes are excluded. -func (c Configuration) WithoutErrors(err QuorumCallError, errorTypes ...error) NodeListOption { +func (c Configuration[T]) WithoutErrors(err QuorumCallError, errorTypes ...error) NodeListOption[T] { // Decide whether an error should exclude a node. exclude := func(cause error) bool { if len(errorTypes) == 0 { @@ -251,17 +251,17 @@ func (c Configuration) WithoutErrors(err QuorumCallError, errorTypes ...error) N } // Build a map of node IDs to exclude - rm := make(map[uint32]bool, len(err.errors)) + rm := make(map[T]bool, len(err.errors)) for _, ne := range err.errors { - rm[ne.nodeID] = exclude(ne.cause) + rm[ne.nodeID.(T)] = exclude(ne.cause) } // Build the list of node IDs to keep - keepIDs := make([]uint32, 0, len(c)) + keepIDs := make([]T, 0, len(c)) for _, node := range c { if !rm[node.id] { keepIDs = append(keepIDs, node.id) } } - return &nodeIDs{nodeIDs: keepIDs} + return &nodeIDs[T]{nodeIDs: keepIDs} } diff --git a/config_test.go b/config_test.go index 3a94f0c4..ce519796 100644 --- a/config_test.go +++ b/config_test.go @@ -2,6 +2,7 @@ package gorums_test import ( "errors" + "slices" "sync" "testing" @@ -25,7 +26,7 @@ var ( func TestNewConfigurationEmptyNodeList(t *testing.T) { wantErr := errors.New("config: missing required node addresses") - mgr := gorums.NewManager(gorums.InsecureDialOptions(t)) + mgr := gorums.NewManager[uint32](gorums.InsecureDialOptions(t)) t.Cleanup(gorums.Closer(t, mgr)) _, err := gorums.NewConfiguration(mgr, gorums.WithNodeList([]string{})) @@ -38,7 +39,7 @@ func TestNewConfigurationEmptyNodeList(t *testing.T) { } func TestNewConfigurationNodeList(t *testing.T) { - mgr := gorums.NewManager(gorums.InsecureDialOptions(t)) + mgr := gorums.NewManager[uint32](gorums.InsecureDialOptions(t)) t.Cleanup(gorums.Closer(t, mgr)) cfg, err := gorums.NewConfiguration(mgr, gorums.WithNodeList(nodes)) @@ -49,13 +50,10 @@ func TestNewConfigurationNodeList(t *testing.T) { t.Errorf("cfg.Size() = %d, expected %d", cfg.Size(), len(nodes)) } - contains := func(nodes []*gorums.Node, addr string) bool { - for _, node := range nodes { - if addr == node.Address() { - return true - } - } - return false + contains := func(nodes []*gorums.Node[uint32], addr string) bool { + return slices.ContainsFunc(nodes, func(n *gorums.Node[uint32]) bool { + return n.Address() == addr + }) } cfgNodes := cfg.Nodes() for _, n := range nodes { @@ -76,7 +74,7 @@ func TestNewConfigurationNodeList(t *testing.T) { } func TestNewConfigurationNodeMap(t *testing.T) { - mgr := gorums.NewManager(gorums.InsecureDialOptions(t)) + mgr := gorums.NewManager[uint32](gorums.InsecureDialOptions(t)) t.Cleanup(gorums.Closer(t, mgr)) cfg, err := gorums.NewConfiguration(mgr, gorums.WithNodeMap(nodeMap)) @@ -101,6 +99,8 @@ func TestNewConfigurationNodeMap(t *testing.T) { } } +type CustomID int + type testNode struct { addr string } @@ -110,16 +110,15 @@ func (n testNode) Addr() string { } func TestNewConfigurationWithNodes(t *testing.T) { - mgr := gorums.NewManager(gorums.InsecureDialOptions(t)) + mgr := gorums.NewManager[CustomID](gorums.InsecureDialOptions(t)) t.Cleanup(gorums.Closer(t, mgr)) - nodes := map[uint32]testNode{ + nodes := map[CustomID]testNode{ 1: {addr: "127.0.0.1:9080"}, 2: {addr: "127.0.0.1:9081"}, 3: {addr: "127.0.0.1:9082"}, 4: {addr: "127.0.0.1:9083"}, } - cfg, err := gorums.NewConfiguration(mgr, gorums.WithNodes(nodes)) if err != nil { t.Fatal(err) @@ -143,7 +142,7 @@ func TestNewConfigurationWithNodes(t *testing.T) { } func TestNewConfigurationNodeIDs(t *testing.T) { - mgr := gorums.NewManager(gorums.InsecureDialOptions(t)) + mgr := gorums.NewManager[uint32](gorums.InsecureDialOptions(t)) t.Cleanup(gorums.Closer(t, mgr)) c1, err := gorums.NewConfiguration(mgr, gorums.WithNodeList(nodes)) @@ -181,7 +180,7 @@ func TestNewConfigurationNodeIDs(t *testing.T) { } func TestNewConfigurationAnd(t *testing.T) { - mgr := gorums.NewManager(gorums.InsecureDialOptions(t)) + mgr := gorums.NewManager[uint32](gorums.InsecureDialOptions(t)) t.Cleanup(gorums.Closer(t, mgr)) c1, err := gorums.NewConfiguration(mgr, gorums.WithNodeList(nodes)) @@ -235,7 +234,7 @@ func TestNewConfigurationAnd(t *testing.T) { } func TestNewConfigurationExcept(t *testing.T) { - mgr := gorums.NewManager(gorums.InsecureDialOptions(t)) + mgr := gorums.NewManager[uint32](gorums.InsecureDialOptions(t)) t.Cleanup(gorums.Closer(t, mgr)) c1, err := gorums.NewConfiguration(mgr, gorums.WithNodeList(nodes)) @@ -294,7 +293,7 @@ func TestConfigConcurrentAccess(t *testing.T) { } func TestConfigurationWithoutErrors(t *testing.T) { - mgr := gorums.NewManager(gorums.InsecureDialOptions(t)) + mgr := gorums.NewManager[uint32](gorums.InsecureDialOptions(t)) t.Cleanup(gorums.Closer(t, mgr)) cfg, err := gorums.NewConfiguration(mgr, gorums.WithNodeMap(nodeMap)) diff --git a/correctable.go b/correctable.go index 2724ccb1..8a2da226 100644 --- a/correctable.go +++ b/correctable.go @@ -99,7 +99,7 @@ func (c *Correctable[Resp]) update(reply Resp, level int, done bool, err error) // // Wait for level 2 to be reached // <-corr.Watch(2) // resp, level, err := corr.Get() -func (r *Responses[Resp]) Correctable(threshold int) *Correctable[Resp] { +func (r *Responses[T, Resp]) Correctable(threshold int) *Correctable[Resp] { corr := &Correctable[Resp]{ level: LevelNotSet, donech: make(chan struct{}, 1), diff --git a/correctable_test.go b/correctable_test.go index 9d434e01..9af225d8 100644 --- a/correctable_test.go +++ b/correctable_test.go @@ -14,7 +14,7 @@ func TestCorrectableQuorumCall(t *testing.T) { config := gorums.TestConfiguration(t, 3, gorums.EchoServerFn) ctx := gorums.TestContext(t, 2*time.Second) - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.TestMethod, @@ -70,7 +70,7 @@ func TestCorrectableQuorumCallStream(t *testing.T) { config := gorums.TestConfiguration(t, 3, gorums.StreamServerFn) ctx := gorums.TestContext(t, 2*time.Second) - responses := gorums.QuorumCallStream[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCallStream[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.Stream, @@ -107,7 +107,7 @@ func TestCorrectableWatch(t *testing.T) { config := gorums.TestConfiguration(t, 3, gorums.StreamServerFn) ctx := gorums.TestContext(t, 2*time.Second) - responses := gorums.QuorumCallStream[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCallStream[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.Stream, @@ -141,7 +141,7 @@ func BenchmarkCorrectable(b *testing.B) { // skipcq: GO-R1005 threshold := numNodes/2 + 1 b.ReportAllocs() for b.Loop() { - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, @@ -161,7 +161,7 @@ func BenchmarkCorrectable(b *testing.B) { // skipcq: GO-R1005 threshold := numNodes/2 + 1 b.ReportAllocs() for b.Loop() { - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, @@ -187,7 +187,7 @@ func BenchmarkCorrectable(b *testing.B) { // skipcq: GO-R1005 threshold := numNodes/2 + 1 b.ReportAllocs() for b.Loop() { - responses := gorums.QuorumCallStream[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCallStream[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.Stream, @@ -207,7 +207,7 @@ func BenchmarkCorrectable(b *testing.B) { // skipcq: GO-R1005 threshold := numNodes/2 + 1 b.ReportAllocs() for b.Loop() { - responses := gorums.QuorumCallStream[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCallStream[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.Stream, diff --git a/errors.go b/errors.go index b2201858..28e3cc39 100644 --- a/errors.go +++ b/errors.go @@ -70,9 +70,9 @@ func (e QuorumCallError) Error() string { // nodeError reports on a failed RPC call from a specific node. type nodeError struct { cause error - nodeID uint32 + nodeID any } func (e nodeError) Error() string { - return fmt.Sprintf("node %d: %v", e.nodeID, e.cause) + return fmt.Sprintf("node %v: %v", e.nodeID, e.cause) } diff --git a/errors_test.go b/errors_test.go index 5b5f3765..5cdf381b 100644 --- a/errors_test.go +++ b/errors_test.go @@ -89,7 +89,7 @@ func TestQuorumCallErrorAccessors(t *testing.T) { qcErr: QuorumCallError{ cause: ErrIncomplete, errors: []nodeError{ - {nodeID: 1, cause: status.Error(codes.Unavailable, "node down")}, + {nodeID: uint32(1), cause: status.Error(codes.Unavailable, "node down")}, }, }, wantCause: ErrIncomplete, @@ -100,9 +100,9 @@ func TestQuorumCallErrorAccessors(t *testing.T) { qcErr: QuorumCallError{ cause: ErrIncomplete, errors: []nodeError{ - {nodeID: 1, cause: status.Error(codes.Unavailable, "node down")}, - {nodeID: 3, cause: status.Error(codes.DeadlineExceeded, "timeout")}, - {nodeID: 5, cause: status.Error(codes.Unavailable, "connection refused")}, + {nodeID: uint32(1), cause: status.Error(codes.Unavailable, "node down")}, + {nodeID: uint32(3), cause: status.Error(codes.DeadlineExceeded, "timeout")}, + {nodeID: uint32(5), cause: status.Error(codes.Unavailable, "connection refused")}, }, }, wantCause: ErrIncomplete, @@ -113,7 +113,7 @@ func TestQuorumCallErrorAccessors(t *testing.T) { qcErr: QuorumCallError{ cause: ErrSendFailure, errors: []nodeError{ - {nodeID: 2, cause: errors.New("send failed")}, + {nodeID: uint32(2), cause: errors.New("send failed")}, }, }, wantCause: ErrSendFailure, @@ -141,9 +141,9 @@ func TestQuorumCallErrorUnwrap(t *testing.T) { qcErr := QuorumCallError{ cause: ErrIncomplete, errors: []nodeError{ - {nodeID: 1, cause: unavailableErr}, - {nodeID: 3, cause: timeoutErr}, - {nodeID: 5, cause: connectionErr}, + {nodeID: uint32(1), cause: unavailableErr}, + {nodeID: uint32(3), cause: timeoutErr}, + {nodeID: uint32(5), cause: connectionErr}, }, } @@ -198,8 +198,8 @@ func TestQuorumCallErrorUnwrapWithAs(t *testing.T) { qcErr := QuorumCallError{ cause: ErrIncomplete, errors: []nodeError{ - {nodeID: 1, cause: customErr}, - {nodeID: 2, cause: status.Error(codes.Unavailable, "down")}, + {nodeID: uint32(1), cause: customErr}, + {nodeID: uint32(2), cause: status.Error(codes.Unavailable, "down")}, }, } diff --git a/examples/storage/client.go b/examples/storage/client.go index cc868651..16d75e6c 100644 --- a/examples/storage/client.go +++ b/examples/storage/client.go @@ -15,13 +15,13 @@ func runClient(addresses []string) error { } // init gorums manager - mgr := gorums.NewManager( + mgr := proto.NewManager( gorums.WithDialOptions( grpc.WithTransportCredentials(insecure.NewCredentials()), // disable TLS ), ) // create configuration containing all nodes - cfg, err := gorums.NewConfiguration(mgr, gorums.WithNodeList(addresses)) + cfg, err := proto.NewConfiguration(mgr, gorums.WithNodeList(addresses)) if err != nil { log.Fatal(err) } @@ -31,7 +31,7 @@ func runClient(addresses []string) error { // newestValue processes responses from a ReadQC call and returns the reply // with the most recent timestamp. -func newestValue(responses *gorums.Responses[*proto.ReadResponse]) (*proto.ReadResponse, error) { +func newestValue(responses *gorums.Responses[proto.NodeID, *proto.ReadResponse]) (*proto.ReadResponse, error) { var newest *proto.ReadResponse for resp := range responses.Seq() { if resp.Err != nil { @@ -49,7 +49,7 @@ func newestValue(responses *gorums.Responses[*proto.ReadResponse]) (*proto.ReadR // numUpdated processes responses from a WriteQC call and returns true if // a majority of nodes updated their value. -func numUpdated(responses *gorums.Responses[*proto.WriteResponse]) (*proto.WriteResponse, error) { +func numUpdated(responses *gorums.Responses[proto.NodeID, *proto.WriteResponse]) (*proto.WriteResponse, error) { var count int size := responses.Size() for resp := range responses.Seq() { diff --git a/examples/storage/proto/storage.pb.go b/examples/storage/proto/storage.pb.go index 8ad5a8d7..c8b64d6a 100644 --- a/examples/storage/proto/storage.pb.go +++ b/examples/storage/proto/storage.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: storage/proto/storage.proto package proto diff --git a/examples/storage/proto/storage_gorums.pb.go b/examples/storage/proto/storage_gorums.pb.go index 4c8239a8..a0de696f 100644 --- a/examples/storage/proto/storage_gorums.pb.go +++ b/examples/storage/proto/storage_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: storage/proto/storage.proto package proto @@ -21,9 +21,11 @@ const ( // Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -32,20 +34,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -66,9 +70,12 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } +// NodeID is a type alias for the type used to identify nodes. +type NodeID = uint32 + // AsyncReadResponse is a future for async quorum calls returning *ReadResponse. type AsyncReadResponse = *gorums.Async[*ReadResponse] @@ -86,7 +93,7 @@ var _ emptypb.Empty // ReadRPC executes a Read RPC on a single node and // returns the value for the provided key. -func ReadRPC(ctx *gorums.NodeContext, in *ReadRequest) (resp *ReadResponse, err error) { +func ReadRPC(ctx *NodeContext, in *ReadRequest) (resp *ReadResponse, err error) { res, err := gorums.RPCCall(ctx, in, "proto.Storage.ReadRPC") if err != nil { return nil, err @@ -96,7 +103,7 @@ func ReadRPC(ctx *gorums.NodeContext, in *ReadRequest) (resp *ReadResponse, err // WriteRPC executes a Write RPC on a single node and // returns true if the value was updated. -func WriteRPC(ctx *gorums.NodeContext, in *WriteRequest) (resp *WriteResponse, err error) { +func WriteRPC(ctx *NodeContext, in *WriteRequest) (resp *WriteResponse, err error) { res, err := gorums.RPCCall(ctx, in, "proto.Storage.WriteRPC") if err != nil { return nil, err @@ -106,8 +113,8 @@ func WriteRPC(ctx *gorums.NodeContext, in *WriteRequest) (resp *WriteResponse, e // ReadQC executes a Read quorum call on a configuration of nodes and // returns the most recent value. -func ReadQC(ctx *gorums.ConfigContext, in *ReadRequest, opts ...gorums.CallOption) *gorums.Responses[*ReadResponse] { - return gorums.QuorumCall[*ReadRequest, *ReadResponse]( +func ReadQC(ctx *ConfigContext, in *ReadRequest, opts ...gorums.CallOption) *gorums.Responses[NodeID, *ReadResponse] { + return gorums.QuorumCall[NodeID, *ReadRequest, *ReadResponse]( ctx, in, "proto.Storage.ReadQC", opts..., ) @@ -115,8 +122,8 @@ func ReadQC(ctx *gorums.ConfigContext, in *ReadRequest, opts ...gorums.CallOptio // WriteQC executes a Write quorum call on a configuration of nodes and // returns true if a majority of nodes were updated. -func WriteQC(ctx *gorums.ConfigContext, in *WriteRequest, opts ...gorums.CallOption) *gorums.Responses[*WriteResponse] { - return gorums.QuorumCall[*WriteRequest, *WriteResponse]( +func WriteQC(ctx *ConfigContext, in *WriteRequest, opts ...gorums.CallOption) *gorums.Responses[NodeID, *WriteResponse] { + return gorums.QuorumCall[NodeID, *WriteRequest, *WriteResponse]( ctx, in, "proto.Storage.WriteQC", opts..., ) @@ -124,7 +131,7 @@ func WriteQC(ctx *gorums.ConfigContext, in *WriteRequest, opts ...gorums.CallOpt // WriteMulticast executes a Write multicast call on a configuration of nodes. // It does not wait for any responses. -func WriteMulticast(ctx *gorums.ConfigContext, in *WriteRequest, opts ...gorums.CallOption) error { +func WriteMulticast(ctx *ConfigContext, in *WriteRequest, opts ...gorums.CallOption) error { return gorums.Multicast(ctx, in, "proto.Storage.WriteMulticast", opts...) } diff --git a/gorums.pb.go b/gorums.pb.go index 8e35f17d..f45fe2b1 100644 --- a/gorums.pb.go +++ b/gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: gorums.proto package gorums @@ -22,6 +22,14 @@ const ( ) var file_gorums_proto_extTypes = []protoimpl.ExtensionInfo{ + { + ExtendedType: (*descriptorpb.FileOptions)(nil), + ExtensionType: (*string)(nil), + Field: 40001, + Name: "gorums.node_id_type", + Tag: "bytes,40001,opt,name=node_id_type", + Filename: "gorums.proto", + }, { ExtendedType: (*descriptorpb.MethodOptions)(nil), ExtensionType: (*bool)(nil), @@ -56,25 +64,33 @@ var file_gorums_proto_extTypes = []protoimpl.ExtensionInfo{ }, } +// Extension fields to descriptorpb.FileOptions. +var ( + // optional string node_id_type = 40001; + E_NodeIdType = &file_gorums_proto_extTypes[0] +) + // Extension fields to descriptorpb.MethodOptions. var ( // call types // // optional bool rpc = 50001; - E_Rpc = &file_gorums_proto_extTypes[0] + E_Rpc = &file_gorums_proto_extTypes[1] // optional bool unicast = 50002; - E_Unicast = &file_gorums_proto_extTypes[1] + E_Unicast = &file_gorums_proto_extTypes[2] // optional bool multicast = 50003; - E_Multicast = &file_gorums_proto_extTypes[2] + E_Multicast = &file_gorums_proto_extTypes[3] // optional bool quorumcall = 50004; - E_Quorumcall = &file_gorums_proto_extTypes[3] + E_Quorumcall = &file_gorums_proto_extTypes[4] ) var File_gorums_proto protoreflect.FileDescriptor const file_gorums_proto_rawDesc = "" + "\n" + - "\fgorums.proto\x12\x06gorums\x1a google/protobuf/descriptor.proto:2\n" + + "\fgorums.proto\x12\x06gorums\x1a google/protobuf/descriptor.proto:@\n" + + "\fnode_id_type\x12\x1c.google.protobuf.FileOptions\x18\xc1\xb8\x02 \x01(\tR\n" + + "nodeIdType:2\n" + "\x03rpc\x12\x1e.google.protobuf.MethodOptions\x18ц\x03 \x01(\bR\x03rpc::\n" + "\aunicast\x12\x1e.google.protobuf.MethodOptions\x18҆\x03 \x01(\bR\aunicast:>\n" + "\tmulticast\x12\x1e.google.protobuf.MethodOptions\x18ӆ\x03 \x01(\bR\tmulticast:@\n" + @@ -83,17 +99,19 @@ const file_gorums_proto_rawDesc = "" + "quorumcallB\x1eZ\x17github.com/relab/gorums\x92\x03\x02\b\x02b\beditionsp\xe8\a" var file_gorums_proto_goTypes = []any{ - (*descriptorpb.MethodOptions)(nil), // 0: google.protobuf.MethodOptions + (*descriptorpb.FileOptions)(nil), // 0: google.protobuf.FileOptions + (*descriptorpb.MethodOptions)(nil), // 1: google.protobuf.MethodOptions } var file_gorums_proto_depIdxs = []int32{ - 0, // 0: gorums.rpc:extendee -> google.protobuf.MethodOptions - 0, // 1: gorums.unicast:extendee -> google.protobuf.MethodOptions - 0, // 2: gorums.multicast:extendee -> google.protobuf.MethodOptions - 0, // 3: gorums.quorumcall:extendee -> google.protobuf.MethodOptions - 4, // [4:4] is the sub-list for method output_type - 4, // [4:4] is the sub-list for method input_type - 4, // [4:4] is the sub-list for extension type_name - 0, // [0:4] is the sub-list for extension extendee + 0, // 0: gorums.node_id_type:extendee -> google.protobuf.FileOptions + 1, // 1: gorums.rpc:extendee -> google.protobuf.MethodOptions + 1, // 2: gorums.unicast:extendee -> google.protobuf.MethodOptions + 1, // 3: gorums.multicast:extendee -> google.protobuf.MethodOptions + 1, // 4: gorums.quorumcall:extendee -> google.protobuf.MethodOptions + 5, // [5:5] is the sub-list for method output_type + 5, // [5:5] is the sub-list for method input_type + 5, // [5:5] is the sub-list for extension type_name + 0, // [0:5] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name } @@ -109,7 +127,7 @@ func file_gorums_proto_init() { RawDescriptor: unsafe.Slice(unsafe.StringData(file_gorums_proto_rawDesc), len(file_gorums_proto_rawDesc)), NumEnums: 0, NumMessages: 0, - NumExtensions: 4, + NumExtensions: 5, NumServices: 0, }, GoTypes: file_gorums_proto_goTypes, diff --git a/gorums.proto b/gorums.proto index c78497d1..1169b8bd 100644 --- a/gorums.proto +++ b/gorums.proto @@ -1,16 +1,21 @@ edition = "2023"; package gorums; + +import "google/protobuf/descriptor.proto"; + option go_package = "github.com/relab/gorums"; option features.field_presence = IMPLICIT; -import "google/protobuf/descriptor.proto"; +extend google.protobuf.FileOptions { + string node_id_type = 40001; +} extend google.protobuf.MethodOptions { - // call types - bool rpc = 50001; - bool unicast = 50002; - bool multicast = 50003; - bool quorumcall = 50004; + // call types + bool rpc = 50001; + bool unicast = 50002; + bool multicast = 50003; + bool quorumcall = 50004; } diff --git a/internal/tests/config/config.pb.go b/internal/tests/config/config.pb.go index 3031b737..97228c4d 100644 --- a/internal/tests/config/config.pb.go +++ b/internal/tests/config/config.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: config/config.proto package config @@ -161,7 +161,7 @@ const file_config_config_proto_rawDesc = "" + "\x03Num\x18\x02 \x01(\x04R\x03Num2?\n" + "\n" + "ConfigTest\x121\n" + - "\x06Config\x12\x0f.config.Request\x1a\x10.config.Response\"\x04\xa0\xb5\x18\x01B+Z$github.com/relab/gorums/tests/config\x92\x03\x02\b\x02b\beditionsp\xe8\a" + "\x06Config\x12\x0f.config.Request\x1a\x10.config.Response\"\x04\xa0\xb5\x18\x01B5\x8a\xc4\x13\x06uint32Z$github.com/relab/gorums/tests/config\x92\x03\x02\b\x02b\beditionsp\xe8\a" var file_config_config_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_config_config_proto_goTypes = []any{ diff --git a/internal/tests/config/config.proto b/internal/tests/config/config.proto index 10cdd817..d64a2de4 100644 --- a/internal/tests/config/config.proto +++ b/internal/tests/config/config.proto @@ -1,12 +1,13 @@ edition = "2023"; package config; -option go_package = "github.com/relab/gorums/tests/config"; - -option features.field_presence = IMPLICIT; import "gorums.proto"; +option go_package = "github.com/relab/gorums/tests/config"; +option features.field_presence = IMPLICIT; +option (gorums.node_id_type) = "uint32"; + service ConfigTest { rpc Config(Request) returns (Response) { option (gorums.quorumcall) = true; diff --git a/internal/tests/config/config_gorums.pb.go b/internal/tests/config/config_gorums.pb.go index ecf3e1b2..04bf4e8b 100644 --- a/internal/tests/config/config_gorums.pb.go +++ b/internal/tests/config/config_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: config/config.proto package config @@ -20,9 +20,11 @@ const ( // Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -31,20 +33,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -65,9 +69,12 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } +// NodeID is a type alias for the type used to identify nodes. +type NodeID = uint32 + // AsyncResponse is a future for async quorum calls returning *Response. type AsyncResponse = *gorums.Async[*Response] @@ -81,8 +88,8 @@ type CorrectableResponse = *gorums.Correctable[*Response] // Example: // // resp, err := Config(ctx, in).Majority() -func Config(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[*Response] { - return gorums.QuorumCall[*Request, *Response]( +func Config(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Response] { + return gorums.QuorumCall[NodeID, *Request, *Response]( ctx, in, "config.ConfigTest.Config", opts..., ) diff --git a/internal/tests/config/config_test.go b/internal/tests/config/config_test.go index 72b9936d..2c134ce2 100644 --- a/internal/tests/config/config_test.go +++ b/internal/tests/config/config_test.go @@ -25,7 +25,7 @@ func serverFn(_ int) gorums.ServerIface { // TestConfig creates and combines multiple configurations and invokes the Config RPC // method on the different configurations created below. func TestConfig(t *testing.T) { - callRPC := func(config gorums.Configuration) { + callRPC := func(config Configuration) { cfgCtx := config.Context(context.Background()) for i := range 5 { // Use the new terminal method API - wait for a majority diff --git a/internal/tests/correctable/correctable.pb.go b/internal/tests/correctable/correctable.pb.go index d6001b27..850dac26 100644 --- a/internal/tests/correctable/correctable.pb.go +++ b/internal/tests/correctable/correctable.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: correctable/correctable.proto package correctable diff --git a/internal/tests/correctable/correctable_gorums.pb.go b/internal/tests/correctable/correctable_gorums.pb.go index 95171059..fffe2ebe 100644 --- a/internal/tests/correctable/correctable_gorums.pb.go +++ b/internal/tests/correctable/correctable_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: correctable/correctable.proto package correctable @@ -21,9 +21,11 @@ const ( // Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -32,20 +34,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -66,9 +70,12 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } +// NodeID is a type alias for the type used to identify nodes. +type NodeID = uint32 + // AsyncResponse is a future for async quorum calls returning *Response. type AsyncResponse = *gorums.Async[*Response] @@ -82,8 +89,8 @@ type CorrectableResponse = *gorums.Correctable[*Response] // Example: // // resp, err := Correctable(ctx, in).Majority() -func Correctable(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[*Response] { - return gorums.QuorumCall[*Request, *Response]( +func Correctable(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Response] { + return gorums.QuorumCall[NodeID, *Request, *Response]( ctx, in, "correctable.CorrectableTest.Correctable", opts..., ) @@ -97,8 +104,8 @@ func Correctable(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOpti // corr := CorrectableStream(ctx, in).Correctable(2) // <-corr.Watch(2) // resp, level, err := corr.Get() -func CorrectableStream(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[*Response] { - return gorums.QuorumCallStream[*Request, *Response]( +func CorrectableStream(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Response] { + return gorums.QuorumCallStream[NodeID, *Request, *Response]( ctx, in, "correctable.CorrectableTest.CorrectableStream", opts..., ) diff --git a/internal/tests/correctable/correctable_test.go b/internal/tests/correctable/correctable_test.go index d821e790..686dbed4 100644 --- a/internal/tests/correctable/correctable_test.go +++ b/internal/tests/correctable/correctable_test.go @@ -10,7 +10,7 @@ import ( // run a test on a correctable call. // n is the number of replicas. // the target level is n (quorum size). -func run(t testing.TB, n int, corr func(*gorums.ConfigContext, int) CorrectableResponse) { +func run(t testing.TB, n int, corr func(*ConfigContext, int) CorrectableResponse) { t.Helper() config := gorums.TestConfiguration(t, n, func(_ int) gorums.ServerIface { gorumsSrv := gorums.NewServer() @@ -36,14 +36,14 @@ func run(t testing.TB, n int, corr func(*gorums.ConfigContext, int) CorrectableR } func TestCorrectable(t *testing.T) { - run(t, 4, func(ctx *gorums.ConfigContext, n int) CorrectableResponse { + run(t, 4, func(ctx *ConfigContext, n int) CorrectableResponse { // Correctable returns *Responses, user calls Correctable to get *Correctable return Correctable(ctx, &Request{}).Correctable(n) }) } func TestCorrectableStream(t *testing.T) { - run(t, 4, func(ctx *gorums.ConfigContext, n int) CorrectableResponse { + run(t, 4, func(ctx *ConfigContext, n int) CorrectableResponse { // CorrectableStream returns *Responses, user calls Correctable to get *Correctable return CorrectableStream(ctx, &Request{}).Correctable(n) }) diff --git a/internal/tests/metadata/metadata.pb.go b/internal/tests/metadata/metadata.pb.go index f5a78ae6..ae608d12 100644 --- a/internal/tests/metadata/metadata.pb.go +++ b/internal/tests/metadata/metadata.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: metadata/metadata.proto package metadata @@ -21,27 +21,27 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -type NodeID struct { +type NodeIDMsg struct { state protoimpl.MessageState `protogen:"opaque.v1"` xxx_hidden_ID uint32 `protobuf:"varint,1,opt,name=ID"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } -func (x *NodeID) Reset() { - *x = NodeID{} +func (x *NodeIDMsg) Reset() { + *x = NodeIDMsg{} mi := &file_metadata_metadata_proto_msgTypes[0] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } -func (x *NodeID) String() string { +func (x *NodeIDMsg) String() string { return protoimpl.X.MessageStringOf(x) } -func (*NodeID) ProtoMessage() {} +func (*NodeIDMsg) ProtoMessage() {} -func (x *NodeID) ProtoReflect() protoreflect.Message { +func (x *NodeIDMsg) ProtoReflect() protoreflect.Message { mi := &file_metadata_metadata_proto_msgTypes[0] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) @@ -53,25 +53,25 @@ func (x *NodeID) ProtoReflect() protoreflect.Message { return mi.MessageOf(x) } -func (x *NodeID) GetID() uint32 { +func (x *NodeIDMsg) GetID() uint32 { if x != nil { return x.xxx_hidden_ID } return 0 } -func (x *NodeID) SetID(v uint32) { +func (x *NodeIDMsg) SetID(v uint32) { x.xxx_hidden_ID = v } -type NodeID_builder struct { +type NodeIDMsg_builder struct { _ [0]func() // Prevents comparability and use of unkeyed literals for the builder. ID uint32 } -func (b0 NodeID_builder) Build() *NodeID { - m0 := &NodeID{} +func (b0 NodeIDMsg_builder) Build() *NodeIDMsg { + m0 := &NodeIDMsg{} b, x := &b0, m0 _, _ = b, x x.xxx_hidden_ID = b.ID @@ -139,25 +139,25 @@ var File_metadata_metadata_proto protoreflect.FileDescriptor const file_metadata_metadata_proto_rawDesc = "" + "\n" + - "\x17metadata/metadata.proto\x12\bmetadata\x1a\x1bgoogle/protobuf/empty.proto\"\x18\n" + - "\x06NodeID\x12\x0e\n" + + "\x17metadata/metadata.proto\x12\bmetadata\x1a\x1bgoogle/protobuf/empty.proto\"\x1b\n" + + "\tNodeIDMsg\x12\x0e\n" + "\x02ID\x18\x01 \x01(\rR\x02ID\"\x1c\n" + "\x06IPAddr\x12\x12\n" + - "\x04Addr\x18\x01 \x01(\tR\x04Addr2|\n" + - "\fMetadataTest\x126\n" + - "\bIDFromMD\x12\x16.google.protobuf.Empty\x1a\x10.metadata.NodeID\"\x00\x124\n" + + "\x04Addr\x18\x01 \x01(\tR\x04Addr2\x7f\n" + + "\fMetadataTest\x129\n" + + "\bIDFromMD\x12\x16.google.protobuf.Empty\x1a\x13.metadata.NodeIDMsg\"\x00\x124\n" + "\x06WhatIP\x12\x16.google.protobuf.Empty\x1a\x10.metadata.IPAddr\"\x00B-Z&github.com/relab/gorums/tests/metadata\x92\x03\x02\b\x02b\beditionsp\xe8\a" var file_metadata_metadata_proto_msgTypes = make([]protoimpl.MessageInfo, 2) var file_metadata_metadata_proto_goTypes = []any{ - (*NodeID)(nil), // 0: metadata.NodeID + (*NodeIDMsg)(nil), // 0: metadata.NodeIDMsg (*IPAddr)(nil), // 1: metadata.IPAddr (*emptypb.Empty)(nil), // 2: google.protobuf.Empty } var file_metadata_metadata_proto_depIdxs = []int32{ 2, // 0: metadata.MetadataTest.IDFromMD:input_type -> google.protobuf.Empty 2, // 1: metadata.MetadataTest.WhatIP:input_type -> google.protobuf.Empty - 0, // 2: metadata.MetadataTest.IDFromMD:output_type -> metadata.NodeID + 0, // 2: metadata.MetadataTest.IDFromMD:output_type -> metadata.NodeIDMsg 1, // 3: metadata.MetadataTest.WhatIP:output_type -> metadata.IPAddr 2, // [2:4] is the sub-list for method output_type 0, // [0:2] is the sub-list for method input_type diff --git a/internal/tests/metadata/metadata.proto b/internal/tests/metadata/metadata.proto index 8d095072..8bc69fdf 100644 --- a/internal/tests/metadata/metadata.proto +++ b/internal/tests/metadata/metadata.proto @@ -9,12 +9,12 @@ import "google/protobuf/empty.proto"; service MetadataTest { // IDFromMD returns the 'id' field from the metadata. - rpc IDFromMD(google.protobuf.Empty) returns (NodeID) {} + rpc IDFromMD(google.protobuf.Empty) returns (NodeIDMsg) {} // WhatIP returns the address of the client that calls it. rpc WhatIP(google.protobuf.Empty) returns (IPAddr) {} } -message NodeID { +message NodeIDMsg { uint32 ID = 1; } diff --git a/internal/tests/metadata/metadata_gorums.pb.go b/internal/tests/metadata/metadata_gorums.pb.go index f819ff04..537775b2 100644 --- a/internal/tests/metadata/metadata_gorums.pb.go +++ b/internal/tests/metadata/metadata_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: metadata/metadata.proto package metadata @@ -21,9 +21,11 @@ const ( // Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -32,20 +34,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -66,20 +70,23 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } +// NodeID is a type alias for the type used to identify nodes. +type NodeID = uint32 + // IDFromMD returns the 'id' field from the metadata. -func IDFromMD(ctx *gorums.NodeContext, in *emptypb.Empty) (resp *NodeID, err error) { +func IDFromMD(ctx *NodeContext, in *emptypb.Empty) (resp *NodeIDMsg, err error) { res, err := gorums.RPCCall(ctx, in, "metadata.MetadataTest.IDFromMD") if err != nil { return nil, err } - return res.(*NodeID), err + return res.(*NodeIDMsg), err } // WhatIP returns the address of the client that calls it. -func WhatIP(ctx *gorums.NodeContext, in *emptypb.Empty) (resp *IPAddr, err error) { +func WhatIP(ctx *NodeContext, in *emptypb.Empty) (resp *IPAddr, err error) { res, err := gorums.RPCCall(ctx, in, "metadata.MetadataTest.WhatIP") if err != nil { return nil, err @@ -89,7 +96,7 @@ func WhatIP(ctx *gorums.NodeContext, in *emptypb.Empty) (resp *IPAddr, err error // MetadataTest is the server-side API for the MetadataTest Service type MetadataTestServer interface { - IDFromMD(ctx gorums.ServerCtx, request *emptypb.Empty) (response *NodeID, err error) + IDFromMD(ctx gorums.ServerCtx, request *emptypb.Empty) (response *NodeIDMsg, err error) WhatIP(ctx gorums.ServerCtx, request *emptypb.Empty) (response *IPAddr, err error) } diff --git a/internal/tests/metadata/metadata_test.go b/internal/tests/metadata/metadata_test.go index a8eeb5fc..0a88a722 100644 --- a/internal/tests/metadata/metadata_test.go +++ b/internal/tests/metadata/metadata_test.go @@ -15,7 +15,7 @@ import ( type testSrv struct{} -func (testSrv) IDFromMD(ctx gorums.ServerCtx, _ *emptypb.Empty) (resp *NodeID, err error) { +func (testSrv) IDFromMD(ctx gorums.ServerCtx, _ *emptypb.Empty) (resp *NodeIDMsg, err error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, status.Error(codes.NotFound, "metadata unavailable") @@ -28,7 +28,7 @@ func (testSrv) IDFromMD(ctx gorums.ServerCtx, _ *emptypb.Empty) (resp *NodeID, e if err != nil { return nil, status.Errorf(codes.InvalidArgument, "value of id field: %q is not a number: %v", v[0], err) } - return NodeID_builder{ID: id}.Build(), nil + return NodeIDMsg_builder{ID: id}.Build(), nil } func (testSrv) WhatIP(ctx gorums.ServerCtx, _ *emptypb.Empty) (resp *IPAddr, err error) { diff --git a/internal/tests/oneway/oneway.pb.go b/internal/tests/oneway/oneway.pb.go index e4e9d244..73bbe742 100644 --- a/internal/tests/oneway/oneway.pb.go +++ b/internal/tests/oneway/oneway.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: oneway/oneway.proto package oneway diff --git a/internal/tests/oneway/oneway_gorums.pb.go b/internal/tests/oneway/oneway_gorums.pb.go index 9550704b..3cf6724c 100644 --- a/internal/tests/oneway/oneway_gorums.pb.go +++ b/internal/tests/oneway/oneway_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: oneway/oneway.proto package oneway @@ -20,9 +20,11 @@ const ( // Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -31,20 +33,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -65,18 +69,21 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } +// NodeID is a type alias for the type used to identify nodes. +type NodeID = uint32 + // Unicast is a unicast call invoked on the node in ctx. // No reply is returned to the client. -func Unicast(ctx *gorums.NodeContext, in *Request, opts ...gorums.CallOption) error { +func Unicast(ctx *NodeContext, in *Request, opts ...gorums.CallOption) error { return gorums.Unicast(ctx, in, "oneway.OnewayTest.Unicast", opts...) } // Multicast is a multicast call invoked on all nodes in the configuration in ctx. // Use gorums.MapRequest to send different messages to each node. No replies are collected. -func Multicast(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) error { +func Multicast(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) error { return gorums.Multicast(ctx, in, "oneway.OnewayTest.Multicast", opts...) } diff --git a/internal/tests/oneway/oneway_test.go b/internal/tests/oneway/oneway_test.go index 8b1c96d5..8e58ad6c 100644 --- a/internal/tests/oneway/oneway_test.go +++ b/internal/tests/oneway/oneway_test.go @@ -38,7 +38,7 @@ func (s *onewaySrv) Multicast(_ gorums.ServerCtx, r *oneway.Request) { // setupWithNodeMap sets up servers and configuration with sequential node IDs // (0, 1, 2, ...) matching the server array indices. This is needed for tests like // TestMulticastPerNode that verify per-node message transformations based on node ID. -func setupWithNodeMap(t testing.TB, cfgSize int) (cfg gorums.Configuration, srvs []*onewaySrv) { +func setupWithNodeMap(t testing.TB, cfgSize int) (cfg oneway.Configuration, srvs []*onewaySrv) { t.Helper() srvs = make([]*onewaySrv, cfgSize) for i := range cfgSize { @@ -126,7 +126,7 @@ func TestMulticastPerNode(t *testing.T) { // transformation function that uses the MapRequest interceptor // to add the msg ID + node ID to the Num field - f := func(msg *oneway.Request, node *gorums.Node) *oneway.Request { + f := func(msg *oneway.Request, node *oneway.Node) *oneway.Request { return oneway.Request_builder{Num: add(msg.GetNum(), node.ID())}.Build() } // the ignoreNodes slice is updated in each test case below; it is a hack @@ -139,7 +139,7 @@ func TestMulticastPerNode(t *testing.T) { return false } // transformation for all except some nodes that are ignored - g := func(msg *oneway.Request, node *gorums.Node) *oneway.Request { + g := func(msg *oneway.Request, node *oneway.Node) *oneway.Request { if ignore(node.ID()) { return nil } @@ -151,7 +151,7 @@ func TestMulticastPerNode(t *testing.T) { servers int sendWait bool ignoreNodes []int - mapFunc func(*oneway.Request, *gorums.Node) *oneway.Request + mapFunc func(*oneway.Request, *oneway.Node) *oneway.Request }{ {name: "MulticastPerNodeNoSendWaiting", calls: numCalls, servers: 1, sendWait: false, mapFunc: f}, {name: "MulticastPerNodeNoSendWaiting", calls: numCalls, servers: 3, sendWait: false, mapFunc: f}, @@ -186,7 +186,7 @@ func TestMulticastPerNode(t *testing.T) { for c := 1; c <= test.calls; c++ { in := oneway.Request_builder{Num: uint64(c)}.Build() cfgCtx := config.Context(context.Background()) - mapInterceptor := gorums.MapRequest[*oneway.Request, *emptypb.Empty](test.mapFunc) + mapInterceptor := gorums.MapRequest[uint32, *oneway.Request, *emptypb.Empty](test.mapFunc) if test.sendWait { if err := oneway.Multicast(cfgCtx, in, gorums.Interceptors(mapInterceptor), diff --git a/internal/tests/ordering/order.pb.go b/internal/tests/ordering/order.pb.go index 817d974c..f80bf58b 100644 --- a/internal/tests/ordering/order.pb.go +++ b/internal/tests/ordering/order.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: ordering/order.proto package ordering diff --git a/internal/tests/ordering/order_gorums.pb.go b/internal/tests/ordering/order_gorums.pb.go index 57d4f5cf..ca3cb610 100644 --- a/internal/tests/ordering/order_gorums.pb.go +++ b/internal/tests/ordering/order_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: ordering/order.proto package ordering @@ -20,9 +20,11 @@ const ( // Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -31,20 +33,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -65,9 +69,12 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } +// NodeID is a type alias for the type used to identify nodes. +type NodeID = uint32 + // AsyncResponse is a future for async quorum calls returning *Response. type AsyncResponse = *gorums.Async[*Response] @@ -81,15 +88,15 @@ type CorrectableResponse = *gorums.Correctable[*Response] // Example: // // resp, err := QuorumCall(ctx, in).Majority() -func QuorumCall(ctx *gorums.ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[*Response] { - return gorums.QuorumCall[*Request, *Response]( +func QuorumCall(ctx *ConfigContext, in *Request, opts ...gorums.CallOption) *gorums.Responses[NodeID, *Response] { + return gorums.QuorumCall[NodeID, *Request, *Response]( ctx, in, "ordering.GorumsTest.QuorumCall", opts..., ) } // UnaryRPC is an RPC call invoked on the node in ctx. -func UnaryRPC(ctx *gorums.NodeContext, in *Request) (resp *Response, err error) { +func UnaryRPC(ctx *NodeContext, in *Request) (resp *Response, err error) { res, err := gorums.RPCCall(ctx, in, "ordering.GorumsTest.UnaryRPC") if err != nil { return nil, err diff --git a/internal/tests/tls/tls.pb.go b/internal/tests/tls/tls.pb.go index e9d5668a..680d216f 100644 --- a/internal/tests/tls/tls.pb.go +++ b/internal/tests/tls/tls.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: tls/tls.proto package tls diff --git a/internal/tests/tls/tls_gorums.pb.go b/internal/tests/tls/tls_gorums.pb.go index 9285b8e0..e73eb211 100644 --- a/internal/tests/tls/tls_gorums.pb.go +++ b/internal/tests/tls/tls_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: tls/tls.proto package tls @@ -20,9 +20,11 @@ const ( // Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -31,20 +33,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -65,11 +69,14 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } +// NodeID is a type alias for the type used to identify nodes. +type NodeID = uint32 + // TestTLS is an RPC call invoked on the node in ctx. -func TestTLS(ctx *gorums.NodeContext, in *Request) (resp *Response, err error) { +func TestTLS(ctx *NodeContext, in *Request) (resp *Response, err error) { res, err := gorums.RPCCall(ctx, in, "tls.TLS.TestTLS") if err != nil { return nil, err diff --git a/internal/tests/unresponsive/unresponsive.pb.go b/internal/tests/unresponsive/unresponsive.pb.go index f1e9fe1c..1fdc0eec 100644 --- a/internal/tests/unresponsive/unresponsive.pb.go +++ b/internal/tests/unresponsive/unresponsive.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: unresponsive/unresponsive.proto package unresponsive diff --git a/internal/tests/unresponsive/unresponsive_gorums.pb.go b/internal/tests/unresponsive/unresponsive_gorums.pb.go index fa40190f..d109973a 100644 --- a/internal/tests/unresponsive/unresponsive_gorums.pb.go +++ b/internal/tests/unresponsive/unresponsive_gorums.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-gorums. DO NOT EDIT. // versions: // protoc-gen-gorums v0.11.0-devel -// protoc v6.33.2 +// protoc v6.33.4 // source: unresponsive/unresponsive.proto package unresponsive @@ -20,9 +20,11 @@ const ( // Type aliases for important Gorums types to make them more accessible // from user code already interacting with the generated code. type ( - Configuration = gorums.Configuration - Manager = gorums.Manager - Node = gorums.Node + Configuration = gorums.Configuration[NodeID] + Manager = gorums.Manager[NodeID] + Node = gorums.Node[NodeID] + NodeContext = gorums.NodeContext[NodeID] + ConfigContext = gorums.ConfigContext[NodeID] ) // Use the aliased types to add them to the reserved identifiers list. @@ -31,20 +33,22 @@ var ( _ = (*Configuration)(nil) _ = (*Manager)(nil) _ = (*Node)(nil) + _ = (*NodeContext)(nil) + _ = (*ConfigContext)(nil) ) // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. func NewManager(opts ...gorums.ManagerOption) *Manager { - return gorums.NewManager(opts...) + return gorums.NewManager[NodeID](opts...) } // NewConfiguration returns a configuration based on the provided list of nodes. // Nodes can be supplied using WithNodeMap or WithNodeList, or WithNodeIDs. // A new configuration can also be created from an existing configuration, // using the And, WithNewNodes, Except, and WithoutNodes methods. -func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, error) { +func NewConfiguration(mgr *Manager, opt gorums.NodeListOption[NodeID]) (Configuration, error) { return gorums.NewConfiguration(mgr, opt) } @@ -65,11 +69,14 @@ func NewConfiguration(mgr *Manager, opt gorums.NodeListOption) (Configuration, e // creates a new manager; if a manager already exists, use [NewConfiguration] // instead, and provide the existing manager as the first argument. func NewConfig(opts ...gorums.Option) (Configuration, error) { - return gorums.NewConfig(opts...) + return gorums.NewConfig[NodeID](opts...) } +// NodeID is a type alias for the type used to identify nodes. +type NodeID = uint32 + // TestUnresponsive is an RPC call invoked on the node in ctx. -func TestUnresponsive(ctx *gorums.NodeContext, in *Empty) (resp *Empty, err error) { +func TestUnresponsive(ctx *NodeContext, in *Empty) (resp *Empty, err error) { res, err := gorums.RPCCall(ctx, in, "unresponsive.Unresponsive.TestUnresponsive") if err != nil { return nil, err diff --git a/mgr.go b/mgr.go index 3f407a14..2237b466 100644 --- a/mgr.go +++ b/mgr.go @@ -9,14 +9,15 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/backoff" + "google.golang.org/grpc/metadata" ) // Manager maintains a connection pool of nodes on // which quorum calls can be performed. -type Manager struct { +type Manager[T NodeID] struct { mu sync.Mutex - nodes []*Node - lookup map[uint32]*Node + nodes []*Node[T] + lookup map[T]*Node[T] closeOnce sync.Once logger *log.Logger opts managerOptions @@ -26,9 +27,9 @@ type Manager struct { // NewManager returns a new Manager for managing connection to nodes added // to the manager. This function accepts manager options used to configure // various aspects of the manager. -func NewManager(opts ...ManagerOption) *Manager { - m := &Manager{ - lookup: make(map[uint32]*Node), +func NewManager[T NodeID](opts ...ManagerOption) *Manager[T] { + m := &Manager[T]{ + lookup: make(map[T]*Node[T]), opts: newManagerOptions(), } for _, opt := range opts { @@ -52,7 +53,7 @@ func NewManager(opts ...ManagerOption) *Manager { } // Close closes all node connections and any client streams. -func (m *Manager) Close() error { +func (m *Manager[T]) Close() error { var err error m.closeOnce.Do(func() { for _, node := range m.nodes { @@ -64,10 +65,10 @@ func (m *Manager) Close() error { // NodeIDs returns the identifier of each available node. IDs are returned in // the same order as they were provided in the creation of the Manager. -func (m *Manager) NodeIDs() []uint32 { +func (m *Manager[T]) NodeIDs() []T { m.mu.Lock() defer m.mu.Unlock() - ids := make([]uint32, 0, len(m.nodes)) + ids := make([]T, 0, len(m.nodes)) for _, node := range m.nodes { ids = append(ids, node.ID()) } @@ -75,7 +76,7 @@ func (m *Manager) NodeIDs() []uint32 { } // Node returns the node with the given identifier if present. -func (m *Manager) Node(id uint32) (node *Node, found bool) { +func (m *Manager[T]) Node(id T) (node *Node[T], found bool) { m.mu.Lock() defer m.mu.Unlock() node, found = m.lookup[id] @@ -84,39 +85,43 @@ func (m *Manager) Node(id uint32) (node *Node, found bool) { // Nodes returns a slice of each available node. IDs are returned in the same // order as they were provided in the creation of the Manager. -func (m *Manager) Nodes() []*Node { +func (m *Manager[T]) Nodes() []*Node[T] { m.mu.Lock() defer m.mu.Unlock() return m.nodes } // Size returns the number of nodes in the Manager. -func (m *Manager) Size() (nodes int) { +func (m *Manager[T]) Size() (nodes int) { m.mu.Lock() defer m.mu.Unlock() return len(m.nodes) } -func (m *Manager) addNode(node *Node) { +func (m *Manager[T]) addNode(node *Node[T]) { m.mu.Lock() defer m.mu.Unlock() m.lookup[node.id] = node m.nodes = append(m.nodes, node) } -func (m *Manager) newNode(addr string, id uint32) (*Node, error) { +func (m *Manager[T]) newNode(addr string, id T) (*Node[T], error) { if _, found := m.Node(id); found { - return nil, fmt.Errorf("node %d already exists", id) + return nil, fmt.Errorf("node %v already exists", id) } - opts := nodeOptions{ + opts := nodeOptions[T]{ ID: id, SendBufferSize: m.opts.sendBuffer, MsgIDGen: m.getMsgID, Metadata: m.opts.metadata, - PerNodeMD: m.opts.perNodeMD, DialOpts: m.opts.grpcDialOpts, Manager: m, } + if m.opts.perNodeMD != nil { + opts.PerNodeMD = func(id T) metadata.MD { + return m.opts.perNodeMD(id) + } + } n, err := newNode(addr, opts) if err != nil { return nil, err @@ -126,6 +131,6 @@ func (m *Manager) newNode(addr string, id uint32) (*Node, error) { } // getMsgID returns a unique message ID for a new RPC from this client's manager. -func (m *Manager) getMsgID() uint64 { +func (m *Manager[T]) getMsgID() uint64 { return atomic.AddUint64(&m.nextMsgID, 1) } diff --git a/mgr_test.go b/mgr_test.go index 8b3fa9a7..70bf1cde 100644 --- a/mgr_test.go +++ b/mgr_test.go @@ -22,17 +22,17 @@ func TestManagerLogging(t *testing.T) { buf bytes.Buffer logger = log.New(&buf, "logger: ", log.Lshortfile) ) - mgr := NewManager(InsecureDialOptions(t), WithLogger(logger)) + mgr := NewManager[uint32](InsecureDialOptions(t), WithLogger(logger)) t.Cleanup(Closer(t, mgr)) - want := "logger: mgr.go:49: ready" + want := "logger: mgr.go:50: ready" if strings.TrimSpace(buf.String()) != want { t.Errorf("logger: got %q, want %q", buf.String(), want) } } func TestManagerNewNode(t *testing.T) { - mgr := NewManager(InsecureDialOptions(t)) + mgr := NewManager[uint32](InsecureDialOptions(t)) t.Cleanup(Closer(t, mgr)) _, err := NewConfiguration(mgr, WithNodeMap(nodeMap)) @@ -61,7 +61,7 @@ func TestManagerNewNode(t *testing.T) { func TestManagerNewNodeWithConn(t *testing.T) { addrs := TestServers(t, 3, DefaultTestServer) - mgr := NewManager(InsecureDialOptions(t)) + mgr := NewManager[uint32](InsecureDialOptions(t)) t.Cleanup(Closer(t, mgr)) // Create configuration with only first 2 nodes diff --git a/multicast.go b/multicast.go index 58b0ee0a..a93c4e92 100644 --- a/multicast.go +++ b/multicast.go @@ -19,11 +19,11 @@ import ( // option. Use gorums.MapRequest to transform requests per-node. // // This method should be used by generated code only. -func Multicast[Req proto.Message](ctx *ConfigContext, msg Req, method string, opts ...CallOption) error { +func Multicast[T NodeID, Req proto.Message](ctx *ConfigContext[T], msg Req, method string, opts ...CallOption) error { callOpts := getCallOptions(E_Multicast, opts...) waitSendDone := callOpts.mustWaitSendDone() - clientCtx := newClientCtxBuilder[Req, *emptypb.Empty](ctx, msg, method).WithWaitSendDone(waitSendDone).Build() + clientCtx := newClientCtxBuilder[T, Req, *emptypb.Empty](ctx, msg, method).WithWaitSendDone(waitSendDone).Build() clientCtx.applyInterceptors(callOpts.interceptors) // Send messages immediately (multicast doesn't use lazy sending) diff --git a/node.go b/node.go index 6cefce96..35003e10 100644 --- a/node.go +++ b/node.go @@ -1,6 +1,7 @@ package gorums import ( + "cmp" "context" "fmt" "hash/fnv" @@ -15,40 +16,42 @@ import ( const nilAngleString = "" +type NodeID cmp.Ordered + // NodeContext is a context that carries a node for unicast and RPC calls. // It embeds context.Context and provides access to the Node. // // Use [Node.Context] to create a NodeContext from an existing context. -type NodeContext struct { +type NodeContext[T NodeID] struct { context.Context - node *Node + node *Node[T] } // Node returns the Node associated with this context. -func (c NodeContext) Node() *Node { +func (c NodeContext[T]) Node() *Node[T] { return c.node } // enqueue enqueues a request to this node's channel. -func (c NodeContext) enqueue(req request) { +func (c NodeContext[T]) enqueue(req request[T]) { c.node.channel.enqueue(req) } // nextMsgID returns the next message ID from this client's manager. -func (c NodeContext) nextMsgID() uint64 { +func (c NodeContext[T]) nextMsgID() uint64 { return c.node.msgIDGen() } // Node encapsulates the state of a node on which a remote procedure call // can be performed. -type Node struct { +type Node[T NodeID] struct { // Only assigned at creation. - id uint32 + id T addr string - mgr *Manager // only used for backward compatibility to allow Configuration.Manager() + mgr *Manager[T] // only used for backward compatibility to allow Configuration.Manager() msgIDGen func() uint64 - channel *channel + channel *channel[T] } // Context creates a new NodeContext from the given parent context @@ -58,33 +61,33 @@ type Node struct { // // nodeCtx := node.Context(context.Background()) // resp, err := service.GRPCCall(nodeCtx, req) -func (n *Node) Context(parent context.Context) *NodeContext { +func (n *Node[T]) Context(parent context.Context) *NodeContext[T] { if n == nil { panic("gorums: Context called with nil node") } - return &NodeContext{Context: parent, node: n} + return &NodeContext[T]{Context: parent, node: n} } // nodeOptions contains configuration options for creating a new Node. -type nodeOptions struct { - ID uint32 +type nodeOptions[T NodeID] struct { + ID T SendBufferSize uint MsgIDGen func() uint64 Metadata metadata.MD - PerNodeMD func(uint32) metadata.MD + PerNodeMD func(T) metadata.MD DialOpts []grpc.DialOption - Manager *Manager // only used for backward compatibility to allow Configuration.Manager() + Manager *Manager[T] // only used for backward compatibility to allow Configuration.Manager() } // newNode creates a new node using the provided options. // It establishes the connection (lazy dial) and initializes the channel. -func newNode(addr string, opts nodeOptions) (*Node, error) { +func newNode[T NodeID](addr string, opts nodeOptions[T]) (*Node[T], error) { tcpAddr, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return nil, err } - n := &Node{ + n := &Node[T]{ id: opts.ID, addr: tcpAddr.String(), mgr: opts.Manager, @@ -122,7 +125,7 @@ func nodeID(addr string) (uint32, error) { } // close this node. -func (n *Node) close() error { +func (n *Node[T]) close() error { if n.channel != nil { return n.channel.close() } @@ -130,15 +133,16 @@ func (n *Node) close() error { } // ID returns the ID of n. -func (n *Node) ID() uint32 { +func (n *Node[T]) ID() T { if n != nil { return n.id } - return 0 + var zero T + return zero } // Address returns network address of n. -func (n *Node) Address() string { +func (n *Node[T]) Address() string { if n != nil { return n.addr } @@ -146,7 +150,7 @@ func (n *Node) Address() string { } // Host returns the network host of n. -func (n *Node) Host() string { +func (n *Node[T]) Host() string { if n == nil { return nilAngleString } @@ -155,7 +159,7 @@ func (n *Node) Host() string { } // Port returns network port of n. -func (n *Node) Port() string { +func (n *Node[T]) Port() string { if n != nil { _, port, _ := net.SplitHostPort(n.addr) return port @@ -163,7 +167,7 @@ func (n *Node) Port() string { return nilAngleString } -func (n *Node) String() string { +func (n *Node[T]) String() string { if n != nil { return fmt.Sprintf("addr: %s", n.addr) } @@ -172,53 +176,53 @@ func (n *Node) String() string { // FullString returns a more descriptive string representation of n that // includes id, network address and latency information. -func (n *Node) FullString() string { +func (n *Node[T]) FullString() string { if n != nil { - return fmt.Sprintf("node %d | addr: %s", n.id, n.addr) + return fmt.Sprintf("node %v | addr: %s", n.id, n.addr) } return nilAngleString } // LastErr returns the last error encountered (if any) for this node. -func (n *Node) LastErr() error { +func (n *Node[T]) LastErr() error { return n.channel.lastErr() } // Latency returns the latency between the client and this node. -func (n *Node) Latency() time.Duration { +func (n *Node[T]) Latency() time.Duration { return n.channel.channelLatency() } -type lessFunc func(n1, n2 *Node) bool +type lessFunc[T NodeID] func(n1, n2 *Node[T]) bool // MultiSorter implements the Sort interface, sorting the nodes within. -type MultiSorter struct { - nodes []*Node - less []lessFunc +type MultiSorter[T NodeID] struct { + nodes []*Node[T] + less []lessFunc[T] } // Sort sorts the argument slice according to the less functions passed to // OrderedBy. -func (ms *MultiSorter) Sort(nodes []*Node) { +func (ms *MultiSorter[T]) Sort(nodes []*Node[T]) { ms.nodes = nodes sort.Sort(ms) } // OrderedBy returns a Sorter that sorts using the less functions, in order. // Call its Sort method to sort the data. -func OrderedBy(less ...lessFunc) *MultiSorter { - return &MultiSorter{ +func OrderedBy[T NodeID](less ...lessFunc[T]) *MultiSorter[T] { + return &MultiSorter[T]{ less: less, } } // Len is part of sort.Interface. -func (ms *MultiSorter) Len() int { +func (ms *MultiSorter[T]) Len() int { return len(ms.nodes) } // Swap is part of sort.Interface. -func (ms *MultiSorter) Swap(i, j int) { +func (ms *MultiSorter[T]) Swap(i, j int) { ms.nodes[i], ms.nodes[j] = ms.nodes[j], ms.nodes[i] } @@ -227,7 +231,7 @@ func (ms *MultiSorter) Swap(i, j int) { // Less. Note that it can call the less functions twice per call. We // could change the functions to return -1, 0, 1 and reduce the // number of calls for greater efficiency: an exercise for the reader. -func (ms *MultiSorter) Less(i, j int) bool { +func (ms *MultiSorter[T]) Less(i, j int) bool { p, q := ms.nodes[i], ms.nodes[j] // Try all but the last comparison. var k int @@ -249,13 +253,13 @@ func (ms *MultiSorter) Less(i, j int) bool { } // ID sorts nodes by their identifier in increasing order. -var ID = func(n1, n2 *Node) bool { +func ID[T NodeID](n1, n2 *Node[T]) bool { return n1.id < n2.id } // Port sorts nodes by their port number in increasing order. // Warning: This function may be removed in the future. -var Port = func(n1, n2 *Node) bool { +func Port[T NodeID](n1, n2 *Node[T]) bool { p1, _ := strconv.Atoi(n1.Port()) p2, _ := strconv.Atoi(n2.Port()) return p1 < p2 @@ -263,7 +267,7 @@ var Port = func(n1, n2 *Node) bool { // LastNodeError sorts nodes by their LastErr() status in increasing order. A // node with LastErr() != nil is larger than a node with LastErr() == nil. -var LastNodeError = func(n1, n2 *Node) bool { +func LastNodeError[T NodeID](n1, n2 *Node[T]) bool { if n1.channel.lastErr() != nil && n2.channel.lastErr() == nil { return false } diff --git a/node_test.go b/node_test.go index 3a260411..737ae5fd 100644 --- a/node_test.go +++ b/node_test.go @@ -8,31 +8,31 @@ import ( ) func TestNodeSort(t *testing.T) { - nodes := []*Node{ + nodes := []*Node[uint32]{ { id: 100, - channel: &channel{ + channel: &channel[uint32]{ lastError: nil, latency: time.Second, }, }, { id: 101, - channel: &channel{ + channel: &channel[uint32]{ lastError: errors.New("some error"), latency: 250 * time.Millisecond, }, }, { id: 42, - channel: &channel{ + channel: &channel[uint32]{ lastError: nil, latency: 300 * time.Millisecond, }, }, { id: 99, - channel: &channel{ + channel: &channel[uint32]{ lastError: errors.New("some error"), latency: 500 * time.Millisecond, }, @@ -41,7 +41,7 @@ func TestNodeSort(t *testing.T) { n := len(nodes) - OrderedBy(ID).Sort(nodes) + OrderedBy(ID[uint32]).Sort(nodes) for i := n - 1; i > 0; i-- { if nodes[i].id < nodes[i-1].id { t.Error("by id: not sorted") @@ -49,7 +49,7 @@ func TestNodeSort(t *testing.T) { } } - OrderedBy(LastNodeError).Sort(nodes) + OrderedBy(LastNodeError[uint32]).Sort(nodes) for i := n - 1; i > 0; i-- { if nodes[i].LastErr() == nil && nodes[i-1].LastErr() != nil { t.Error("by error: not sorted") @@ -58,7 +58,7 @@ func TestNodeSort(t *testing.T) { } } -func printNodes(t *testing.T, nodes []*Node) { +func printNodes(t *testing.T, nodes []*Node[uint32]) { t.Helper() for i, n := range nodes { nodeStr := fmt.Sprintf( diff --git a/opts.go b/opts.go index bc68e19d..1a59c1cd 100644 --- a/opts.go +++ b/opts.go @@ -24,7 +24,7 @@ type managerOptions struct { backoff backoff.Config sendBuffer uint metadata metadata.MD - perNodeMD func(uint32) metadata.MD + perNodeMD func(any) metadata.MD } func newManagerOptions() managerOptions { @@ -75,10 +75,16 @@ func WithMetadata(md metadata.MD) ManagerOption { } } -// WithPerNodeMetadata returns a ManagerOption that allows you to set metadata for each -// node individually. -func WithPerNodeMetadata(f func(uint32) metadata.MD) ManagerOption { +// WithPerNodeMetadata returns a ManagerOption that sets the metadata that is sent to each node +// when the connection is initially established. The metadata is generated by the provided +// function, which takes the node's ID as input. +func WithPerNodeMetadata[T NodeID](f func(T) metadata.MD) ManagerOption { return func(o *managerOptions) { - o.perNodeMD = f + o.perNodeMD = func(id any) metadata.MD { + if idVal, ok := id.(T); ok { + return f(idVal) + } + return nil + } } } diff --git a/ordering/ordering.pb.go b/ordering/ordering.pb.go index e4f7cc85..bd60796d 100644 --- a/ordering/ordering.pb.go +++ b/ordering/ordering.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.2 +// protoc v6.33.4 // source: ordering/ordering.proto package ordering diff --git a/ordering/ordering_grpc.pb.go b/ordering/ordering_grpc.pb.go index 4c3c6384..a1b79e9f 100644 --- a/ordering/ordering_grpc.pb.go +++ b/ordering/ordering_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.6.0 -// - protoc v6.33.2 +// - protoc v6.33.4 // source: ordering/ordering.proto package ordering diff --git a/quorumcall.go b/quorumcall.go index afbc91ac..7226ed93 100644 --- a/quorumcall.go +++ b/quorumcall.go @@ -4,6 +4,7 @@ package gorums // that provides access to node responses via terminal methods and fluent iteration. // // Type parameters: +// - T: The node ID type (e.g., uint32) // - Req: The request message type // - Resp: The response message type from individual nodes // @@ -16,13 +17,13 @@ package gorums // This lazy sending is necessary to allow interceptors to register transformations prior to dispatch. // // This function should be used by generated code only. -func QuorumCall[Req, Resp msg]( - ctx *ConfigContext, +func QuorumCall[T NodeID, Req, Resp msg]( + ctx *ConfigContext[T], req Req, method string, opts ...CallOption, -) *Responses[Resp] { - return invokeQuorumCall[Req, Resp](ctx, req, method, false, opts...) +) *Responses[T, Resp] { + return invokeQuorumCall[T, Req, Resp](ctx, req, method, false, opts...) } // QuorumCallStream performs a streaming quorum call and returns a Responses object. @@ -32,25 +33,25 @@ func QuorumCall[Req, Resp msg]( // is canceled, allowing the server to send multiple responses over time. // // This function should be used by generated code only. -func QuorumCallStream[Req, Resp msg]( - ctx *ConfigContext, +func QuorumCallStream[T NodeID, Req, Resp msg]( + ctx *ConfigContext[T], req Req, method string, opts ...CallOption, -) *Responses[Resp] { - return invokeQuorumCall[Req, Resp](ctx, req, method, true, opts...) +) *Responses[T, Resp] { + return invokeQuorumCall[T, Req, Resp](ctx, req, method, true, opts...) } // invokeQuorumCall is the internal implementation shared by QuorumCall and QuorumCallStream. -func invokeQuorumCall[Req, Resp msg]( - ctx *ConfigContext, +func invokeQuorumCall[T NodeID, Req, Resp msg]( + ctx *ConfigContext[T], req Req, method string, streaming bool, opts ...CallOption, -) *Responses[Resp] { +) *Responses[T, Resp] { callOpts := getCallOptions(E_Quorumcall, opts...) - builder := newClientCtxBuilder[Req, Resp](ctx, req, method) + builder := newClientCtxBuilder[T, Req, Resp](ctx, req, method) if streaming { builder = builder.WithStreaming() } diff --git a/quorumcall_test.go b/quorumcall_test.go index 129949cb..b5de2c4c 100644 --- a/quorumcall_test.go +++ b/quorumcall_test.go @@ -51,7 +51,7 @@ func checkQuorumCall(t *testing.T, gotErr, wantErr error, expectedNodeErrors ... func TestQuorumCall(t *testing.T) { // type alias short hand for the responses type - type respType = *gorums.Responses[*pb.StringValue] + type respType = *gorums.Responses[uint32, *pb.StringValue] tests := []struct { name string call func(respType) (*pb.StringValue, error) @@ -82,7 +82,7 @@ func TestQuorumCall(t *testing.T) { t.Run(tt.name, func(t *testing.T) { config := gorums.TestConfiguration(t, tt.numNodes, gorums.EchoServerFn) ctx := gorums.TestContext(t, 2*time.Second) - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.TestMethod, @@ -104,19 +104,19 @@ func TestQuorumCallPartialFailures(t *testing.T) { // Function type for the call to test type callInfo struct { name string - callFunc func(*gorums.ConfigContext, *pb.StringValue) error + callFunc func(*gorums.ConfigContext[uint32], *pb.StringValue) error } const numServers = 3 - type respType = *gorums.Responses[*pb.StringValue] + type respType = *gorums.Responses[uint32, *pb.StringValue] // Helper to create QuorumCall variants quorumcall := func(name string, aggregateFunc func(respType) (*pb.StringValue, error)) callInfo { return callInfo{ name: "QuorumCall/" + name, - callFunc: func(ctx *gorums.ConfigContext, req *pb.StringValue) error { - _, err := aggregateFunc(gorums.QuorumCall[*pb.StringValue, *pb.StringValue](ctx, req, mock.TestMethod)) + callFunc: func(ctx *gorums.ConfigContext[uint32], req *pb.StringValue) error { + _, err := aggregateFunc(gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue](ctx, req, mock.TestMethod)) return err }, } @@ -126,7 +126,7 @@ func TestQuorumCallPartialFailures(t *testing.T) { multicast := func(name string, opts ...gorums.CallOption) callInfo { return callInfo{ name: "Multicast/" + name, - callFunc: func(ctx *gorums.ConfigContext, req *pb.StringValue) error { + callFunc: func(ctx *gorums.ConfigContext[uint32], req *pb.StringValue) error { return gorums.Multicast(ctx, req, mock.TestMethod, opts...) }, } @@ -210,7 +210,7 @@ func TestQuorumCallCustomAggregation(t *testing.T) { config := gorums.TestConfiguration(t, 3, gorums.DefaultTestServer) // uses default server that returns (i+1)*10 ctx := gorums.TestContext(t, 2*time.Second) - responses := gorums.QuorumCall[*pb.Int32Value, *pb.Int32Value]( + responses := gorums.QuorumCall[uint32, *pb.Int32Value, *pb.Int32Value]( config.Context(ctx), pb.Int32(0), mock.GetValueMethod, @@ -233,7 +233,7 @@ func TestQuorumCallCollectAll(t *testing.T) { config := gorums.TestConfiguration(t, 3, gorums.EchoServerFn) ctx := gorums.TestContext(t, 2*time.Second) - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( config.Context(ctx), pb.String("test"), mock.TestMethod, @@ -252,7 +252,7 @@ func TestQuorumCallSynctest(t *testing.T) { ctx := gorums.TestContext(t, 2*time.Second) cfgCtx := config.Context(ctx) - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("synctest-demo"), mock.TestMethod, @@ -278,13 +278,13 @@ func TestQuorumCallAsyncSynctest(t *testing.T) { cfgCtx := config.Context(ctx) // Test async operations - future1 := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + future1 := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("async1"), mock.TestMethod, ).AsyncMajority() - future2 := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + future2 := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("async2"), mock.TestMethod, @@ -321,7 +321,7 @@ func BenchmarkQuorumCallTerminalMethods(b *testing.B) { b.Run(fmt.Sprintf("Majority/%d", numNodes), func(b *testing.B) { b.ReportAllocs() for b.Loop() { - resp, err := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + resp, err := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, @@ -337,7 +337,7 @@ func BenchmarkQuorumCallTerminalMethods(b *testing.B) { b.ReportAllocs() threshold := numNodes/2 + 1 for b.Loop() { - resp, err := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + resp, err := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, @@ -352,7 +352,7 @@ func BenchmarkQuorumCallTerminalMethods(b *testing.B) { b.Run(fmt.Sprintf("First/%d", numNodes), func(b *testing.B) { b.ReportAllocs() for b.Loop() { - resp, err := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + resp, err := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, @@ -367,7 +367,7 @@ func BenchmarkQuorumCallTerminalMethods(b *testing.B) { b.Run(fmt.Sprintf("All/%d", numNodes), func(b *testing.B) { b.ReportAllocs() for b.Loop() { - resp, err := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + resp, err := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, @@ -392,7 +392,7 @@ func BenchmarkQuorumCall(b *testing.B) { b.ReportAllocs() quorum := numNodes/2 + 1 for b.Loop() { - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, @@ -414,7 +414,7 @@ func BenchmarkQuorumCall(b *testing.B) { b.ReportAllocs() quorum := numNodes/2 + 1 for b.Loop() { - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, @@ -436,7 +436,7 @@ func BenchmarkQuorumCall(b *testing.B) { b.ReportAllocs() quorum := numNodes/2 + 1 for b.Loop() { - responses := gorums.QuorumCall[*pb.StringValue, *pb.StringValue]( + responses := gorums.QuorumCall[uint32, *pb.StringValue, *pb.StringValue]( cfgCtx, pb.String("test"), mock.TestMethod, diff --git a/responses.go b/responses.go index 856e4b7a..b1dcb20b 100644 --- a/responses.go +++ b/responses.go @@ -13,8 +13,8 @@ type msg = proto.Message // Iterator Helpers // ------------------------------------------------------------------------- -// ResponseSeq is an iterator that yields NodeResponse[T] values from a quorum call. -type ResponseSeq[T msg] iter.Seq[NodeResponse[T]] +// ResponseSeq is an iterator that yields NodeResponse[T, Resp] values from a quorum call. +type ResponseSeq[T NodeID, Resp msg] iter.Seq[NodeResponse[T, Resp]] // IgnoreErrors returns an iterator that yields only successful responses, // discarding any responses with errors. This is useful when you want to process @@ -28,8 +28,8 @@ type ResponseSeq[T msg] iter.Seq[NodeResponse[T]] // // resp is guaranteed to be a successful response // sum += resp.Value.GetValue() // } -func (seq ResponseSeq[Resp]) IgnoreErrors() ResponseSeq[Resp] { - return func(yield func(NodeResponse[Resp]) bool) { +func (seq ResponseSeq[T, Resp]) IgnoreErrors() ResponseSeq[T, Resp] { + return func(yield func(NodeResponse[T, Resp]) bool) { for result := range seq { if result.Err == nil { if !yield(result) { @@ -53,8 +53,8 @@ func (seq ResponseSeq[Resp]) IgnoreErrors() ResponseSeq[Resp] { // }) { // // process resp // } -func (seq ResponseSeq[Resp]) Filter(keep func(NodeResponse[Resp]) bool) ResponseSeq[Resp] { - return func(yield func(NodeResponse[Resp]) bool) { +func (seq ResponseSeq[T, Resp]) Filter(keep func(NodeResponse[T, Resp]) bool) ResponseSeq[T, Resp] { + return func(yield func(NodeResponse[T, Resp]) bool) { for result := range seq { if keep(result) { if !yield(result) { @@ -76,8 +76,8 @@ func (seq ResponseSeq[Resp]) Filter(keep func(NodeResponse[Resp]) bool) Response // replies := responses.CollectN(2) // // or collect 2 successful responses // replies = responses.IgnoreErrors().CollectN(2) -func (seq ResponseSeq[Resp]) CollectN(n int) map[uint32]Resp { - replies := make(map[uint32]Resp, n) +func (seq ResponseSeq[T, Resp]) CollectN(n int) map[T]Resp { + replies := make(map[T]Resp, n) for result := range seq { replies[result.NodeID] = result.Value if len(replies) >= n { @@ -97,8 +97,8 @@ func (seq ResponseSeq[Resp]) CollectN(n int) map[uint32]Resp { // replies := responses.CollectAll() // // or collect all successful responses // replies = responses.IgnoreErrors().CollectAll() -func (seq ResponseSeq[Resp]) CollectAll() map[uint32]Resp { - replies := make(map[uint32]Resp) +func (seq ResponseSeq[T, Resp]) CollectAll() map[T]Resp { + replies := make(map[T]Resp) for result := range seq { replies[result.NodeID] = result.Value } @@ -119,15 +119,16 @@ func (seq ResponseSeq[Resp]) CollectAll() map[uint32]Resp { // replies := ReadQuorumCall(ctx, req).IgnoreErrors().CollectAll() // // Type parameter: +// - T: The node ID type // - Resp: The response message type -type Responses[Resp msg] struct { - ResponseSeq[Resp] +type Responses[T NodeID, Resp msg] struct { + ResponseSeq[T, Resp] size int sendNow func() // sendNow triggers immediate sending of requests } -func NewResponses[Req, Resp msg](ctx *ClientCtx[Req, Resp]) *Responses[Resp] { - return &Responses[Resp]{ +func NewResponses[T NodeID, Req, Resp msg](ctx *ClientCtx[T, Req, Resp]) *Responses[T, Resp] { + return &Responses[T, Resp]{ ResponseSeq: ctx.responseSeq, size: ctx.Size(), sendNow: func() { ctx.sendOnce.Do(ctx.send) }, @@ -135,7 +136,7 @@ func NewResponses[Req, Resp msg](ctx *ClientCtx[Req, Resp]) *Responses[Resp] { } // Size returns the number of nodes in the configuration. -func (r *Responses[Resp]) Size() int { +func (r *Responses[T, Resp]) Size() int { return r.size } @@ -157,7 +158,7 @@ func (r *Responses[Resp]) Size() int { // } // // Process result.Value // } -func (r *Responses[Resp]) Seq() ResponseSeq[Resp] { +func (r *Responses[T, Resp]) Seq() ResponseSeq[T, Resp] { return r.ResponseSeq } @@ -167,26 +168,26 @@ func (r *Responses[Resp]) Seq() ResponseSeq[Resp] { // First returns the first successful response received from any node. // This is useful for read-any patterns where any single response is sufficient. -func (r *Responses[Resp]) First() (Resp, error) { +func (r *Responses[T, Resp]) First() (Resp, error) { return r.Threshold(1) } // Majority returns the first response once a simple majority (⌈(n+1)/2⌉) // of successful responses are received. -func (r *Responses[Resp]) Majority() (Resp, error) { +func (r *Responses[T, Resp]) Majority() (Resp, error) { quorumSize := r.size/2 + 1 return r.Threshold(quorumSize) } // All returns the first response once all nodes have responded successfully. // If any node fails, it returns an error. -func (r *Responses[Resp]) All() (Resp, error) { +func (r *Responses[T, Resp]) All() (Resp, error) { return r.Threshold(r.size) } // Threshold waits for a threshold number of successful responses. // It returns the first response once the threshold is reached. -func (r *Responses[Resp]) Threshold(threshold int) (resp Resp, err error) { +func (r *Responses[T, Resp]) Threshold(threshold int) (resp Resp, err error) { var ( count int errs []nodeError diff --git a/responses_test.go b/responses_test.go index 30db3d2c..a4e17ba8 100644 --- a/responses_test.go +++ b/responses_test.go @@ -10,21 +10,21 @@ import ( // makeClientCtx is a helper to create a ClientCtx with mock responses for unit tests. // It creates a channel with the provided responses and returns a ClientCtx. -func makeClientCtx[Req, Resp proto.Message](t *testing.T, numNodes int, responses []NodeResponse[proto.Message]) *ClientCtx[Req, Resp] { +func makeClientCtx[Req, Resp proto.Message](t *testing.T, numNodes int, responses []NodeResponse[uint32, proto.Message]) *ClientCtx[uint32, Req, Resp] { t.Helper() - resultChan := make(chan NodeResponse[proto.Message], len(responses)) + resultChan := make(chan NodeResponse[uint32, proto.Message], len(responses)) for _, r := range responses { resultChan <- r } close(resultChan) - config := make(Configuration, numNodes) + config := make(Configuration[uint32], numNodes) for i := range numNodes { - config[i] = &Node{id: uint32(i + 1)} + config[i] = &Node[uint32]{id: uint32(i + 1)} } - c := &ClientCtx[Req, Resp]{ + c := &ClientCtx[uint32, Req, Resp]{ Context: t.Context(), config: config, replyChan: resultChan, @@ -63,11 +63,11 @@ func checkError(t *testing.T, wantErr bool, err, wantErrType error) bool { // TestTerminalMethods tests the terminal methods on Responses func TestTerminalMethods(t *testing.T) { - type respType = *Responses[*pb.StringValue] + type respType = *Responses[uint32, *pb.StringValue] tests := []struct { name string numNodes int - responses []NodeResponse[proto.Message] + responses []NodeResponse[uint32, proto.Message] call func(resp respType) (*pb.StringValue, error) wantValue string wantErr bool @@ -77,7 +77,7 @@ func TestTerminalMethods(t *testing.T) { { name: "First_Success", numNodes: 3, - responses: []NodeResponse[proto.Message]{ + responses: []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, }, call: respType.First, @@ -86,7 +86,7 @@ func TestTerminalMethods(t *testing.T) { { name: "First_Error", numNodes: 3, - responses: []NodeResponse[proto.Message]{ + responses: []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: nil, Err: errors.New("node error")}, {NodeID: 2, Value: nil, Err: errors.New("node error")}, {NodeID: 3, Value: nil, Err: errors.New("node error")}, @@ -99,7 +99,7 @@ func TestTerminalMethods(t *testing.T) { { name: "Majority_Success_3Nodes", numNodes: 3, - responses: []NodeResponse[proto.Message]{ + responses: []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, }, @@ -109,7 +109,7 @@ func TestTerminalMethods(t *testing.T) { { name: "Majority_Insufficient", numNodes: 3, - responses: []NodeResponse[proto.Message]{ + responses: []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: nil, Err: errors.New("node error")}, {NodeID: 3, Value: nil, Err: errors.New("node error")}, @@ -121,7 +121,7 @@ func TestTerminalMethods(t *testing.T) { { name: "Majority_Even_Success", numNodes: 4, - responses: []NodeResponse[proto.Message]{ + responses: []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -133,7 +133,7 @@ func TestTerminalMethods(t *testing.T) { { name: "All_Success", numNodes: 3, - responses: []NodeResponse[proto.Message]{ + responses: []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -144,7 +144,7 @@ func TestTerminalMethods(t *testing.T) { { name: "All_PartialFailure", numNodes: 3, - responses: []NodeResponse[proto.Message]{ + responses: []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: nil, Err: errors.New("node error")}, @@ -172,11 +172,11 @@ func TestTerminalMethods(t *testing.T) { } func TestTerminalMethodsThreshold(t *testing.T) { - type respType = *Responses[*pb.StringValue] + type respType = *Responses[uint32, *pb.StringValue] tests := []struct { name string numNodes int - responses []NodeResponse[proto.Message] + responses []NodeResponse[uint32, proto.Message] call func(resp respType, threshold int) (*pb.StringValue, error) threshold int wantValue string @@ -186,7 +186,7 @@ func TestTerminalMethodsThreshold(t *testing.T) { { name: "Threshold_Success", numNodes: 3, - responses: []NodeResponse[proto.Message]{ + responses: []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, }, @@ -197,7 +197,7 @@ func TestTerminalMethodsThreshold(t *testing.T) { { name: "Threshold_Insufficient", numNodes: 3, - responses: []NodeResponse[proto.Message]{ + responses: []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: nil, Err: errors.New("node error")}, {NodeID: 3, Value: nil, Err: errors.New("node error")}, @@ -232,7 +232,7 @@ func TestTerminalMethodsThreshold(t *testing.T) { // TestIteratorMethods tests the iterator helper methods func TestIteratorMethods(t *testing.T) { t.Run("IgnoreErrors", func(t *testing.T) { - responses := []NodeResponse[proto.Message]{ + responses := []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: nil, Err: errors.New("node error")}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -250,7 +250,7 @@ func TestIteratorMethods(t *testing.T) { }) t.Run("Filter", func(t *testing.T) { - responses := []NodeResponse[proto.Message]{ + responses := []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -260,7 +260,7 @@ func TestIteratorMethods(t *testing.T) { // Filter to only node 2 var count int - for resp := range r.Seq().Filter(func(resp NodeResponse[*pb.StringValue]) bool { + for resp := range r.Seq().Filter(func(resp NodeResponse[uint32, *pb.StringValue]) bool { return resp.NodeID == 2 }) { count++ @@ -274,7 +274,7 @@ func TestIteratorMethods(t *testing.T) { }) t.Run("CollectN", func(t *testing.T) { - responses := []NodeResponse[proto.Message]{ + responses := []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -289,7 +289,7 @@ func TestIteratorMethods(t *testing.T) { }) t.Run("CollectAll", func(t *testing.T) { - responses := []NodeResponse[proto.Message]{ + responses := []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -313,7 +313,7 @@ func TestIteratorMethods(t *testing.T) { func TestCustomAggregation(t *testing.T) { t.Run("SameTypeAggregation", func(t *testing.T) { // Aggregation function that returns the same type (Resp -> Resp) - majorityQF := func(resp *Responses[*pb.StringValue]) (*pb.StringValue, error) { + majorityQF := func(resp *Responses[uint32, *pb.StringValue]) (*pb.StringValue, error) { replies := resp.IgnoreErrors().CollectN(2) if len(replies) < 2 { return nil, ErrIncomplete @@ -324,7 +324,7 @@ func TestCustomAggregation(t *testing.T) { return nil, ErrIncomplete } - responses := []NodeResponse[proto.Message]{ + responses := []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -345,7 +345,7 @@ func TestCustomAggregation(t *testing.T) { t.Run("CustomReturnType", func(t *testing.T) { // Aggregation function that returns a different type (Resp -> []string) // This demonstrates the key benefit: Out can differ from In - collectAllValues := func(resp *Responses[*pb.StringValue]) ([]string, error) { + collectAllValues := func(resp *Responses[uint32, *pb.StringValue]) ([]string, error) { replies := resp.IgnoreErrors().CollectAll() if len(replies) == 0 { return nil, ErrIncomplete @@ -357,7 +357,7 @@ func TestCustomAggregation(t *testing.T) { return result, nil } - responses := []NodeResponse[proto.Message]{ + responses := []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("alpha"), Err: nil}, {NodeID: 2, Value: pb.String("beta"), Err: nil}, {NodeID: 3, Value: pb.String("gamma"), Err: nil}, @@ -377,9 +377,9 @@ func TestCustomAggregation(t *testing.T) { t.Run("WithFiltering", func(t *testing.T) { // Aggregation function that uses filtering and custom logic - filterAndCount := func(resp *Responses[*pb.StringValue]) (int, error) { + filterAndCount := func(resp *Responses[uint32, *pb.StringValue]) (int, error) { count := 0 - for range resp.IgnoreErrors().Filter(func(r NodeResponse[*pb.StringValue]) bool { + for range resp.IgnoreErrors().Filter(func(r NodeResponse[uint32, *pb.StringValue]) bool { return r.NodeID > 1 // Only nodes 2 and 3 }) { count++ @@ -390,7 +390,7 @@ func TestCustomAggregation(t *testing.T) { return count, nil } - responses := []NodeResponse[proto.Message]{ + responses := []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: pb.String("response2"), Err: nil}, {NodeID: 3, Value: pb.String("response3"), Err: nil}, @@ -410,7 +410,7 @@ func TestCustomAggregation(t *testing.T) { t.Run("ErrorHandling", func(t *testing.T) { // Aggregation function that handles errors explicitly - requireAllSuccess := func(resp *Responses[*pb.StringValue]) (*pb.StringValue, error) { + requireAllSuccess := func(resp *Responses[uint32, *pb.StringValue]) (*pb.StringValue, error) { var first *pb.StringValue for r := range resp.Seq() { if r.Err != nil { @@ -426,7 +426,7 @@ func TestCustomAggregation(t *testing.T) { return first, nil } - responses := []NodeResponse[proto.Message]{ + responses := []NodeResponse[uint32, proto.Message]{ {NodeID: 1, Value: pb.String("response1"), Err: nil}, {NodeID: 2, Value: nil, Err: errors.New("node 2 failed")}, } diff --git a/rpc.go b/rpc.go index d7b45902..e13eae13 100644 --- a/rpc.go +++ b/rpc.go @@ -8,10 +8,10 @@ import ( // RPCCall executes a remote procedure call on the node. // // This method should be used by generated code only. -func RPCCall(ctx *NodeContext, msg proto.Message, method string) (proto.Message, error) { +func RPCCall[T NodeID](ctx *NodeContext[T], msg proto.Message, method string) (proto.Message, error) { md := ordering.NewGorumsMetadata(ctx, ctx.nextMsgID(), method) - replyChan := make(chan NodeResponse[proto.Message], 1) - ctx.enqueue(request{ctx: ctx, msg: NewRequestMessage(md, msg), responseChan: replyChan}) + replyChan := make(chan NodeResponse[T, proto.Message], 1) + ctx.enqueue(request[T]{ctx: ctx, msg: NewRequestMessage(md, msg), responseChan: replyChan}) select { case r := <-replyChan: diff --git a/server_test.go b/server_test.go index 49626797..b0291ac7 100644 --- a/server_test.go +++ b/server_test.go @@ -120,7 +120,7 @@ func TestTCPReconnection(t *testing.T) { _ = srv.Serve(lis) }() - mgr := gorums.NewManager(gorums.InsecureDialOptions(t)) + mgr := gorums.NewManager[uint32](gorums.InsecureDialOptions(t)) t.Cleanup(gorums.Closer(t, mgr)) cfg, err := gorums.NewConfiguration(mgr, gorums.WithNodeList([]string{addr})) diff --git a/testing_bufconn.go b/testing_bufconn.go index 66a25f0e..c05b5ab3 100644 --- a/testing_bufconn.go +++ b/testing_bufconn.go @@ -50,7 +50,7 @@ var testBufconnDialers = make(map[testing.TB]func(context.Context, string) (net. // getOrCreateManager returns the existing manager or creates a new one with bufconn dialing. // If a new manager is created, its cleanup is registered via t.Cleanup. -func (to *testOptions) getOrCreateManager(t testing.TB) *Manager { +func (to *testOptions) getOrCreateManager(t testing.TB) *Manager[uint32] { if to.existingMgr != nil { // Don't register cleanup - caller is responsible for closing the manager return to.existingMgr @@ -68,7 +68,7 @@ func (to *testOptions) getOrCreateManager(t testing.TB) *Manager { grpc.WithTransportCredentials(insecure.NewCredentials()), } mgrOpts := append([]ManagerOption{WithDialOptions(dialOpts...)}, to.managerOpts...) - mgr := NewManager(mgrOpts...) + mgr := NewManager[uint32](mgrOpts...) t.Cleanup(func() { Closer(t, mgr)() }) return mgr } diff --git a/testing_integration.go b/testing_integration.go index e37e9d75..63f5b32e 100644 --- a/testing_integration.go +++ b/testing_integration.go @@ -24,14 +24,14 @@ func testSetupServers(t testing.TB, numServers int, srvFn func(i int) ServerIfac // getOrCreateManager returns the existing manager or creates a new one with real network dialing. // If a new manager is created, its cleanup is registered via t.Cleanup. -func (to *testOptions) getOrCreateManager(t testing.TB) *Manager { +func (to *testOptions) getOrCreateManager(t testing.TB) *Manager[uint32] { if to.existingMgr != nil { // Don't register cleanup - caller is responsible for closing the manager return to.existingMgr } // Create manager and register its cleanup LAST so it runs FIRST (LIFO) mgrOpts := append([]ManagerOption{InsecureDialOptions(t)}, to.managerOpts...) - mgr := NewManager(mgrOpts...) + mgr := NewManager[uint32](mgrOpts...) t.Cleanup(Closer(t, mgr)) return mgr } diff --git a/testing_shared.go b/testing_shared.go index 1956bd01..514042b8 100644 --- a/testing_shared.go +++ b/testing_shared.go @@ -59,7 +59,7 @@ func TestQuorumCallError(_ testing.TB, nodeErrors map[uint32]error) QuorumCallEr // // This is the recommended way to set up tests that need both servers and a configuration. // It ensures proper cleanup and detects goroutine leaks. -func TestConfiguration(t testing.TB, numServers int, srvFn func(i int) ServerIface, opts ...TestOption) Configuration { +func TestConfiguration(t testing.TB, numServers int, srvFn func(i int) ServerIface, opts ...TestOption) Configuration[uint32] { t.Helper() testOpts := extractTestOptions(opts) @@ -104,7 +104,7 @@ func TestConfiguration(t testing.TB, numServers int, srvFn func(i int) ServerIfa // // This is the recommended way to set up tests that need only a single server node. // It ensures proper cleanup and detects goroutine leaks. -func TestNode(t testing.TB, srvFn func(i int) ServerIface, opts ...TestOption) *Node { +func TestNode(t testing.TB, srvFn func(i int) ServerIface, opts ...TestOption) *Node[uint32] { t.Helper() return TestConfiguration(t, 1, srvFn, opts...).Nodes()[0] } diff --git a/testopts.go b/testopts.go index 0fc208ec..ac88bcb2 100644 --- a/testopts.go +++ b/testopts.go @@ -21,8 +21,8 @@ type TestOption any type testOptions struct { managerOpts []ManagerOption serverOpts []ServerOption - nodeListOpts []NodeListOption - existingMgr *Manager + nodeListOpts []NodeListOption[uint32] + existingMgr *Manager[uint32] stopFuncPtr *func(...int) // pointer to capture the variadic stop function preConnectHook func(stopFn func()) // called before connecting to servers skipGoleak bool // skip goleak checks (useful for synctest) @@ -57,7 +57,7 @@ func (to *testOptions) serverFunc(srvFn func(i int) ServerIface) func(i int) Ser // nodeListOption returns the appropriate NodeListOption for the configuration. // It uses provided options if available, or generates defaults based on whether // an existing manager is being reused. -func (to *testOptions) nodeListOption(addrs []string) NodeListOption { +func (to *testOptions) nodeListOption(addrs []string) NodeListOption[uint32] { if len(to.nodeListOpts) > 0 { // Use the last provided NodeListOption (allows overriding) return to.nodeListOpts[len(to.nodeListOpts)-1] @@ -84,9 +84,9 @@ func extractTestOptions(opts []TestOption) testOptions { result.managerOpts = append(result.managerOpts, o) case ServerOption: result.serverOpts = append(result.serverOpts, o) - case NodeListOption: + case NodeListOption[uint32]: result.nodeListOpts = append(result.nodeListOpts, o) - case *Manager: + case *Manager[uint32]: result.existingMgr = o case stopFuncProvider: result.stopFuncPtr = o.stopFunc @@ -107,7 +107,7 @@ func extractTestOptions(opts []TestOption) testOptions { // SetupConfiguration will NOT register a cleanup function for the manager. // // This option is intended for testing purposes only. -func WithManager(_ testing.TB, mgr *Manager) TestOption { +func WithManager(_ testing.TB, mgr *Manager[uint32]) TestOption { if mgr == nil { panic("gorums: WithManager called with nil manager") } diff --git a/unicast.go b/unicast.go index 55d17633..d6373b71 100644 --- a/unicast.go +++ b/unicast.go @@ -16,7 +16,7 @@ import ( // enqueueing the message (fire-and-forget semantics). // // This method should be used by generated code only. -func Unicast[Req proto.Message](ctx *NodeContext, req Req, method string, opts ...CallOption) error { +func Unicast[T NodeID, Req proto.Message](ctx *NodeContext[T], req Req, method string, opts ...CallOption) error { callOpts := getCallOptions(E_Unicast, opts...) md := ordering.NewGorumsMetadata(ctx, ctx.nextMsgID(), method) msg := NewRequestMessage(md, req) @@ -24,13 +24,13 @@ func Unicast[Req proto.Message](ctx *NodeContext, req Req, method string, opts . waitSendDone := callOpts.mustWaitSendDone() if !waitSendDone { // Fire-and-forget: enqueue and return immediately - ctx.enqueue(request{ctx: ctx, msg: msg}) + ctx.enqueue(request[T]{ctx: ctx, msg: msg}) return nil } // Default: block until send completes - replyChan := make(chan NodeResponse[proto.Message], 1) - ctx.enqueue(request{ctx: ctx, msg: msg, waitSendDone: true, responseChan: replyChan}) + replyChan := make(chan NodeResponse[T, proto.Message], 1) + ctx.enqueue(request[T]{ctx: ctx, msg: msg, waitSendDone: true, responseChan: replyChan}) // Wait for send confirmation select {