Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 20 additions & 27 deletions internal/provider/connectors/amqp091/amqp091.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ type BrokerDetails struct {
knownQueues *util.ConcurrentMap
knownBindings *util.ConcurrentMap
activeMessages *util.ConcurrentMap
state uint16
state atomic.Uint32
connectionConfig *pb.ConnectionConfiguration
tlsSkipVerify bool
ActiveStreams int64
Expand Down Expand Up @@ -157,7 +157,6 @@ func (prov *amqp091provider) getBrokerDetails(ctx context.Context) (*BrokerDetai
}

if bd := prov.getBrokerDetailsByIdentifier(clientIdentifier); bd != nil {
bd.tlsConfig = prov.tlsConfig
return bd, nil
}

Expand Down Expand Up @@ -1100,10 +1099,10 @@ func (prov *amqp091provider) queueSubscribe(ctx context.Context, bd *BrokerDetai
if cancelErr != (amqp091Error{}) {
util.Logger.Debugf("Received channel notify for client during subscribe %v : %v", bd.ClientIdentifier, cancelErr)
return &pb.Error{Message: cancelErr.Error()}
} else if bd.state != provider.CONNECTED {
} else if s := bd.state.Load(); s != provider.CONNECTED {
// The connection was closed without an error on the channel, so this was expected.
// TODO: Should we check for DISCONNECTED/CONNECTING as well?
util.Logger.Debugf("Received channel state not connected during subscribe %v : %v", bd.ClientIdentifier, bd.state)
util.Logger.Debugf("Received channel state not connected during subscribe %v : %v", bd.ClientIdentifier, s)
return nil
}
case chanErr, ok := <-connErrChan:
Expand All @@ -1115,10 +1114,10 @@ func (prov *amqp091provider) queueSubscribe(ctx context.Context, bd *BrokerDetai
if chanErr != (amqp091Error{}) {
util.Logger.Debugf("Received connection notify for client during subscribe %v : %v", bd.ClientIdentifier, chanErr)
return &pb.Error{Message: chanErr.Error()}
} else if bd.state != provider.CONNECTED {
} else if s := bd.state.Load(); s != provider.CONNECTED {
// The connection was closed without an error on the channel, so this was expected.
// TODO: Should we check for DISCONNECTED/CONNECTING as well?
util.Logger.Debugf("Received connection state not connected during subscribe %v : %v", bd.ClientIdentifier, bd.state)
util.Logger.Debugf("Received connection state not connected during subscribe %v : %v", bd.ClientIdentifier, s)
return nil
}
case msg, ok := <-messages:
Expand Down Expand Up @@ -1401,10 +1400,10 @@ func (prov *amqp091provider) Publish(ctx context.Context, messageChannel <-chan
if cancelErr != (amqp091Error{}) {
util.Logger.Debugf("Received channel notify for client during publish %v : %v", bd.ClientIdentifier, cancelErr)
return &pb.Error{Message: cancelErr.Error()}
} else if bd.state != provider.CONNECTED {
} else if s := bd.state.Load(); s != provider.CONNECTED {
// The connection was closed without an error on the channel, so this was expected.
// TODO: Should we check for DISCONNECTED/CONNECTING as well?
util.Logger.Debugf("Received channel state not connected during publish %v : %v", bd.ClientIdentifier, bd.state)
util.Logger.Debugf("Received channel state not connected during publish %v : %v", bd.ClientIdentifier, s)
return nil
}
case chanErr, ok := <-connErrChan:
Expand All @@ -1417,10 +1416,10 @@ func (prov *amqp091provider) Publish(ctx context.Context, messageChannel <-chan
util.Logger.Debugf("Received connection notify for client during publish %v : %v", bd.ClientIdentifier, chanErr)
retError := &pb.Error{Message: chanErr.Error()}
return retError
} else if bd.state != provider.CONNECTED {
} else if s := bd.state.Load(); s != provider.CONNECTED {
// The connection was closed without an error on the channel, so this was expected.
// TODO: Should we check for DISCONNECTED/CONNECTING as well?
util.Logger.Debugf("Received connection state not connected during publish %v : %v", bd.ClientIdentifier, bd.state)
util.Logger.Debugf("Received connection state not connected during publish %v : %v", bd.ClientIdentifier, s)
return nil
}
case message := <-messageChannel:
Expand Down Expand Up @@ -1652,7 +1651,7 @@ func (prov *amqp091provider) WaitForConnect(ctx context.Context) bool {
defer bd.decrementStreamCount()

for start := time.Now(); time.Since(start) < provider.CONNECTTIMEOUT*time.Second; {
if bd.state == provider.CONNECTED {
if bd.state.Load() == provider.CONNECTED {
util.Logger.Info(i18n.ClientConnected, bd.ClientIdentifier)
return true
}
Expand Down Expand Up @@ -1806,9 +1805,7 @@ func (bd *BrokerDetails) connectionWatcher() {
// frame or a TCP-level close/RST). Unconditionally reconnecting
// avoids a race where IsClosed() is still false at the time of
// the check even though the connection is genuinely gone.
bd.Lock()
bd.state = provider.DISCONNECTED
bd.Unlock()
bd.state.Store(provider.DISCONNECTED)
// Retry until we reconnect or the client explicitly disconnects.
// Without this loop, a single failed connect() would leave the
// watcher waiting for the 30-second fallback timer before trying
Expand All @@ -1820,19 +1817,15 @@ func (bd *BrokerDetails) connectionWatcher() {
break
}
// connect() failed; reset state so the next attempt can proceed.
bd.Lock()
bd.state = provider.DISCONNECTED
bd.Unlock()
bd.state.Store(provider.DISCONNECTED)
}
continue
case <-time.After(30 * time.Second):
// if we never get an error on the bd.ErrorChannel, try again after 30 seconds
// this is to help deal with race condition where we're not listening on the bd.ErrorChannel
// when there is an error on the connection
if bd.Connection.IsClosed() {
bd.Lock()
bd.state = provider.DISCONNECTED
bd.Unlock()
bd.state.Store(provider.DISCONNECTED)
// Ignore this error because we will reconnect in 30 seconds
bd.connect() //nolint:errcheck
}
Expand All @@ -1843,9 +1836,9 @@ func (bd *BrokerDetails) connectionWatcher() {

// waitWhileConnecting waits up to 30 seconds for an in-progress connection attempt to resolve.
// It returns the resulting state: CONNECTED, CLOSED, or DISCONNECTED (also used for timeouts).
func (bd *BrokerDetails) waitWhileConnecting() uint16 {
func (bd *BrokerDetails) waitWhileConnecting() uint32 {
for start := time.Now(); time.Since(start) < 30*time.Second; {
switch bd.state {
switch bd.state.Load() {
case provider.CONNECTED:
return provider.CONNECTED
case provider.CONNECTING:
Expand All @@ -1864,7 +1857,7 @@ func (bd *BrokerDetails) connect() (bool, error) {
return false, nil
}

if bd.state == provider.CONNECTING {
if bd.state.Load() == provider.CONNECTING {
switch bd.waitWhileConnecting() {
case provider.CONNECTED:
return true, nil
Expand All @@ -1875,11 +1868,11 @@ func (bd *BrokerDetails) connect() (bool, error) {

bd.Lock()
defer bd.Unlock()
if bd.state == provider.CONNECTED {
if bd.state.Load() == provider.CONNECTED {
return true, nil
}

bd.state = provider.CONNECTING
bd.state.Store(provider.CONNECTING)
var conn amqp091ConnectionShim
var err error

Expand Down Expand Up @@ -1928,14 +1921,14 @@ func (bd *BrokerDetails) connect() (bool, error) {

if err != nil {
util.Logger.Warn(i18n.BrokerConnectError, err.Error())
bd.state = provider.CLOSED
bd.state.Store(provider.CLOSED)
return false, err
}

bd.Connection = conn
bd.ErrorChannel = make(chan amqp091Error, 1)
bd.ErrorChannel = bd.Connection.NotifyClose(bd.ErrorChannel) // this looks unneeded but it aids in unit testing
bd.state = provider.CONNECTED
bd.state.Store(provider.CONNECTED)

util.Logger.Info(i18n.ClientConnected, bd.ClientIdentifier)

Expand Down
19 changes: 9 additions & 10 deletions internal/provider/connectors/amqp091/amqp091_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -832,13 +832,12 @@ func Test_Retry(t *testing.T) {
mm.DeliveryTag = 1
delMock.On("Ack").Return(nil)
mm.SetDelivery(&delMock)
go func() {
msgs <- mm
}()

mm.Headers = amqp091Table{
retryCountHeaderName: 1,
}
go func() {
msgs <- mm
}()

cancels := make(chan amqp091Error)
cmock.On("NotifyClose").Return(cancels)
Expand Down Expand Up @@ -2356,11 +2355,11 @@ func Test_connect_clientDisconnect(t *testing.T) {

func Test_connect_connecting_connected(t *testing.T) {
bd := BrokerDetails{}
bd.state = provider.CONNECTING
bd.state.Store(provider.CONNECTING)
bd.clientDisconnect = false
go func() {
time.Sleep(1 * time.Second)
bd.state = provider.CONNECTED
bd.state.Store(provider.CONNECTED)
}()
ok, err := bd.connect()
assert.True(t, ok)
Expand All @@ -2369,11 +2368,11 @@ func Test_connect_connecting_connected(t *testing.T) {

func Test_connect_connecting_closed(t *testing.T) {
bd := BrokerDetails{}
bd.state = provider.CONNECTING
bd.state.Store(provider.CONNECTING)
bd.clientDisconnect = false
go func() {
time.Sleep(1 * time.Second)
bd.state = provider.CLOSED
bd.state.Store(provider.CLOSED)
}()
ok, err := bd.connect()
assert.False(t, ok)
Expand All @@ -2382,7 +2381,7 @@ func Test_connect_connecting_closed(t *testing.T) {

func Test_connect_connecting_disconnected(t *testing.T) {
bd := BrokerDetails{}
bd.state = provider.CONNECTING
bd.state.Store(provider.CONNECTING)
bd.clientDisconnect = false

msrv := mockManagementRequestServer()
Expand Down Expand Up @@ -2412,7 +2411,7 @@ func Test_connect_connecting_disconnected(t *testing.T) {

go func() {
time.Sleep(1 * time.Second)
bd.state = provider.DISCONNECTED
bd.state.Store(provider.DISCONNECTED)
}()
ok, err := bd.connect()
assert.True(t, ok)
Expand Down