diff --git a/internal/provider/connectors/amqp091/amqp091.go b/internal/provider/connectors/amqp091/amqp091.go index d6476f0..0c5e211 100644 --- a/internal/provider/connectors/amqp091/amqp091.go +++ b/internal/provider/connectors/amqp091/amqp091.go @@ -1209,11 +1209,31 @@ func (prov *amqp091provider) streamSubscribe(ctx context.Context, bd *BrokerDeta ttl = val } + amqpChannel, err := bd.Connection.NewChannel(false) + if err != nil { + return &pb.Error{Message: err.Error()} + } + defer amqpChannel.Close() + + if source.GetAddress().GetType() != pb.Address_STREAM { + err := prov.declareExchange(source.GetAddress(), bd, amqpChannel) + if err != nil { + util.Logger.Debugf("Failed to declare exchange for source %s: %v", source.GetName(), err) + } + } + dErr := bd.StreamConnection.DeclareStream(source.GetName(), ttl) if dErr != nil { return &pb.Error{IsFatal: true, Message: fmt.Sprintf("failed to declare stream: %s", dErr.Error())} } + if source.GetAddress().GetType() != pb.Address_STREAM { + err := prov.declareBinding(source, bd, amqpChannel, true) + if err != nil { + util.Logger.Debugf("Failed to declare binding for source %s: %s", source.GetName(), err.Error()) + } + } + if source.GetDeclareOnly() { // if we reach here, everything has succeeded and we should return from Consume if source.DeclareOnly = true return nil diff --git a/internal/provider/connectors/amqp091/stream_test.go b/internal/provider/connectors/amqp091/stream_test.go index 0d1abc6..49b71b1 100644 --- a/internal/provider/connectors/amqp091/stream_test.go +++ b/internal/provider/connectors/amqp091/stream_test.go @@ -8,7 +8,9 @@ import ( "crypto/tls" "errors" "fmt" + "net/url" "reflect" + "strconv" "sync" "testing" "time" @@ -533,12 +535,14 @@ func Test_SubscribeStream(t *testing.T) { pmock := &streamConsumerMock{} pmock.On("Close").Return(nil) smock.On("NewConsumer", src.GetName(), src.GetName(), src.GetOptions()["Offset"], mock.Anything, mock.AnythingOfType("bool")).Return(pmock, nil) + smock.On("StoreOffset", src.GetName(), src.GetName(), mock.Anything).Return(nil) oldNewStreamConn := NewStreamConn NewStreamConn = func(string, string, *tls.Config) streamConnectionShim { return smock } cmock := &amqpChannelMock{} + cmock.On("Close").Return(nil) amock := &amqpConnectionMock{} amock.On("Connect").Return(nil) @@ -547,6 +551,7 @@ func Test_SubscribeStream(t *testing.T) { errs := make(chan amqp091Error) amock.On("NotifyClose").Return(errs) + amock.On("NewChannel", false).Return(cmock, nil) oldNewAmqpConn091 := NewAmqpConn091 NewAmqpConn091 = func(string, string, *tls.Config) amqp091ConnectionShim { return amock @@ -585,10 +590,13 @@ func Test_SubscribeStream(t *testing.T) { cancel() time.Sleep(1000 * time.Millisecond) + // address type is STREAM: exchange and binding must NOT be declared + cmock.AssertNumberOfCalls(t, "ExchangeDeclare", 0) + cmock.AssertNumberOfCalls(t, "QueueBind", 0) + cmock.AssertExpectations(t) amock.AssertExpectations(t) pmock.AssertExpectations(t) smock.AssertExpectations(t) - cmock.AssertExpectations(t) } func Test_SubscribeStreamBadOpt(t *testing.T) { @@ -716,6 +724,11 @@ func Test_streamSubscribe(t *testing.T) { return amock } + cmock := &amqpChannelMock{} + cmock.On("Close").Return(nil) + amock.On("NewChannel", false).Return(cmock, nil) + bd.Connection = amock + pmock := &streamConsumerMock{} pmock.On("Close").Return(nil) @@ -1023,6 +1036,7 @@ func Test_SubscribeStreamFailedDeclare(t *testing.T) { } cmock := &amqpChannelMock{} + cmock.On("Close").Return(nil) amock := &amqpConnectionMock{} amock.On("Connect").Return(nil) @@ -1031,6 +1045,7 @@ func Test_SubscribeStreamFailedDeclare(t *testing.T) { errs := make(chan amqp091Error) amock.On("NotifyClose").Return(errs) + amock.On("NewChannel", false).Return(cmock, nil) oldNewAmqpConn091 := NewAmqpConn091 NewAmqpConn091 = func(string, string, *tls.Config) amqp091ConnectionShim { return amock @@ -1102,6 +1117,7 @@ func Test_StreamRetry(t *testing.T) { } cmock := &amqpChannelMock{} + cmock.On("Close").Return(nil) amock := &amqpConnectionMock{} amock.On("Connect").Return(nil) @@ -1110,6 +1126,7 @@ func Test_StreamRetry(t *testing.T) { errs := make(chan amqp091Error) amock.On("NotifyClose").Return(errs) + amock.On("NewChannel", false).Return(cmock, nil) oldNewAmqpConn091 := NewAmqpConn091 NewAmqpConn091 = func(string, string, *tls.Config) amqp091ConnectionShim { return amock @@ -1199,6 +1216,10 @@ func Test_Subscribe_Stream_DeclareOnly(t *testing.T) { return amock } + cmock := &amqpChannelMock{} + cmock.On("Close").Return(nil) + amock.On("NewChannel", false).Return(cmock, nil) + defer func() { NewStreamConn = oldNewStreamConn }() @@ -1223,3 +1244,139 @@ func Test_Subscribe_Stream_DeclareOnly(t *testing.T) { }) } } + +// Test_streamSubscribe_ExchangeAndBindingDeclaration verifies that +// when subscribing to a Source_STREAM source, an exchange and +// binding are declared on the AMQP channel if (and only if) the address type +// is NOT Address_STREAM. +func Test_streamSubscribe_ExchangeAndBindingDeclaration(t *testing.T) { + tests := []struct { + name string + addressType pb.Address_TargetType + expectExchangeDeclare bool + expectQueueBind bool + needsManagementServer bool // cleanupBindings makes HTTP calls only when bindings are actually declared + }{ + { + name: "non-stream address type declares exchange and binding", + addressType: pb.Address_TOPIC, + expectExchangeDeclare: true, + expectQueueBind: true, + needsManagementServer: true, + }, + { + name: "stream address type skips exchange and binding", + addressType: pb.Address_STREAM, + expectExchangeDeclare: false, + expectQueueBind: false, + needsManagementServer: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + prov := NewAMQP091Provider() + + oldGetClientIdentifier := GetClientIdentifier + GetClientIdentifier = func(context.Context) (string, error) { + return "1234", nil + } + oldNewAmqpConn091 := NewAmqpConn091 + oldNewStreamConn := NewStreamConn + defer func() { + GetClientIdentifier = oldGetClientIdentifier + NewAmqpConn091 = oldNewAmqpConn091 + NewStreamConn = oldNewStreamConn + }() + + const addressName = "addressName" + const sourceName = "srcName" + addr := &pb.Address{ + Name: addressName, + Subjects: []string{"subject1"}, + Type: tc.addressType, + } + src := &pb.Source{ + Name: sourceName, + Address: addr, + Type: pb.Source_STREAM, + Options: map[string]string{"Offset": "0"}, + DeclareOnly: true, // stop after declare; avoids consumer setup + } + + // stream connection mock + smock := &streamConnectionMock{} + smock.On("Connect").Return(nil) + smock.On("IsClosed").Return(false) + smock.On("DeclareStream").Return(nil) + NewStreamConn = func(string, string, *tls.Config) streamConnectionShim { + return smock + } + + // amqp channel mock – ExchangeDeclare and QueueBind are only + // expected when the address type is not Address_STREAM + cmock := &amqpChannelMock{} + cmock.On("Close").Return(nil) + if tc.expectExchangeDeclare { + cmock.On("ExchangeDeclare", addressName, "topic", false).Return(nil) + } + if tc.expectQueueBind { + cmock.On("QueueBind", sourceName, "subject1", addressName, mock.Anything).Return(nil) + } + + // amqp connection mock + amock := &amqpConnectionMock{} + amock.On("Connect").Return(nil) + errs := make(chan amqp091Error) + amock.On("NotifyClose").Return(errs) + amock.On("NewChannel", false).Return(cmock, nil) + NewAmqpConn091 = func(string, string, *tls.Config) amqp091ConnectionShim { + return amock + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + cc := &pb.ConnectionConfiguration{} + + // The cleanupBindings call inside declareBinding issues an HTTP + // request to the management API; wire that up when needed. + if tc.needsManagementServer { + msrv := mockManagementRequestServer() + defer msrv.Close() + u, serr := url.Parse(msrv.URL) + assert.Nil(t, serr) + cc.Host = u.Hostname() + cc.Tenant = testTenant + i, _ := strconv.Atoi(u.Port()) + cc.AdminPort = int32(i) //nolint:gosec + } + + connectErr := prov.Connect(ctx, cc, false) + assert.Nil(t, connectErr) + + mc := make(chan *pb.Message) + defer close(mc) + + subscribeErr := prov.Subscribe(ctx, src, mc) + assert.Nil(t, subscribeErr) + + // Core assertions: verify exchange and binding call counts + cmock.AssertNumberOfCalls(t, "ExchangeDeclare", func() int { + if tc.expectExchangeDeclare { + return 1 + } + return 0 + }()) + cmock.AssertNumberOfCalls(t, "QueueBind", func() int { + if tc.expectQueueBind { + return 1 + } + return 0 + }()) + + cmock.AssertExpectations(t) + amock.AssertExpectations(t) + smock.AssertExpectations(t) + }) + } +}