diff --git a/.gitignore b/.gitignore index 8c310b3c..464d20fe 100644 --- a/.gitignore +++ b/.gitignore @@ -38,5 +38,7 @@ branch-compare-* tmp/* doc/task-*.md doc/issue-*.md +doc/review-*.md .claude/settings.local.json .superset/config.json +.claude/agent-memory/ diff --git a/.vscode/gorums.txt b/.vscode/gorums.txt index 4eea5bce..53dc1808 100644 --- a/.vscode/gorums.txt +++ b/.vscode/gorums.txt @@ -116,6 +116,7 @@ testutils timestamppb tmpl Tormod +Twoway ucast unexport Unexported diff --git a/callopts.go b/callopts.go index 5c7efc41..758fb38c 100644 --- a/callopts.go +++ b/callopts.go @@ -2,31 +2,18 @@ package gorums import ( "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/runtime/protoimpl" ) type callOptions struct { - callType *protoimpl.ExtensionInfo ignoreErrors bool interceptors []any // Type-erased interceptors, restored by QuorumCall } -// mustWaitSendDone returns true if the caller of a one-way call type must wait -// for send completion. This is the default behavior unless the IgnoreErrors -// call option is set. This always returns false for two-way call types, since -// they should always wait for actual server responses. -func (o callOptions) mustWaitSendDone() bool { - // must wait for send completion if we are not ignoring errors - // and the call type is Unicast or Multicast - return !o.ignoreErrors && (o.callType == E_Unicast || o.callType == E_Multicast) -} - // CallOption is a function that sets a value in the given callOptions struct type CallOption func(*callOptions) -func getCallOptions(callType *protoimpl.ExtensionInfo, opts ...CallOption) callOptions { +func getCallOptions(opts ...CallOption) callOptions { o := callOptions{ - callType: callType, ignoreErrors: false, // default: return error and wait for send completion } for _, opt := range opts { @@ -46,8 +33,8 @@ func IgnoreErrors() CallOption { } // Interceptors returns a CallOption that adds quorum call interceptors. -// Interceptors are executed in the order provided, modifying the Responses object -// before the user calls a terminal method. +// Interceptors are executed in the order provided, modifying the Responses +// object before the user calls a terminal method. // // Example: // diff --git a/callopts_test.go b/callopts_test.go index 35418268..6dc191fa 100644 --- a/callopts_test.go +++ b/callopts_test.go @@ -3,35 +3,67 @@ package gorums import ( "fmt" "testing" + "time" + + "github.com/relab/gorums/internal/testutils/mock" + pb "google.golang.org/protobuf/types/known/wrapperspb" ) -func TestCallOptionsMustWaitSendDone(t *testing.T) { +func TestCallOptionsIgnoreErrors(t *testing.T) { tests := []struct { name string callOpts callOptions - wantWaitSendDone bool + wantIgnoreErrors bool }{ - // One-way call types - {name: "Unicast/Default", callOpts: getCallOptions(E_Unicast), wantWaitSendDone: true}, - {name: "Unicast/IgnoreErrors", callOpts: getCallOptions(E_Unicast, IgnoreErrors()), wantWaitSendDone: false}, - {name: "Multicast/Default", callOpts: getCallOptions(E_Multicast), wantWaitSendDone: true}, - {name: "Multicast/IgnoreErrors", callOpts: getCallOptions(E_Multicast, IgnoreErrors()), wantWaitSendDone: false}, - // Two-way call types (never wait for send completion, regardless of option) - {name: "Rpc/Default", callOpts: getCallOptions(E_Rpc), wantWaitSendDone: false}, - {name: "Rpc/IgnoreErrors", callOpts: getCallOptions(E_Rpc, IgnoreErrors()), wantWaitSendDone: false}, - {name: "Quorumcall/Default", callOpts: getCallOptions(E_Quorumcall), wantWaitSendDone: false}, - {name: "Quorumcall/IgnoreErrors", callOpts: getCallOptions(E_Quorumcall, IgnoreErrors()), wantWaitSendDone: false}, + {name: "Default", callOpts: getCallOptions(), wantIgnoreErrors: false}, + {name: "IgnoreErrors", callOpts: getCallOptions(IgnoreErrors()), wantIgnoreErrors: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotWaitSendDone := tt.callOpts.mustWaitSendDone() - if gotWaitSendDone != tt.wantWaitSendDone { - t.Errorf("mustWaitSendDone() = %v, want %v", gotWaitSendDone, tt.wantWaitSendDone) + if got := tt.callOpts.ignoreErrors; got != tt.wantIgnoreErrors { + t.Errorf("ignoreErrors = %v, want %v", got, tt.wantIgnoreErrors) } }) } } +func TestCallOptionsIgnoreErrorsResourceLeak(t *testing.T) { + // Previously leaked because fire-and-forget multicast still registered in router. + // Now fixed: no replyChan → no ResponseChan → no Register. + systems := TestSystems(t, 3) + for _, sys := range systems { + sys.RegisterService(nil, func(srv *Server) { + srv.RegisterHandler(mock.TestMethod, func(_ ServerCtx, _ *Message) (*Message, error) { + return nil, nil + }) + }) + } + for _, sys := range systems { + sys.WaitForConfig(t.Context(), func(cfg Configuration) bool { + return cfg.Size() == 3 + }) + } + cfg := systems[0].OutboundConfig() + ctx := TestContext(t, 5*time.Second) + for i := range 1000 { + Multicast(cfg.Context(ctx), pb.String(fmt.Sprintf("mc-%d", i)), mock.TestMethod, IgnoreErrors()) + } + TestWaitUntil(t, 5*time.Second, func() bool { + for _, node := range cfg.Nodes() { + if node.PendingCount() > 0 { + return false + } + } + return true + }) + + for _, node := range cfg.Nodes() { + if pc := node.PendingCount(); pc > 0 { + t.Errorf("node %d: pending = %d; expected 0", node.ID(), pc) + } + } +} + func BenchmarkGetCallOptions(b *testing.B) { interceptor := func(_ *ClientCtx[msg, msg], next ResponseSeq[msg]) ResponseSeq[msg] { return next } tests := []struct { @@ -48,7 +80,7 @@ func BenchmarkGetCallOptions(b *testing.B) { b.Run(fmt.Sprintf("options=%d", tc.numOpts), func(b *testing.B) { b.ReportAllocs() for b.Loop() { - _ = getCallOptions(E_Quorumcall, opts...) + _ = getCallOptions(opts...) } }) } diff --git a/client_interceptor.go b/client_interceptor.go index 3a13bf89..5a0e8b19 100644 --- a/client_interceptor.go +++ b/client_interceptor.go @@ -7,6 +7,7 @@ import ( "github.com/relab/gorums/internal/stream" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/emptypb" ) // QuorumInterceptor intercepts and processes quorum calls, allowing modification of @@ -55,8 +56,8 @@ type ClientCtx[Req, Resp msg] struct { // streaming indicates whether this is a streaming call (for correctable streams). streaming bool - // waitSendDone indicates whether the caller waits for send completion (for multicast). - waitSendDone bool + // oneway indicates whether this is a one-way call (for multicast). + oneway bool // sendOnce ensures messages are sent exactly once, on the first // call to Responses(). This deferred sending allows interceptors @@ -69,50 +70,65 @@ func (c *ClientCtx[Req, Resp]) sendNow() { c.sendOnce.Do(c.send) } -type clientCtxOptions struct { - streaming bool - waitSendDone bool - interceptors []any -} - -// newClientCtx constructs and initializes a ClientCtx for quorum-style calls. -// It creates call metadata, configures the response iterator, and applies -// interceptors after the base iterator has been established. -func newClientCtx[Req, Resp msg]( +// newQuorumCallClientCtx constructs a ClientCtx for quorum calls (two-way, always returns responses). +// A reply channel is always created; streaming controls both its buffer size and the response iterator type. +func newQuorumCallClientCtx[Req, Resp msg]( ctx *ConfigContext, req Req, method string, - opts clientCtxOptions, + streaming bool, + interceptors []any, ) *ClientCtx[Req, Resp] { config := ctx.Configuration() + n := config.Size() + if streaming { + n *= 10 + } clientCtx := &ClientCtx[Req, Resp]{ - Context: ctx, - config: config, - request: req, - method: method, - msgID: config.nextMsgID(), - replyChan: make(chan NodeResponse[*stream.Message], chanSize(config, opts.streaming)), - streaming: opts.streaming, - waitSendDone: opts.waitSendDone, + Context: ctx, + config: config, + request: req, + method: method, + msgID: config.nextMsgID(), + streaming: streaming, + replyChan: make(chan NodeResponse[*stream.Message], n), } - - if clientCtx.streaming { + if streaming { clientCtx.responseSeq = clientCtx.streamingResponseSeq() } else { clientCtx.responseSeq = clientCtx.defaultResponseSeq() } - clientCtx.applyInterceptors(opts.interceptors) + clientCtx.applyInterceptors(interceptors) return clientCtx } -// chanSize returns the channel buffer size based on the configuration and -// whether the call is streaming. For streaming calls, we use a larger buffer -// to accommodate more in-flight messages without blocking. -func chanSize(config Configuration, streaming bool) int { - if streaming { - return config.Size() * 10 +// newMulticastClientCtx constructs a ClientCtx for multicast (one-way, no responses). +// A reply channel is created only when waitForSend=true (blocking send); fire-and-forget +// calls receive a nil channel, meaning no router entry is registered. +func newMulticastClientCtx[Req msg]( + ctx *ConfigContext, + req Req, + method string, + waitForSend bool, + interceptors []any, +) *ClientCtx[Req, *emptypb.Empty] { + config := ctx.Configuration() + var replyChan chan NodeResponse[*stream.Message] + if waitForSend { + replyChan = make(chan NodeResponse[*stream.Message], config.Size()) + } + clientCtx := &ClientCtx[Req, *emptypb.Empty]{ + Context: ctx, + config: config, + request: req, + method: method, + msgID: config.nextMsgID(), + oneway: true, + replyChan: replyChan, } - return config.Size() + clientCtx.responseSeq = clientCtx.defaultResponseSeq() + clientCtx.applyInterceptors(interceptors) + return clientCtx } // ------------------------------------------------------------------------- @@ -156,6 +172,26 @@ func (c *ClientCtx[Req, Resp]) Size() int { return c.config.Size() } +// reportNodeError sends an error response for the given node to replyChan. +// It is a no-op for fire-and-forget calls where replyChan is nil. +func (c *ClientCtx[Req, Resp]) reportNodeError(nodeID uint32, err error) { + if c.replyChan != nil { + c.replyChan <- NodeResponse[*stream.Message]{NodeID: nodeID, Err: err} + } +} + +// enqueue sends a stream.Request to the given node, populating the shared +// fields from ClientCtx so call sites only need to supply the message. +func (c *ClientCtx[Req, Resp]) enqueue(n *Node, msg *stream.Message) { + n.Enqueue(stream.Request{ + Ctx: c.Context, + Msg: msg, + Streaming: c.streaming, + Oneway: c.oneway, + ResponseChan: c.replyChan, + }) +} + // applyInterceptors chains the given interceptors, wrapping the response sequence. // Each interceptor receives the current response sequence and returns a new one. // Interceptors are applied in order, with each wrapping the previous result. @@ -186,18 +222,12 @@ func (c *ClientCtx[Req, Resp]) sendShared() { if err != nil { // Marshaling fails identically for all nodes; report and return. for _, n := range c.config { - c.replyChan <- NodeResponse[*stream.Message]{NodeID: n.ID(), Err: err} + c.reportNodeError(n.ID(), err) } return } for _, n := range c.config { - n.Enqueue(stream.Request{ - Ctx: c.Context, - Msg: sharedMsg, - Streaming: c.streaming, - WaitSendDone: c.waitSendDone, - ResponseChan: c.replyChan, - }) + c.enqueue(n, sharedMsg) } } @@ -209,19 +239,13 @@ func (c *ClientCtx[Req, Resp]) sendWithPerNodeTransformation() { if streamMsg == nil { continue // Skip node: transformAndMarshal already sent ErrSkipNode } - n.Enqueue(stream.Request{ - Ctx: c.Context, - Msg: streamMsg, - Streaming: c.streaming, - WaitSendDone: c.waitSendDone, - ResponseChan: c.replyChan, - }) + c.enqueue(n, streamMsg) } } // transformAndMarshal applies transformations to the request for the given node, // then marshals it into a stream.Message. Returns nil if transformation fails -// or marshaling fails (in which case the error is sent on replyChan). +// or marshaling fails (in which case the error is reported via reportNodeError). func (c *ClientCtx[Req, Resp]) transformAndMarshal(n *Node) *stream.Message { transformedRequest := c.request for _, transform := range c.reqTransforms { @@ -229,12 +253,12 @@ func (c *ClientCtx[Req, Resp]) transformAndMarshal(n *Node) *stream.Message { } // Check if the result is valid if protoReq, ok := any(transformedRequest).(proto.Message); !ok || protoReq == nil || !protoReq.ProtoReflect().IsValid() { - c.replyChan <- NodeResponse[*stream.Message]{NodeID: n.ID(), Err: ErrSkipNode} + c.reportNodeError(n.ID(), ErrSkipNode) return nil } streamMsg, err := stream.NewMessage(c.Context, c.msgID, c.method, transformedRequest) if err != nil { - c.replyChan <- NodeResponse[*stream.Message]{NodeID: n.ID(), Err: err} + c.reportNodeError(n.ID(), err) return nil } return streamMsg diff --git a/internal/stream/channel.go b/internal/stream/channel.go index b3652de0..8c4884f4 100644 --- a/internal/stream/channel.go +++ b/internal/stream/channel.go @@ -1,6 +1,7 @@ package stream import ( + "cmp" "context" "sync" "time" @@ -28,11 +29,25 @@ type Request struct { Ctx context.Context Msg *Message Streaming bool - WaitSendDone bool + Oneway bool ResponseChan chan<- response SendTime time.Time } +// wantServerResponse returns true if the request expects an actual +// server response and needs a router entry. It returns true for +// two-way calls (RPC, QuorumCall) and streaming calls (correctable). +func (r Request) wantServerResponse() bool { + return r.ResponseChan != nil && !r.Oneway +} + +// wantSendConfirmation returns true if the request needs send confirmation +// delivered directly on its ResponseChan, bypassing the router. It returns +// true for one-way calls (Unicast, Multicast) that are not fire-and-forget. +func (r Request) wantSendConfirmation() bool { + return r.Oneway && r.ResponseChan != nil +} + // deliver sends the response on request's response channel, preferring delivery // even if request's context is already canceled. If the channel is full, // it falls back to respecting context cancellation to avoid blocking forever. @@ -281,13 +296,15 @@ func (c *Channel) isConnected() bool { // the registered RequestHandler without touching the network. // If the node is closed, it responds with an error instead. // -// WaitSendDone and Streaming are mutually exclusive: WaitSendDone is for one-way -// calls that want send-completion confirmation, while Streaming is for -// correctable calls that keep the router entry alive for multiple server responses. +// Requests cannot combine Oneway and Streaming; they are mutually exclusive: +// - one-way calls (Unicast, Multicast) do not expect server responses. +// - streaming (correctable) calls expect multiple server responses and +// require the router entry to stay alive for the duration of the stream. +// // Combining them would cause double delivery on the response channel. func (c *Channel) Enqueue(req Request) { - if req.WaitSendDone && req.Streaming { - panic("gorums: WaitSendDone and Streaming are mutually exclusive") + if req.Oneway && req.Streaming { + panic("gorums: Oneway and Streaming are mutually exclusive") } if c.isLocal() { c.router.DispatchLocalRequest(c.id, req) @@ -373,13 +390,13 @@ func (c *Channel) drainSendQ() { // If the stream is down, it tries to re-establish it. // // Delivery contract: -// - Pre-registration exits (stream error, cancelled request context, nil stream): replyError + continue. -// The request never enters the router, so no routeResponse lookup is needed. -// - Send failure: requeuePendingMsgs handles the registered entry (requeue or cancel). -// continue skips routeResponse since the entry is already gone. -// - Send success, WaitSendDone=true: routeResponse delivers the confirmation. -// - Send success, WaitSendDone=false: the router entry stays alive for receiver() -// to deliver the actual server response, so routeResponse is not called here. +// - Pre-registration exits (stream error, cancelled request context, nil stream): +// replyError + continue. The request never enters the router. +// - Send failure: requeuePendingMsgs handles registered two-way entries (requeue or cancel). +// One-way errors are delivered directly via replyError. +// - Send success, one-way call: confirm send directly on ResponseChan. +// - Send success, two-way call: the router entry stays alive for receiver() +// to deliver the actual server response. func (c *Channel) sender() { defer c.drainSendQ() @@ -410,9 +427,9 @@ func (c *Channel) sender() { continue } - // Register call in the response router only for calls that are genuinely - // in-flight on the current stream, after all early-exit checks pass. - if req.ResponseChan != nil { + // One-way calls bypass the router and confirm directly after Send below. + if req.wantServerResponse() { + // Register only for two-way/streaming calls that expect server responses. c.router.Register(req.Msg.GetMessageSeqNo(), req) } @@ -433,14 +450,20 @@ func (c *Channel) sender() { stop() c.setLastErr(err) c.clearStream(stream) - c.requeuePendingMsgs() // removes and requeues/cancels all router entries + c.requeuePendingMsgs() // handles registered two-way entries + // One-way calls are not registered in the router to receive server responses, + // so requeuePendingMsgs won't handle them. Deliver error directly to caller. + if !req.wantServerResponse() { + // prefer context error when cancellation caused the failure. + req.replyError(c.id, cmp.Or(req.Ctx.Err(), err)) + } continue } stop() - // For one-way calls (Unicast/Multicast) with WaitSendDone, confirm successful send. - if req.WaitSendDone { - c.router.RouteResponse(req.Msg.GetMessageSeqNo(), response{NodeID: c.id}) + // For one-way calls, confirm successful send directly (no router round-trip). + if req.wantSendConfirmation() { + req.deliver(response{NodeID: c.id}) } } } diff --git a/internal/stream/channel_test.go b/internal/stream/channel_test.go index c6cd8d3e..eed7c797 100644 --- a/internal/stream/channel_test.go +++ b/internal/stream/channel_test.go @@ -223,7 +223,7 @@ func TestChannelCreation(t *testing.T) { tc := setupChannelWithoutServer(t) // send message when server is down - resp := sendRequest(t, tc.Channel, Request{WaitSendDone: true}, 1) + resp := sendRequest(t, tc.Channel, Request{Oneway: true}, 1) if resp.Err == nil { t.Error("response err: got , want error") } @@ -278,7 +278,7 @@ func TestChannelLatency(t *testing.T) { // Send a few requests to update latency for i := range 10 { - sendRequest(t, tc.Channel, Request{WaitSendDone: false}, uint64(i)) + sendRequest(t, tc.Channel, Request{Oneway: false}, uint64(i)) } latency := tc.router.Latency() @@ -294,16 +294,16 @@ func TestChannelSendCompletionWaiting(t *testing.T) { tc := setupChannel(t, echoServer) tests := []struct { - name string - waitSendDone bool + name string + oneway bool }{ - {name: "WaitForSend", waitSendDone: true}, - {name: "NoSendWaiting", waitSendDone: false}, + {name: "Oneway", oneway: true}, + {name: "Twoway", oneway: false}, } for i, tt := range tests { t.Run(tt.name, func(t *testing.T) { start := time.Now() - resp := sendRequest(t, tc.Channel, Request{WaitSendDone: tt.waitSendDone}, uint64(i)) + resp := sendRequest(t, tc.Channel, Request{Oneway: tt.oneway}, uint64(i)) elapsed := time.Since(start) if resp.Err != nil { t.Errorf("unexpected error: %v", resp.Err) @@ -342,7 +342,7 @@ func TestChannelErrors(t *testing.T) { setup: func(t *testing.T) *testChannel { tc := setupChannel(t, echoServer) // Send a message to ensure connection is established - resp := sendRequest(t, tc.Channel, Request{WaitSendDone: true}, 1) + resp := sendRequest(t, tc.Channel, Request{Oneway: true}, 1) if resp.Err != nil { t.Errorf("initial message send should succeed, got error: %v", resp.Err) } @@ -358,7 +358,7 @@ func TestChannelErrors(t *testing.T) { tc := tt.setup(t) time.Sleep(100 * time.Millisecond) - resp := sendRequest(t, tc.Channel, Request{WaitSendDone: true}, uint64(i)) + resp := sendRequest(t, tc.Channel, Request{Oneway: true}, uint64(i)) if resp.Err == nil { t.Errorf("expected error containing %q but got nil", tt.wantErr) } else if !strings.Contains(resp.Err.Error(), tt.wantErr) { @@ -559,35 +559,35 @@ func TestChannelContext(t *testing.T) { name string serverFn func(Gorums_NodeStreamServer) error contextSetup func(context.Context) (context.Context, context.CancelFunc) - waitSendDone bool + oneway bool wantErr error }{ { name: "CancelBeforeSend/WaitSending", serverFn: echoServer, contextSetup: cancelledContext, - waitSendDone: true, + oneway: true, wantErr: context.Canceled, }, { name: "CancelBeforeSend/NoSendWaiting", serverFn: echoServer, contextSetup: cancelledContext, - waitSendDone: false, + oneway: false, wantErr: context.Canceled, }, { name: "CancelDuringSend/WaitSending", serverFn: holdServer, contextSetup: expireBeforeSend, - waitSendDone: true, + oneway: true, wantErr: context.DeadlineExceeded, }, { name: "CancelDuringSend/NoSendWaiting", serverFn: holdServer, contextSetup: expireBeforeSend, - waitSendDone: false, + oneway: false, wantErr: context.DeadlineExceeded, }, } @@ -598,7 +598,7 @@ func TestChannelContext(t *testing.T) { t.Cleanup(cancel) tc := setupChannel(t, tt.serverFn) - resp := sendRequest(t, tc.Channel, Request{Ctx: ctx, WaitSendDone: tt.waitSendDone}, uint64(i)) + resp := sendRequest(t, tc.Channel, Request{Ctx: ctx, Oneway: tt.oneway}, uint64(i)) if !errors.Is(resp.Err, tt.wantErr) { t.Errorf("expected %v, got: %v", tt.wantErr, resp.Err) } @@ -1047,16 +1047,16 @@ func TestChannelRouterLifecycle(t *testing.T) { } tests := []struct { - name string - waitSendDone bool - streaming bool - wantRouter bool - wantPanic bool + name string + oneway bool + streaming bool + wantRouter bool + wantPanic bool }{ - {name: "WaitSendDone/NoStreaming/Cleanup", waitSendDone: true, streaming: false, wantRouter: false}, - {name: "WaitSendDone/Streaming/Invalid", waitSendDone: true, streaming: true, wantPanic: true}, - {name: "NoSendWaiting/NoStreaming/Cleanup", waitSendDone: false, streaming: false, wantRouter: false}, - {name: "NoSendWaiting/Streaming/KeepsRouterAlive", waitSendDone: false, streaming: true, wantRouter: true}, + {name: "Oneway/NoStreaming/Cleanup", oneway: true, streaming: false, wantRouter: false}, + {name: "Oneway/Streaming/Invalid", oneway: true, streaming: true, wantPanic: true}, + {name: "Twoway/NoStreaming/Cleanup", oneway: false, streaming: false, wantRouter: false}, + {name: "Twoway/Streaming/KeepsRouterAlive", oneway: false, streaming: true, wantRouter: true}, } for i, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1070,7 +1070,7 @@ func TestChannelRouterLifecycle(t *testing.T) { } }() msgID := uint64(i) - resp := sendRequest(t, tc.Channel, Request{WaitSendDone: tt.waitSendDone, Streaming: tt.streaming}, msgID) + resp := sendRequest(t, tc.Channel, Request{Oneway: tt.oneway, Streaming: tt.streaming}, msgID) if resp.Err != nil { t.Errorf("unexpected error: %v", resp.Err) } @@ -1100,7 +1100,7 @@ func TestChannelResponseRouting(t *testing.T) { results := make(chan msgResponse, numMessages) for i := range numMessages { - go sendReq(t, results, tc.Channel, i, 1, Request{WaitSendDone: true}) + go sendReq(t, results, tc.Channel, i, 1, Request{Oneway: true}) } // Collect and verify results @@ -1131,8 +1131,8 @@ func TestChannelConcurrentSends(t *testing.T) { results := make(chan msgResponse, numMessages) for goID := range numGoroutines { go func() { - sendReq(t, results, tc.Channel, goID, msgsPerGoroutine, Request{WaitSendDone: true}) - sendReq(t, results, tc.Channel, goID, msgsPerGoroutine, Request{WaitSendDone: false}) + sendReq(t, results, tc.Channel, goID, msgsPerGoroutine, Request{Oneway: true}) + sendReq(t, results, tc.Channel, goID, msgsPerGoroutine, Request{Oneway: false}) }() } @@ -1173,7 +1173,7 @@ func TestChannelDeadlock(t *testing.T) { } // Send message to activate stream - sendRequest(t, tc.Channel, Request{WaitSendDone: true}, 1) + sendRequest(t, tc.Channel, Request{Oneway: true}, 1) // Break the stream, forcing a reconnection on next send tc.clearStream(tc.getStream()) @@ -1286,7 +1286,7 @@ func TestChannelClearStreamDeadlock(t *testing.T) { Ctx: ctx, Msg: msg, Streaming: false, - WaitSendDone: false, + Oneway: false, ResponseChan: replyChannels[i], }) } @@ -1389,7 +1389,7 @@ func TestIsInbound(t *testing.T) { // TestInboundChannel verifies that an inbound channel can send messages. // No receiver goroutine is started for inbound channels; the caller's NodeStream -// Recv loop is the sole reader. WaitSendDone confirms successful delivery to the +// Recv loop is the sole reader. Oneway confirms successful delivery to the // stream without requiring a routed response. func TestInboundChannel(t *testing.T) { stream := newMockBidiStream() @@ -1399,7 +1399,7 @@ func TestInboundChannel(t *testing.T) { }) // Send a message and verify it is delivered to the stream. - resp := sendRequest(t, c, Request{WaitSendDone: true}, 1) + resp := sendRequest(t, c, Request{Oneway: true}, 1) if resp.Err != nil { t.Errorf("unexpected error: %v", resp.Err) } @@ -1420,7 +1420,7 @@ func TestInboundChannelClose(t *testing.T) { } // Subsequent sends should fail with ErrNodeClosed. - resp := sendRequest(t, c, Request{WaitSendDone: true}, 2) + resp := sendRequest(t, c, Request{Oneway: true}, 2) if resp.Err == nil { t.Error("expected error after close, got nil") } else if !errors.Is(resp.Err, ErrNodeClosed) { @@ -1443,7 +1443,7 @@ func TestInboundChannelStreamDown(t *testing.T) { c := NewInboundChannel(t.Context(), 1, 10, stream, NewMessageRouter()) // Verify initial send works. - resp := sendRequest(t, c, Request{WaitSendDone: true}, 1) + resp := sendRequest(t, c, Request{Oneway: true}, 1) if resp.Err != nil { t.Fatalf("initial send failed: %v", resp.Err) } @@ -1456,7 +1456,7 @@ func TestInboundChannelStreamDown(t *testing.T) { } // Sends after close must fail with ErrNodeClosed, not silently reconnect. - resp = sendRequest(t, c, Request{WaitSendDone: true}, 2) + resp = sendRequest(t, c, Request{Oneway: true}, 2) if resp.Err == nil { t.Error("expected error after stream down, got nil") } else if !errors.Is(resp.Err, ErrNodeClosed) { @@ -1587,7 +1587,7 @@ func BenchmarkChannelSend(b *testing.B) { Method: mock.TestMethod, Payload: payload, }.Build() - req := Request{Ctx: context.Background(), Msg: msg, WaitSendDone: true, ResponseChan: replyChan} + req := Request{Ctx: context.Background(), Msg: msg, Oneway: true, ResponseChan: replyChan} tc.Enqueue(req) <-replyChan } @@ -1623,7 +1623,7 @@ func BenchmarkChannelSendParallel(b *testing.B) { Method: mock.TestMethod, Payload: payload, }.Build() - req := Request{Ctx: context.Background(), Msg: msg, WaitSendDone: true, ResponseChan: replyChan} + req := Request{Ctx: context.Background(), Msg: msg, Oneway: true, ResponseChan: replyChan} tc.Enqueue(req) <-replyChan } diff --git a/internal/stream/router.go b/internal/stream/router.go index a57f68f9..2f4a5a42 100644 --- a/internal/stream/router.go +++ b/internal/stream/router.go @@ -36,7 +36,7 @@ type RequestHandler interface { // survives channel replacement (e.g., inbound reconnects). // // The router maintains a map of pending calls keyed by message sequence number. -// When a response arrives, RouteResponse looks up the matching request and +// When a response arrives, deliverPending looks up the matching request and // delivers the response on its response channel. // // The router also provides handler lookup via a shared handler map. All routers @@ -53,10 +53,10 @@ type MessageRouter struct { } // NewMessageRouter creates a new MessageRouter with an optional RequestHandler. -// The handler, if provided, is used to dispatch incoming requests: -// in RouteMessage, it processes server-initiated back-channel calls (high-bit IDs); -// in RouteInboundMessage, it dispatches client-initiated requests (low-bit IDs). -// Passing nil (or omitting the argument) disables request dispatch on this router. +// The handler, if provided, is used to dispatch incoming requests: on the client +// side it handles server-initiated back-channel calls; on the server side it +// dispatches client-initiated requests. Passing nil (or omitting the argument) +// disables request dispatch on this router. func NewMessageRouter(handler ...RequestHandler) *MessageRouter { handler = append(handler, nil) // ensure handler[0] is always valid return &MessageRouter{ @@ -86,14 +86,21 @@ func (r *MessageRouter) SetLatency(latency time.Duration) { r.latency = latency } +// PendingCount returns the number of pending calls currently registered in the router. +func (r *MessageRouter) PendingCount() int { + r.mu.Lock() + defer r.mu.Unlock() + return len(r.pending) +} + // DispatchLocalRequest handles the request in-process for the local node, // bypassing the network. It delivers the request to the registered handler, // serializing execution the same way remote nodes do: the next dispatch is // blocked until the handler returns or calls [ServerCtx.Release]. // -// For one-way calls, send-completion is confirmed before the handler runs -// if WaitSendDone is true. For two-way calls, the response is delivered -// directly to the caller's response channel via the send closure. +// For one-way calls, send-completion is confirmed before the handler runs. +// For two-way calls, the response is delivered directly to the caller's +// response channel via the send closure. func (r *MessageRouter) DispatchLocalRequest(nodeID uint32, req Request) { if req.Ctx.Err() != nil { req.replyError(nodeID, req.Ctx.Err()) @@ -103,19 +110,19 @@ func (r *MessageRouter) DispatchLocalRequest(nodeID uint32, req Request) { req.replyError(nodeID, status.Error(codes.Unimplemented, "no request handler registered")) return } - if req.WaitSendDone && req.ResponseChan != nil { + // One-way calls: confirm "send" completion before running the handler, + // since the caller blocks until confirmation arrives on ResponseChan. + if req.wantSendConfirmation() { if !req.deliver(response{NodeID: nodeID}) { - return + return // request cancelled while waiting for send confirmation; do not run the handler. } } - // For two-way calls, deliver the response via the send closure. - // For one-way calls (WaitSendDone=true or ResponseChan==nil), send is a no-op: - // the confirmation was already delivered above, and a second write would either - // race with the caller consuming the channel or block on a full response channel. send := func(msg *Message) { - if req.WaitSendDone || req.ResponseChan == nil { + // One-way fire-and-forget calls have no ResponseChan, so send is a no-op. + if !req.wantServerResponse() { return } + // Two-way calls: deliver the handler's response on ResponseChan. req.deliver(response{NodeID: nodeID, Value: msg, Err: msg.ErrorStatus()}) } @@ -126,13 +133,11 @@ func (r *MessageRouter) DispatchLocalRequest(nodeID uint32, req Request) { go r.handler.HandleRequest(req.Msg.AppendToIncomingContext(req.Ctx), req.Msg, release, send) } -// RouteMessage delivers a response to a pending call registered via [Register], -// or dispatches a server-initiated request to the registered handler. -// It is the primary entry point for messages received on the client-side stream. -// -// Responses to client-initiated calls are delivered to the matching pending call; -// responses to cancelled calls are silently dropped. Server-initiated requests -// (back-channel calls) are dispatched to the handler in a new goroutine. +// RouteMessage demultiplexes a message received on the client-side (outbound) stream. +// Server-initiated requests (back-channel calls, high-bit IDs) are dispatched to the +// handler in a new goroutine. Responses to client-initiated calls (low-bit IDs) are +// delivered to the matching pending call; responses to cancelled or unknown calls are +// silently dropped. func (r *MessageRouter) RouteMessage(ctx context.Context, nodeID uint32, msg *Message, enqueue func(Request)) { msgID := msg.GetMessageSeqNo() @@ -148,20 +153,7 @@ func (r *MessageRouter) RouteMessage(ctx context.Context, nodeID uint32, msg *Me return } - r.mu.Lock() - req, ok := r.pending[msgID] - if ok && !req.Streaming { - delete(r.pending, msgID) - } - r.mu.Unlock() - - if ok { - resp := response{NodeID: nodeID, Value: msg, Err: msg.ErrorStatus()} - if resp.Err == nil { - r.updateLatency(time.Since(req.SendTime)) - } - req.deliver(resp) - } + r.deliverPending(msgID, response{NodeID: nodeID, Value: msg, Err: msg.ErrorStatus()}) } // Register registers a pending call awaiting a response. @@ -173,13 +165,12 @@ func (r *MessageRouter) Register(msgID uint64, req Request) { r.mu.Unlock() } -// RouteInboundMessage delivers a response to a pending call registered via [Register], -// or dispatches a client-initiated request to the registered handler. +// RouteInboundMessage demultiplexes a message received on the server-side (inbound) stream. // It is the symmetric counterpart of [RouteMessage] for the server-side receive path. -// -// Responses to server-initiated calls are delivered to the matching pending call; -// responses to cancelled calls are silently absorbed. Client-initiated requests -// are dispatched to the handler in a new goroutine. The release function is always called. +// Client-initiated requests (low-bit IDs) are dispatched to the handler in a new goroutine, +// or release is called immediately when no handler is registered. Responses to server-initiated +// calls (high-bit IDs) are delivered to the matching pending call; stale responses from +// cancelled calls are silently absorbed. The release function is always called. func (r *MessageRouter) RouteInboundMessage(ctx context.Context, nodeID uint32, msg *Message, release func(), send func(*Message)) { msgID := msg.GetMessageSeqNo() if !isServerSequenceNumber(msgID) { @@ -191,33 +182,18 @@ func (r *MessageRouter) RouteInboundMessage(ctx context.Context, nodeID uint32, } return } - // Server-initiated response: look up pending call and deliver if found; - // silently absorb if not found (stale response from a cancelled call). - r.mu.Lock() - req, ok := r.pending[msgID] - if ok && !req.Streaming { - delete(r.pending, msgID) - } - r.mu.Unlock() - - if ok { - resp := response{NodeID: nodeID, Value: msg, Err: msg.ErrorStatus()} - if resp.Err == nil { - r.updateLatency(time.Since(req.SendTime)) - } - req.deliver(resp) - } + // Server-initiated response: deliver to the matching pending call (if any) and + // release the ordering lock. Stale responses from cancelled calls are silently absorbed. + r.deliverPending(msgID, response{NodeID: nodeID, Value: msg, Err: msg.ErrorStatus()}) release() } -// RouteResponse delivers a response to a pending call registered via [Register]. +// deliverPending looks up the pending call for msgID and delivers resp to it. // For non-streaming calls, the entry is removed after delivery. // For streaming calls (correctable), the entry remains for subsequent responses. -// -// Unmatched server-initiated calls (back-channel responses) are absorbed and -// the method returns true. Returns false only for unmatched client-initiated -// calls (stale responses). -func (r *MessageRouter) RouteResponse(msgID uint64, resp response) bool { +// Returns true if a matching pending entry was found (delivery is attempted but +// may be a no-op if the caller's context is already canceled), false otherwise. +func (r *MessageRouter) deliverPending(msgID uint64, resp response) bool { r.mu.Lock() req, ok := r.pending[msgID] if ok && !req.Streaming { @@ -230,9 +206,8 @@ func (r *MessageRouter) RouteResponse(msgID uint64, resp response) bool { r.updateLatency(time.Since(req.SendTime)) } req.deliver(resp) - return true } - return isServerSequenceNumber(msgID) + return ok } // Latency returns the estimated round-trip latency based on recent responses. diff --git a/internal/stream/router_test.go b/internal/stream/router_test.go index 176fd13a..44deb615 100644 --- a/internal/stream/router_test.go +++ b/internal/stream/router_test.go @@ -10,7 +10,7 @@ import ( "github.com/relab/gorums/internal/testutils/mock" ) -func TestRouterRegisterAndRoute(t *testing.T) { +func TestRouterRegisterAndDeliver(t *testing.T) { r := NewMessageRouter() replyChan := make(chan response, 1) r.Register(42, Request{ @@ -20,8 +20,8 @@ func TestRouterRegisterAndRoute(t *testing.T) { }) resp := response{NodeID: 1, Value: nil} - if !r.RouteResponse(42, resp) { - t.Fatal("RouteResponse should return true for registered msgID") + if !r.deliverPending(42, resp) { + t.Fatal("deliverPending should return true for registered msgID") } // The response should be delivered on the channel. @@ -35,38 +35,40 @@ func TestRouterRegisterAndRoute(t *testing.T) { } // After routing a non-streaming request, it should be removed. - if r.RouteResponse(42, resp) { - t.Error("RouteResponse should return false for already-consumed msgID") + if r.deliverPending(42, resp) { + t.Error("deliverPending should return false for already-consumed msgID") } } -func TestRouterRouteUnknown(t *testing.T) { +func TestRouterDeliverUnknown(t *testing.T) { r := NewMessageRouter() - if r.RouteResponse(999, response{NodeID: 1}) { - t.Error("RouteResponse should return false for unknown msgID") + if r.deliverPending(999, response{NodeID: 1}) { + t.Error("deliverPending should return false for unknown msgID") } } -// TestRouterRouteResponseServerInitiated verifies that RouteResponse absorbs -// unmatched server-initiated IDs and rejects unmatched client-initiated IDs. -func TestRouterRouteResponseServerInitiated(t *testing.T) { - t.Run("ServerInitiatedReturnsTrue", func(t *testing.T) { +// TestRouterDeliverPendingUnknown verifies that deliverPending returns false for +// any unmatched msgID, regardless of whether it is a client- or server-initiated ID. +// Callers (RouteMessage, RouteInboundMessage) are responsible for handling the +// server-initiated case before invoking deliverPending. +func TestRouterDeliverPendingUnknown(t *testing.T) { + t.Run("ServerInitiatedUnknownReturnsFalse", func(t *testing.T) { r := NewMessageRouter() - if !r.RouteResponse(ServerSequenceNumber(1), response{NodeID: 1}) { - t.Error("RouteResponse should return true for server-initiated msgID") + if r.deliverPending(ServerSequenceNumber(1), response{NodeID: 1}) { + t.Error("deliverPending should return false for unmatched server-initiated msgID") } }) t.Run("ClientInitiatedUnknownReturnsFalse", func(t *testing.T) { r := NewMessageRouter() - if r.RouteResponse(1, response{NodeID: 1}) { - t.Error("RouteResponse should return false for unmatched client-initiated msgID") + if r.deliverPending(1, response{NodeID: 1}) { + t.Error("deliverPending should return false for unmatched client-initiated msgID") } }) } -func TestRouterStreamingKeepsEntry(t *testing.T) { +func TestRouterDeliverPendingStreamingKeepsEntry(t *testing.T) { r := NewMessageRouter() replyChan := make(chan response, 3) r.Register(10, Request{ @@ -78,20 +80,20 @@ func TestRouterStreamingKeepsEntry(t *testing.T) { // First route should succeed and keep the entry. resp := response{NodeID: 1} - if !r.RouteResponse(10, resp) { - t.Fatal("first RouteResponse should succeed") + if !r.deliverPending(10, resp) { + t.Fatal("first deliverPending should succeed") } <-replyChan // drain // Second route should also succeed (streaming keeps entry alive). - if !r.RouteResponse(10, resp) { - t.Fatal("second RouteResponse should succeed for streaming entry") + if !r.deliverPending(10, resp) { + t.Fatal("second deliverPending should succeed for streaming entry") } <-replyChan // drain // Third route should also succeed. - if !r.RouteResponse(10, resp) { - t.Fatal("third RouteResponse should succeed for streaming entry") + if !r.deliverPending(10, resp) { + t.Fatal("third deliverPending should succeed for streaming entry") } <-replyChan // drain } @@ -112,7 +114,7 @@ func TestRouterCancelPending(t *testing.T) { } // Map should be empty now. - if r.RouteResponse(0, response{}) { + if r.PendingCount() != 0 { t.Error("pending map should be empty after CancelPending") } } @@ -161,7 +163,7 @@ func TestRouterRequeuePending(t *testing.T) { } // Map should be empty. - if r.RouteResponse(0, response{}) { + if r.PendingCount() != 0 { t.Error("pending map should be empty after RequeuePending") } } @@ -262,7 +264,7 @@ func (m *mockRequestHandler) HandleRequest(_ context.Context, _ *Message, releas } } -func TestRouterRouteResponseDoesNotBlockOnCanceledRequest(t *testing.T) { +func TestRouterDeliverPendingDoesNotBlockOnCanceledRequest(t *testing.T) { r := NewMessageRouter() ctx, cancel := context.WithCancel(context.Background()) replyChan := make(chan response, 1) @@ -276,8 +278,8 @@ func TestRouterRouteResponseDoesNotBlockOnCanceledRequest(t *testing.T) { done := make(chan struct{}) go func() { - if !r.RouteResponse(42, response{NodeID: 1}) { - t.Error("RouteResponse should return true for registered msgID") + if !r.deliverPending(42, response{NodeID: 1}) { + t.Error("deliverPending should return true for registered msgID") } close(done) }() @@ -285,11 +287,11 @@ func TestRouterRouteResponseDoesNotBlockOnCanceledRequest(t *testing.T) { select { case <-done: case <-time.After(time.Second): - t.Fatal("RouteResponse blocked on a canceled request with a full reply channel") + t.Fatal("deliverPending blocked on a canceled request with a full reply channel") } } -func TestRouterRouteResponsePrefersDeliveryWhenCanceledAndReplyChanReady(t *testing.T) { +func TestRouterDeliverPendingPrefersDeliveryWhenCanceledAndReplyChanReady(t *testing.T) { r := NewMessageRouter() ctx, cancel := context.WithCancel(context.Background()) replyChan := make(chan response, 1) @@ -300,8 +302,8 @@ func TestRouterRouteResponsePrefersDeliveryWhenCanceledAndReplyChanReady(t *test }) cancel() - if !r.RouteResponse(42, response{NodeID: 1, Err: ErrStreamDown}) { - t.Fatal("RouteResponse should return true for registered msgID") + if !r.deliverPending(42, response{NodeID: 1, Err: ErrStreamDown}) { + t.Fatal("deliverPending should return true for registered msgID") } select { @@ -313,7 +315,7 @@ func TestRouterRouteResponsePrefersDeliveryWhenCanceledAndReplyChanReady(t *test t.Fatalf("reply error = %v, want ErrStreamDown", got.Err) } case <-time.After(time.Second): - t.Fatal("RouteResponse dropped a ready delivery on canceled context") + t.Fatal("deliverPending dropped a ready delivery on canceled context") } } diff --git a/multicast.go b/multicast.go index b4086181..93e9d5c6 100644 --- a/multicast.go +++ b/multicast.go @@ -2,8 +2,6 @@ package gorums import ( "errors" - - "google.golang.org/protobuf/types/known/emptypb" ) // Multicast is a one-way call; no replies are returned to the client. @@ -21,19 +19,16 @@ import ( // // This method should be used by generated code only. func Multicast[Req msg](ctx *ConfigContext, req Req, method string, opts ...CallOption) error { - callOpts := getCallOptions(E_Multicast, opts...) - waitSendDone := callOpts.mustWaitSendDone() + callOpts := getCallOptions(opts...) + waitForSend := !callOpts.ignoreErrors - clientCtx := newClientCtx[Req, *emptypb.Empty](ctx, req, method, clientCtxOptions{ - waitSendDone: waitSendDone, - interceptors: callOpts.interceptors, - }) + clientCtx := newMulticastClientCtx(ctx, req, method, waitForSend, callOpts.interceptors) // Send messages immediately (multicast doesn't use lazy sending) clientCtx.sendNow() // If waiting for send completion, drain the reply channel and return the first error. - if waitSendDone { + if waitForSend { var errs []nodeError for range clientCtx.Size() { select { diff --git a/node.go b/node.go index b263d383..d2424357 100644 --- a/node.go +++ b/node.go @@ -150,6 +150,14 @@ func (n *Node) IsInbound() bool { return ch != nil && ch.IsInbound() } +// PendingCount returns the number of pending calls currently registered in the router. +func (n *Node) PendingCount() int { + if n == nil || n.router == nil { + return 0 + } + return n.router.PendingCount() +} + // attachStream attaches a new inbound channel to the node when a peer connects. // If the node already has an active channel (e.g., a stale stream from a previous // connection), it is atomically replaced and the old channel is closed. diff --git a/node_test.go b/node_test.go index 57a12445..409277b7 100644 --- a/node_test.go +++ b/node_test.go @@ -414,7 +414,7 @@ func BenchmarkNodeEnqueueSend(b *testing.B) { n.Enqueue(stream.Request{ Ctx: context.Background(), Msg: reqMsg, - WaitSendDone: true, + Oneway: true, ResponseChan: replyChan, }) <-replyChan diff --git a/quorumcall.go b/quorumcall.go index 4b036239..e0e3b816 100644 --- a/quorumcall.go +++ b/quorumcall.go @@ -49,11 +49,7 @@ func invokeQuorumCall[Req, Resp msg]( streaming bool, opts ...CallOption, ) *Responses[Resp] { - callOpts := getCallOptions(E_Quorumcall, opts...) - clientCtx := newClientCtx[Req, Resp](ctx, req, method, clientCtxOptions{ - streaming: streaming, - interceptors: callOpts.interceptors, - }) - + callOpts := getCallOptions(opts...) + clientCtx := newQuorumCallClientCtx[Req, Resp](ctx, req, method, streaming, callOpts.interceptors) return NewResponses(clientCtx) } diff --git a/testing_shared.go b/testing_shared.go index bfadd59d..39a673d0 100644 --- a/testing_shared.go +++ b/testing_shared.go @@ -40,6 +40,32 @@ func TestContext(t testing.TB, timeout time.Duration) context.Context { return ctx } +// TestWaitUntil polls predicate until it returns true or timeout elapses. +// It returns true when predicate succeeds within timeout, and false otherwise. +func TestWaitUntil(t testing.TB, timeout time.Duration, predicate func() bool) bool { + t.Helper() + + if predicate() { + return true + } + + ctx, cancel := context.WithTimeout(t.Context(), timeout) + defer cancel() + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return predicate() + case <-ticker.C: + if predicate() { + return true + } + } + } +} + // InsecureDialOptions returns a DialOption with insecure transport credentials // for testing. func InsecureDialOptions(_ testing.TB) DialOption { @@ -48,7 +74,6 @@ func InsecureDialOptions(_ testing.TB) DialOption { ) } - // TestQuorumCallError creates a QuorumCallError for testing. // The nodeErrors map contains node IDs and their corresponding errors. func TestQuorumCallError(_ testing.TB, nodeErrors map[uint32]error) QuorumCallError { diff --git a/unicast.go b/unicast.go index e666be44..e4bd6910 100644 --- a/unicast.go +++ b/unicast.go @@ -14,22 +14,21 @@ import "github.com/relab/gorums/internal/stream" // // This method should be used by generated code only. func Unicast[Req msg](ctx *NodeContext, req Req, method string, opts ...CallOption) error { - callOpts := getCallOptions(E_Unicast, opts...) + callOpts := getCallOptions(opts...) reqMsg, err := stream.NewMessage(ctx, ctx.nextMsgID(), method, req) if err != nil { return err } - waitSendDone := callOpts.mustWaitSendDone() - if !waitSendDone { + if callOpts.ignoreErrors { // Fire-and-forget: enqueue and return immediately - ctx.enqueue(stream.Request{Ctx: ctx, Msg: reqMsg}) + ctx.enqueue(stream.Request{Ctx: ctx, Msg: reqMsg, Oneway: true}) return nil } // Default: block until send completes replyChan := make(chan NodeResponse[*stream.Message], 1) - ctx.enqueue(stream.Request{Ctx: ctx, Msg: reqMsg, WaitSendDone: true, ResponseChan: replyChan}) + ctx.enqueue(stream.Request{Ctx: ctx, Msg: reqMsg, Oneway: true, ResponseChan: replyChan}) // Wait for send confirmation select {