diff --git a/internal/server/server.go b/internal/server/server.go index 4b1729a..8e8853d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -168,26 +168,33 @@ func (s *ConsumerServer) Consume(stream pb.Consumer_ConsumeServer) error { //nol var returnError error isSubscribing := false - // stopForLoop is used for errors that require the exiting of Consume, - // essentially if a long running goroutine needs to exit, it needs to - // use stopForLoop as a killswitch for the entire Consume() because - // we don't want leaks. - stopForLoop := make(chan bool) - defer func() { - close(stopForLoop) - stopForLoop = nil - }() + // wg tracks the long-running sender/subscribe goroutines so Consume waits + // for them to drain before returning. Otherwise they outlive Consume and + // race against returnError / channel teardown. + var wg sync.WaitGroup + + // stopForLoop is the killswitch a long-running goroutine uses to exit + // Consume(). The error it carries (if non-nil) becomes Consume's return + // value; carrying it on the channel keeps main as the sole writer of + // returnError so goroutines never have to touch it. + stopForLoop := make(chan error) // messageChannel is used for sending messages from the provider back to Consume // for sending back to the client messageChannel := make(chan *pb.Message, 10) - defer close(messageChannel) var source *pb.Source - // lCtx is our loop context, we use it to shutdown our long running goroutines + // loopCtx is our loop context, we use it to shutdown our long running goroutines loopCtx, loopCancel := context.WithCancel(ctx) - // lCancel will signal to all long running goroutines to shutdown + + // Defers fire LIFO, so the execution order is loopCancel -> wg.Wait -> + // close(messageChannel) -> close(stopForLoop). Cancelling first lets the + // goroutines unblock; waiting before closing channels ensures nothing is + // still writing when they're closed. + defer close(stopForLoop) + defer close(messageChannel) + defer wg.Wait() defer loopCancel() consumeLoop: @@ -262,7 +269,9 @@ consumeLoop: isSubscribing = true // this goroutine receives messages from the provider and sends them to the client for processing - go func(mc <-chan *pb.Message, cont context.Context, stopFor *chan bool, returnErr *error) { + wg.Add(1) + go func(mc <-chan *pb.Message, cont context.Context, stopFor chan<- error) { + defer wg.Done() defer util.RecoverPanic() for { select { @@ -284,9 +293,9 @@ consumeLoop: if err != nil { util.Logger.Warn(i18n.StreamSendError, err.Error(), clientIdentifier) span.RecordError(err) - *returnErr = err - if *stopFor != nil { - *stopFor <- true + select { + case stopFor <- err: + case <-cont.Done(): } span.End() return @@ -294,45 +303,45 @@ consumeLoop: span.End() } } - }(messageChannel, loopCtx, &stopForLoop, &returnError) + }(messageChannel, loopCtx, stopForLoop) // call provider.Subscribe and use messageChannel to pass messages from the provider to the receiver func above - go func(mc chan<- *pb.Message, prov provider.Provider, cont context.Context, stopFor *chan bool, returnErr *error) { + wg.Add(1) + go func(mc chan<- *pb.Message, prov provider.Provider, cont context.Context, stopFor chan<- error) { + defer wg.Done() defer util.RecoverPanic() connected := prov.WaitForConnect(cont) - if connected { - err := prov.Subscribe(cont, source, mc) - if err != nil { - util.Logger.Warn(i18n.SubscribeError, err.Message) - *returnErr = errors.New(err.GetMessage()) + if !connected { + util.Logger.Warn(i18n.BrokerConnectError, "could not connect to broker") + select { + case stopFor <- errors.New("could not connect to broker"): + case <-cont.Done(): } + return + } - if source.GetDeclareOnly() { - dor := &pb.DeclareOnlyResponse{Success: true} - dor.Error = err - if err != nil { - dor.Success = false - } + err := prov.Subscribe(cont, source, mc) + var stopErr error + if err != nil { + util.Logger.Warn(i18n.SubscribeError, err.Message) + stopErr = errors.New(err.GetMessage()) + } - cr := &pb.ConsumeResponse{ - Resp: &pb.ConsumeResponse_DeclareOnlyResponse{ - DeclareOnlyResponse: dor, - }, - } - _ = sender.Send(cr) + if source.GetDeclareOnly() { + dor := &pb.DeclareOnlyResponse{Success: err == nil, Error: err} + cr := &pb.ConsumeResponse{ + Resp: &pb.ConsumeResponse_DeclareOnlyResponse{ + DeclareOnlyResponse: dor, + }, } + _ = sender.Send(cr) + } - if *stopFor != nil { - *stopFor <- true - } - } else { - util.Logger.Warn(i18n.BrokerConnectError, "could not connect to broker") - *returnErr = errors.New("could not connect to broker") - if *stopFor != nil { - *stopFor <- true - } + select { + case stopFor <- stopErr: + case <-cont.Done(): } - }(messageChannel, prov, loopCtx, &stopForLoop, &returnError) + }(messageChannel, prov, loopCtx, stopForLoop) } else if cnsmRecv.msg.GetAck() != nil { // Ack or Nack the message go func() { ackmsg := cnsmRecv.msg.GetAck() @@ -365,7 +374,10 @@ consumeLoop: _ = sender.Send(&pb.ConsumeResponse{Resp: &pb.ConsumeResponse_ConsumedResponse{ConsumedResponse: mcr}}) }() } - case <-stopForLoop: + case stopErr := <-stopForLoop: + if stopErr != nil { + returnError = stopErr + } break consumeLoop } }