From 49cccb0192f5183be50cb22d754aaf504df3d57d Mon Sep 17 00:00:00 2001 From: Michael Otteni Date: Fri, 22 May 2026 14:14:52 -0400 Subject: [PATCH] fix(server): close Consume race on returnError and stopForLoop teardown Consume spawned sender and Subscribe goroutines that wrote to returnError and read stopForLoop through heap-escaped pointers, then returned without waiting for them. The goroutines kept running past Consume's return, racing against its writes to returnError and its deferred close+nil of stopForLoop. Wait for those goroutines before returning, and make stopForLoop carry the error itself so main is the sole writer of returnError. The defers are ordered so loopCancel signals first, wg.Wait drains the goroutines next, and the channels close last with no writers left. Signed-off-by: Michael Otteni --- internal/server/server.go | 104 +++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 46 deletions(-) diff --git a/internal/server/server.go b/internal/server/server.go index 4b1729a..8e8853d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -168,26 +168,33 @@ func (s *ConsumerServer) Consume(stream pb.Consumer_ConsumeServer) error { //nol var returnError error isSubscribing := false - // stopForLoop is used for errors that require the exiting of Consume, - // essentially if a long running goroutine needs to exit, it needs to - // use stopForLoop as a killswitch for the entire Consume() because - // we don't want leaks. - stopForLoop := make(chan bool) - defer func() { - close(stopForLoop) - stopForLoop = nil - }() + // wg tracks the long-running sender/subscribe goroutines so Consume waits + // for them to drain before returning. Otherwise they outlive Consume and + // race against returnError / channel teardown. + var wg sync.WaitGroup + + // stopForLoop is the killswitch a long-running goroutine uses to exit + // Consume(). The error it carries (if non-nil) becomes Consume's return + // value; carrying it on the channel keeps main as the sole writer of + // returnError so goroutines never have to touch it. + stopForLoop := make(chan error) // messageChannel is used for sending messages from the provider back to Consume // for sending back to the client messageChannel := make(chan *pb.Message, 10) - defer close(messageChannel) var source *pb.Source - // lCtx is our loop context, we use it to shutdown our long running goroutines + // loopCtx is our loop context, we use it to shutdown our long running goroutines loopCtx, loopCancel := context.WithCancel(ctx) - // lCancel will signal to all long running goroutines to shutdown + + // Defers fire LIFO, so the execution order is loopCancel -> wg.Wait -> + // close(messageChannel) -> close(stopForLoop). Cancelling first lets the + // goroutines unblock; waiting before closing channels ensures nothing is + // still writing when they're closed. + defer close(stopForLoop) + defer close(messageChannel) + defer wg.Wait() defer loopCancel() consumeLoop: @@ -262,7 +269,9 @@ consumeLoop: isSubscribing = true // this goroutine receives messages from the provider and sends them to the client for processing - go func(mc <-chan *pb.Message, cont context.Context, stopFor *chan bool, returnErr *error) { + wg.Add(1) + go func(mc <-chan *pb.Message, cont context.Context, stopFor chan<- error) { + defer wg.Done() defer util.RecoverPanic() for { select { @@ -284,9 +293,9 @@ consumeLoop: if err != nil { util.Logger.Warn(i18n.StreamSendError, err.Error(), clientIdentifier) span.RecordError(err) - *returnErr = err - if *stopFor != nil { - *stopFor <- true + select { + case stopFor <- err: + case <-cont.Done(): } span.End() return @@ -294,45 +303,45 @@ consumeLoop: span.End() } } - }(messageChannel, loopCtx, &stopForLoop, &returnError) + }(messageChannel, loopCtx, stopForLoop) // call provider.Subscribe and use messageChannel to pass messages from the provider to the receiver func above - go func(mc chan<- *pb.Message, prov provider.Provider, cont context.Context, stopFor *chan bool, returnErr *error) { + wg.Add(1) + go func(mc chan<- *pb.Message, prov provider.Provider, cont context.Context, stopFor chan<- error) { + defer wg.Done() defer util.RecoverPanic() connected := prov.WaitForConnect(cont) - if connected { - err := prov.Subscribe(cont, source, mc) - if err != nil { - util.Logger.Warn(i18n.SubscribeError, err.Message) - *returnErr = errors.New(err.GetMessage()) + if !connected { + util.Logger.Warn(i18n.BrokerConnectError, "could not connect to broker") + select { + case stopFor <- errors.New("could not connect to broker"): + case <-cont.Done(): } + return + } - if source.GetDeclareOnly() { - dor := &pb.DeclareOnlyResponse{Success: true} - dor.Error = err - if err != nil { - dor.Success = false - } + err := prov.Subscribe(cont, source, mc) + var stopErr error + if err != nil { + util.Logger.Warn(i18n.SubscribeError, err.Message) + stopErr = errors.New(err.GetMessage()) + } - cr := &pb.ConsumeResponse{ - Resp: &pb.ConsumeResponse_DeclareOnlyResponse{ - DeclareOnlyResponse: dor, - }, - } - _ = sender.Send(cr) + if source.GetDeclareOnly() { + dor := &pb.DeclareOnlyResponse{Success: err == nil, Error: err} + cr := &pb.ConsumeResponse{ + Resp: &pb.ConsumeResponse_DeclareOnlyResponse{ + DeclareOnlyResponse: dor, + }, } + _ = sender.Send(cr) + } - if *stopFor != nil { - *stopFor <- true - } - } else { - util.Logger.Warn(i18n.BrokerConnectError, "could not connect to broker") - *returnErr = errors.New("could not connect to broker") - if *stopFor != nil { - *stopFor <- true - } + select { + case stopFor <- stopErr: + case <-cont.Done(): } - }(messageChannel, prov, loopCtx, &stopForLoop, &returnError) + }(messageChannel, prov, loopCtx, stopForLoop) } else if cnsmRecv.msg.GetAck() != nil { // Ack or Nack the message go func() { ackmsg := cnsmRecv.msg.GetAck() @@ -365,7 +374,10 @@ consumeLoop: _ = sender.Send(&pb.ConsumeResponse{Resp: &pb.ConsumeResponse_ConsumedResponse{ConsumedResponse: mcr}}) }() } - case <-stopForLoop: + case stopErr := <-stopForLoop: + if stopErr != nil { + returnError = stopErr + } break consumeLoop } }