Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions internal/provider/connectors/amqp091/amqp091.go
Original file line number Diff line number Diff line change
Expand Up @@ -1209,11 +1209,31 @@ func (prov *amqp091provider) streamSubscribe(ctx context.Context, bd *BrokerDeta
ttl = val
}

amqpChannel, err := bd.Connection.NewChannel(false)
if err != nil {
return &pb.Error{Message: err.Error()}
}
defer amqpChannel.Close()

if source.GetAddress().GetType() != pb.Address_STREAM {
err := prov.declareExchange(source.GetAddress(), bd, amqpChannel)
if err != nil {
util.Logger.Debugf("Failed to declare exchange for source %s: %v", source.GetName(), err)
}
}

dErr := bd.StreamConnection.DeclareStream(source.GetName(), ttl)
if dErr != nil {
return &pb.Error{IsFatal: true, Message: fmt.Sprintf("failed to declare stream: %s", dErr.Error())}
}

if source.GetAddress().GetType() != pb.Address_STREAM {
err := prov.declareBinding(source, bd, amqpChannel, true)
if err != nil {
util.Logger.Debugf("Failed to declare binding for source %s: %s", source.GetName(), err.Error())
}
}

if source.GetDeclareOnly() {
// if we reach here, everything has succeeded and we should return from Consume if source.DeclareOnly = true
return nil
Expand Down
159 changes: 158 additions & 1 deletion internal/provider/connectors/amqp091/stream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ import (
"crypto/tls"
"errors"
"fmt"
"net/url"
"reflect"
"strconv"
"sync"
"testing"
"time"
Expand Down Expand Up @@ -533,12 +535,14 @@ func Test_SubscribeStream(t *testing.T) {
pmock := &streamConsumerMock{}
pmock.On("Close").Return(nil)
smock.On("NewConsumer", src.GetName(), src.GetName(), src.GetOptions()["Offset"], mock.Anything, mock.AnythingOfType("bool")).Return(pmock, nil)
smock.On("StoreOffset", src.GetName(), src.GetName(), mock.Anything).Return(nil)
oldNewStreamConn := NewStreamConn
NewStreamConn = func(string, string, *tls.Config) streamConnectionShim {
return smock
}

cmock := &amqpChannelMock{}
cmock.On("Close").Return(nil)

amock := &amqpConnectionMock{}
amock.On("Connect").Return(nil)
Expand All @@ -547,6 +551,7 @@ func Test_SubscribeStream(t *testing.T) {

errs := make(chan amqp091Error)
amock.On("NotifyClose").Return(errs)
amock.On("NewChannel", false).Return(cmock, nil)
oldNewAmqpConn091 := NewAmqpConn091
NewAmqpConn091 = func(string, string, *tls.Config) amqp091ConnectionShim {
return amock
Expand Down Expand Up @@ -585,10 +590,13 @@ func Test_SubscribeStream(t *testing.T) {
cancel()
time.Sleep(1000 * time.Millisecond)

// address type is STREAM: exchange and binding must NOT be declared
cmock.AssertNumberOfCalls(t, "ExchangeDeclare", 0)
cmock.AssertNumberOfCalls(t, "QueueBind", 0)
cmock.AssertExpectations(t)
amock.AssertExpectations(t)
pmock.AssertExpectations(t)
smock.AssertExpectations(t)
cmock.AssertExpectations(t)
}

func Test_SubscribeStreamBadOpt(t *testing.T) {
Expand Down Expand Up @@ -716,6 +724,11 @@ func Test_streamSubscribe(t *testing.T) {
return amock
}

cmock := &amqpChannelMock{}
cmock.On("Close").Return(nil)
amock.On("NewChannel", false).Return(cmock, nil)
bd.Connection = amock

pmock := &streamConsumerMock{}
pmock.On("Close").Return(nil)

Expand Down Expand Up @@ -1023,6 +1036,7 @@ func Test_SubscribeStreamFailedDeclare(t *testing.T) {
}

cmock := &amqpChannelMock{}
cmock.On("Close").Return(nil)

amock := &amqpConnectionMock{}
amock.On("Connect").Return(nil)
Expand All @@ -1031,6 +1045,7 @@ func Test_SubscribeStreamFailedDeclare(t *testing.T) {

errs := make(chan amqp091Error)
amock.On("NotifyClose").Return(errs)
amock.On("NewChannel", false).Return(cmock, nil)
oldNewAmqpConn091 := NewAmqpConn091
NewAmqpConn091 = func(string, string, *tls.Config) amqp091ConnectionShim {
return amock
Expand Down Expand Up @@ -1102,6 +1117,7 @@ func Test_StreamRetry(t *testing.T) {
}

cmock := &amqpChannelMock{}
cmock.On("Close").Return(nil)

amock := &amqpConnectionMock{}
amock.On("Connect").Return(nil)
Expand All @@ -1110,6 +1126,7 @@ func Test_StreamRetry(t *testing.T) {

errs := make(chan amqp091Error)
amock.On("NotifyClose").Return(errs)
amock.On("NewChannel", false).Return(cmock, nil)
oldNewAmqpConn091 := NewAmqpConn091
NewAmqpConn091 = func(string, string, *tls.Config) amqp091ConnectionShim {
return amock
Expand Down Expand Up @@ -1199,6 +1216,10 @@ func Test_Subscribe_Stream_DeclareOnly(t *testing.T) {
return amock
}

cmock := &amqpChannelMock{}
cmock.On("Close").Return(nil)
amock.On("NewChannel", false).Return(cmock, nil)

defer func() {
NewStreamConn = oldNewStreamConn
}()
Expand All @@ -1223,3 +1244,139 @@ func Test_Subscribe_Stream_DeclareOnly(t *testing.T) {
})
}
}

// Test_streamSubscribe_ExchangeAndBindingDeclaration verifies that
// when subscribing to a Source_STREAM source, an exchange and
// binding are declared on the AMQP channel if (and only if) the address type
// is NOT Address_STREAM.
func Test_streamSubscribe_ExchangeAndBindingDeclaration(t *testing.T) {
tests := []struct {
name string
addressType pb.Address_TargetType
expectExchangeDeclare bool
expectQueueBind bool
needsManagementServer bool // cleanupBindings makes HTTP calls only when bindings are actually declared
}{
{
name: "non-stream address type declares exchange and binding",
addressType: pb.Address_TOPIC,
expectExchangeDeclare: true,
expectQueueBind: true,
needsManagementServer: true,
},
{
name: "stream address type skips exchange and binding",
addressType: pb.Address_STREAM,
expectExchangeDeclare: false,
expectQueueBind: false,
needsManagementServer: false,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
prov := NewAMQP091Provider()

oldGetClientIdentifier := GetClientIdentifier
GetClientIdentifier = func(context.Context) (string, error) {
return "1234", nil
}
oldNewAmqpConn091 := NewAmqpConn091
oldNewStreamConn := NewStreamConn
defer func() {
GetClientIdentifier = oldGetClientIdentifier
NewAmqpConn091 = oldNewAmqpConn091
NewStreamConn = oldNewStreamConn
}()

const addressName = "addressName"
const sourceName = "srcName"
addr := &pb.Address{
Name: addressName,
Subjects: []string{"subject1"},
Type: tc.addressType,
}
src := &pb.Source{
Name: sourceName,
Address: addr,
Type: pb.Source_STREAM,
Options: map[string]string{"Offset": "0"},
DeclareOnly: true, // stop after declare; avoids consumer setup
}

// stream connection mock
smock := &streamConnectionMock{}
smock.On("Connect").Return(nil)
smock.On("IsClosed").Return(false)
smock.On("DeclareStream").Return(nil)
NewStreamConn = func(string, string, *tls.Config) streamConnectionShim {
return smock
}

// amqp channel mock – ExchangeDeclare and QueueBind are only
// expected when the address type is not Address_STREAM
cmock := &amqpChannelMock{}
cmock.On("Close").Return(nil)
if tc.expectExchangeDeclare {
cmock.On("ExchangeDeclare", addressName, "topic", false).Return(nil)
}
if tc.expectQueueBind {
cmock.On("QueueBind", sourceName, "subject1", addressName, mock.Anything).Return(nil)
}

// amqp connection mock
amock := &amqpConnectionMock{}
amock.On("Connect").Return(nil)
errs := make(chan amqp091Error)
amock.On("NotifyClose").Return(errs)
amock.On("NewChannel", false).Return(cmock, nil)
NewAmqpConn091 = func(string, string, *tls.Config) amqp091ConnectionShim {
return amock
}

ctx, cancel := context.WithCancel(context.Background())
defer cancel()
cc := &pb.ConnectionConfiguration{}

// The cleanupBindings call inside declareBinding issues an HTTP
// request to the management API; wire that up when needed.
if tc.needsManagementServer {
msrv := mockManagementRequestServer()
defer msrv.Close()
u, serr := url.Parse(msrv.URL)
assert.Nil(t, serr)
cc.Host = u.Hostname()
cc.Tenant = testTenant
i, _ := strconv.Atoi(u.Port())
cc.AdminPort = int32(i) //nolint:gosec
}

connectErr := prov.Connect(ctx, cc, false)
assert.Nil(t, connectErr)

mc := make(chan *pb.Message)
defer close(mc)

subscribeErr := prov.Subscribe(ctx, src, mc)
assert.Nil(t, subscribeErr)

// Core assertions: verify exchange and binding call counts
cmock.AssertNumberOfCalls(t, "ExchangeDeclare", func() int {
if tc.expectExchangeDeclare {
return 1
}
return 0
}())
cmock.AssertNumberOfCalls(t, "QueueBind", func() int {
if tc.expectQueueBind {
return 1
}
return 0
}())

cmock.AssertExpectations(t)
amock.AssertExpectations(t)
smock.AssertExpectations(t)
})
}
}