Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,7 @@ branch-compare-*
tmp/*
doc/task-*.md
doc/issue-*.md
doc/review-*.md
.claude/settings.local.json
.superset/config.json
.claude/agent-memory/
1 change: 1 addition & 0 deletions .vscode/gorums.txt
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ testutils
timestamppb
tmpl
Tormod
Twoway
ucast
unexport
Unexported
Expand Down
19 changes: 3 additions & 16 deletions callopts.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,18 @@ package gorums

import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/runtime/protoimpl"
)

type callOptions struct {
callType *protoimpl.ExtensionInfo
ignoreErrors bool
interceptors []any // Type-erased interceptors, restored by QuorumCall
}

// mustWaitSendDone returns true if the caller of a one-way call type must wait
// for send completion. This is the default behavior unless the IgnoreErrors
// call option is set. This always returns false for two-way call types, since
// they should always wait for actual server responses.
func (o callOptions) mustWaitSendDone() bool {
// must wait for send completion if we are not ignoring errors
// and the call type is Unicast or Multicast
return !o.ignoreErrors && (o.callType == E_Unicast || o.callType == E_Multicast)
}

// CallOption is a function that sets a value in the given callOptions struct
type CallOption func(*callOptions)

func getCallOptions(callType *protoimpl.ExtensionInfo, opts ...CallOption) callOptions {
func getCallOptions(opts ...CallOption) callOptions {
o := callOptions{
callType: callType,
ignoreErrors: false, // default: return error and wait for send completion
}
for _, opt := range opts {
Expand All @@ -46,8 +33,8 @@ func IgnoreErrors() CallOption {
}

// Interceptors returns a CallOption that adds quorum call interceptors.
// Interceptors are executed in the order provided, modifying the Responses object
// before the user calls a terminal method.
// Interceptors are executed in the order provided, modifying the Responses
// object before the user calls a terminal method.
//
// Example:
//
Expand Down
64 changes: 48 additions & 16 deletions callopts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,67 @@ package gorums
import (
"fmt"
"testing"
"time"

"github.com/relab/gorums/internal/testutils/mock"
pb "google.golang.org/protobuf/types/known/wrapperspb"
)

func TestCallOptionsMustWaitSendDone(t *testing.T) {
func TestCallOptionsIgnoreErrors(t *testing.T) {
tests := []struct {
name string
callOpts callOptions
wantWaitSendDone bool
wantIgnoreErrors bool
}{
// One-way call types
{name: "Unicast/Default", callOpts: getCallOptions(E_Unicast), wantWaitSendDone: true},
{name: "Unicast/IgnoreErrors", callOpts: getCallOptions(E_Unicast, IgnoreErrors()), wantWaitSendDone: false},
{name: "Multicast/Default", callOpts: getCallOptions(E_Multicast), wantWaitSendDone: true},
{name: "Multicast/IgnoreErrors", callOpts: getCallOptions(E_Multicast, IgnoreErrors()), wantWaitSendDone: false},
// Two-way call types (never wait for send completion, regardless of option)
{name: "Rpc/Default", callOpts: getCallOptions(E_Rpc), wantWaitSendDone: false},
{name: "Rpc/IgnoreErrors", callOpts: getCallOptions(E_Rpc, IgnoreErrors()), wantWaitSendDone: false},
{name: "Quorumcall/Default", callOpts: getCallOptions(E_Quorumcall), wantWaitSendDone: false},
{name: "Quorumcall/IgnoreErrors", callOpts: getCallOptions(E_Quorumcall, IgnoreErrors()), wantWaitSendDone: false},
{name: "Default", callOpts: getCallOptions(), wantIgnoreErrors: false},
{name: "IgnoreErrors", callOpts: getCallOptions(IgnoreErrors()), wantIgnoreErrors: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotWaitSendDone := tt.callOpts.mustWaitSendDone()
if gotWaitSendDone != tt.wantWaitSendDone {
t.Errorf("mustWaitSendDone() = %v, want %v", gotWaitSendDone, tt.wantWaitSendDone)
if got := tt.callOpts.ignoreErrors; got != tt.wantIgnoreErrors {
t.Errorf("ignoreErrors = %v, want %v", got, tt.wantIgnoreErrors)
}
})
}
}

func TestCallOptionsIgnoreErrorsResourceLeak(t *testing.T) {
// Previously leaked because fire-and-forget multicast still registered in router.
// Now fixed: no replyChan → no ResponseChan → no Register.
systems := TestSystems(t, 3)
for _, sys := range systems {
sys.RegisterService(nil, func(srv *Server) {
srv.RegisterHandler(mock.TestMethod, func(_ ServerCtx, _ *Message) (*Message, error) {
return nil, nil
})
})
}
for _, sys := range systems {
sys.WaitForConfig(t.Context(), func(cfg Configuration) bool {
return cfg.Size() == 3
})
}
cfg := systems[0].OutboundConfig()
ctx := TestContext(t, 5*time.Second)
for i := range 1000 {
Multicast(cfg.Context(ctx), pb.String(fmt.Sprintf("mc-%d", i)), mock.TestMethod, IgnoreErrors())
}
TestWaitUntil(t, 5*time.Second, func() bool {
for _, node := range cfg.Nodes() {
if node.PendingCount() > 0 {
return false
}
}
return true
})
Comment thread
meling marked this conversation as resolved.

for _, node := range cfg.Nodes() {
if pc := node.PendingCount(); pc > 0 {
t.Errorf("node %d: pending = %d; expected 0", node.ID(), pc)
}
}
}

func BenchmarkGetCallOptions(b *testing.B) {
interceptor := func(_ *ClientCtx[msg, msg], next ResponseSeq[msg]) ResponseSeq[msg] { return next }
tests := []struct {
Expand All @@ -48,7 +80,7 @@ func BenchmarkGetCallOptions(b *testing.B) {
b.Run(fmt.Sprintf("options=%d", tc.numOpts), func(b *testing.B) {
b.ReportAllocs()
for b.Loop() {
_ = getCallOptions(E_Quorumcall, opts...)
_ = getCallOptions(opts...)
}
})
}
Expand Down
122 changes: 73 additions & 49 deletions client_interceptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"github.com/relab/gorums/internal/stream"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/emptypb"
)

// QuorumInterceptor intercepts and processes quorum calls, allowing modification of
Expand Down Expand Up @@ -55,8 +56,8 @@ type ClientCtx[Req, Resp msg] struct {
// streaming indicates whether this is a streaming call (for correctable streams).
streaming bool

// waitSendDone indicates whether the caller waits for send completion (for multicast).
waitSendDone bool
// oneway indicates whether this is a one-way call (for multicast).
oneway bool

// sendOnce ensures messages are sent exactly once, on the first
// call to Responses(). This deferred sending allows interceptors
Expand All @@ -69,50 +70,65 @@ func (c *ClientCtx[Req, Resp]) sendNow() {
c.sendOnce.Do(c.send)
}

type clientCtxOptions struct {
streaming bool
waitSendDone bool
interceptors []any
}

// newClientCtx constructs and initializes a ClientCtx for quorum-style calls.
// It creates call metadata, configures the response iterator, and applies
// interceptors after the base iterator has been established.
func newClientCtx[Req, Resp msg](
// newQuorumCallClientCtx constructs a ClientCtx for quorum calls (two-way, always returns responses).
// A reply channel is always created; streaming controls both its buffer size and the response iterator type.
func newQuorumCallClientCtx[Req, Resp msg](
ctx *ConfigContext,
req Req,
method string,
opts clientCtxOptions,
streaming bool,
interceptors []any,
) *ClientCtx[Req, Resp] {
config := ctx.Configuration()
n := config.Size()
if streaming {
n *= 10
}
clientCtx := &ClientCtx[Req, Resp]{
Context: ctx,
config: config,
request: req,
method: method,
msgID: config.nextMsgID(),
replyChan: make(chan NodeResponse[*stream.Message], chanSize(config, opts.streaming)),
streaming: opts.streaming,
waitSendDone: opts.waitSendDone,
Context: ctx,
config: config,
request: req,
method: method,
msgID: config.nextMsgID(),
streaming: streaming,
replyChan: make(chan NodeResponse[*stream.Message], n),
}

if clientCtx.streaming {
if streaming {
clientCtx.responseSeq = clientCtx.streamingResponseSeq()
} else {
clientCtx.responseSeq = clientCtx.defaultResponseSeq()
}
clientCtx.applyInterceptors(opts.interceptors)
clientCtx.applyInterceptors(interceptors)
return clientCtx
}

// chanSize returns the channel buffer size based on the configuration and
// whether the call is streaming. For streaming calls, we use a larger buffer
// to accommodate more in-flight messages without blocking.
func chanSize(config Configuration, streaming bool) int {
if streaming {
return config.Size() * 10
// newMulticastClientCtx constructs a ClientCtx for multicast (one-way, no responses).
// A reply channel is created only when waitForSend=true (blocking send); fire-and-forget
// calls receive a nil channel, meaning no router entry is registered.
func newMulticastClientCtx[Req msg](
ctx *ConfigContext,
req Req,
method string,
waitForSend bool,
interceptors []any,
) *ClientCtx[Req, *emptypb.Empty] {
config := ctx.Configuration()
var replyChan chan NodeResponse[*stream.Message]
if waitForSend {
replyChan = make(chan NodeResponse[*stream.Message], config.Size())
}
clientCtx := &ClientCtx[Req, *emptypb.Empty]{
Context: ctx,
config: config,
request: req,
method: method,
msgID: config.nextMsgID(),
oneway: true,
replyChan: replyChan,
}
return config.Size()
clientCtx.responseSeq = clientCtx.defaultResponseSeq()
clientCtx.applyInterceptors(interceptors)
return clientCtx
}

// -------------------------------------------------------------------------
Expand Down Expand Up @@ -156,6 +172,26 @@ func (c *ClientCtx[Req, Resp]) Size() int {
return c.config.Size()
}

// reportNodeError sends an error response for the given node to replyChan.
// It is a no-op for fire-and-forget calls where replyChan is nil.
func (c *ClientCtx[Req, Resp]) reportNodeError(nodeID uint32, err error) {
if c.replyChan != nil {
c.replyChan <- NodeResponse[*stream.Message]{NodeID: nodeID, Err: err}
}
}

// enqueue sends a stream.Request to the given node, populating the shared
// fields from ClientCtx so call sites only need to supply the message.
func (c *ClientCtx[Req, Resp]) enqueue(n *Node, msg *stream.Message) {
n.Enqueue(stream.Request{
Ctx: c.Context,
Msg: msg,
Streaming: c.streaming,
Oneway: c.oneway,
ResponseChan: c.replyChan,
})
}

// applyInterceptors chains the given interceptors, wrapping the response sequence.
// Each interceptor receives the current response sequence and returns a new one.
// Interceptors are applied in order, with each wrapping the previous result.
Expand Down Expand Up @@ -186,18 +222,12 @@ func (c *ClientCtx[Req, Resp]) sendShared() {
if err != nil {
// Marshaling fails identically for all nodes; report and return.
for _, n := range c.config {
c.replyChan <- NodeResponse[*stream.Message]{NodeID: n.ID(), Err: err}
c.reportNodeError(n.ID(), err)
}
return
}
for _, n := range c.config {
n.Enqueue(stream.Request{
Ctx: c.Context,
Msg: sharedMsg,
Streaming: c.streaming,
WaitSendDone: c.waitSendDone,
ResponseChan: c.replyChan,
})
c.enqueue(n, sharedMsg)
}
}

Expand All @@ -209,32 +239,26 @@ func (c *ClientCtx[Req, Resp]) sendWithPerNodeTransformation() {
if streamMsg == nil {
continue // Skip node: transformAndMarshal already sent ErrSkipNode
}
n.Enqueue(stream.Request{
Ctx: c.Context,
Msg: streamMsg,
Streaming: c.streaming,
WaitSendDone: c.waitSendDone,
ResponseChan: c.replyChan,
})
c.enqueue(n, streamMsg)
}
}

// transformAndMarshal applies transformations to the request for the given node,
// then marshals it into a stream.Message. Returns nil if transformation fails
// or marshaling fails (in which case the error is sent on replyChan).
// or marshaling fails (in which case the error is reported via reportNodeError).
func (c *ClientCtx[Req, Resp]) transformAndMarshal(n *Node) *stream.Message {
transformedRequest := c.request
for _, transform := range c.reqTransforms {
transformedRequest = transform(transformedRequest, n)
}
// Check if the result is valid
if protoReq, ok := any(transformedRequest).(proto.Message); !ok || protoReq == nil || !protoReq.ProtoReflect().IsValid() {
c.replyChan <- NodeResponse[*stream.Message]{NodeID: n.ID(), Err: ErrSkipNode}
c.reportNodeError(n.ID(), ErrSkipNode)
return nil
}
streamMsg, err := stream.NewMessage(c.Context, c.msgID, c.method, transformedRequest)
if err != nil {
c.replyChan <- NodeResponse[*stream.Message]{NodeID: n.ID(), Err: err}
c.reportNodeError(n.ID(), err)
return nil
}
return streamMsg
Expand Down
Loading
Loading