diff --git a/pulsar/message_chunking_test.go b/pulsar/message_chunking_test.go index 12f0517c44..b7918756b7 100644 --- a/pulsar/message_chunking_test.go +++ b/pulsar/message_chunking_test.go @@ -26,7 +26,6 @@ import ( "net/http" "os" "strings" - "sync" "testing" "time" @@ -554,30 +553,29 @@ func sendSingleChunk(p Producer, uuid string, chunkID int, totalChunks int) { mm.TotalChunkMsgSize = proto.Int32(int32(len(wholePayload))) mm.ChunkId = proto.Int32(int32(chunkID)) producerImpl.updateMetadataSeqID(mm, msg) + sr := newSendRequest( + context.Background(), + producerImpl, + msg, + func(MessageID, *ProducerMessage, error) {}, + true, + ) + sr.totalChunks = totalChunks + sr.chunkID = chunkID + sr.uuid = uuid + sr.chunkRecorder = newChunkRecorder() + sr.uncompressedPayload = wholePayload + sr.uncompressedSize = int64(len(wholePayload)) + sr.compressedPayload = wholePayload + sr.compressedSize = len(wholePayload) + sr.payloadChunkSize = internal.MaxMessageSize - proto.Size(mm) + sr.mm = mm + sr.deliverAt = time.Now() + sr.maxMessageSize = internal.MaxMessageSize producerImpl.internalSingleSend( mm, msg.Payload, - &sendRequest{ - callback: func(id MessageID, producerMessage *ProducerMessage, err error) { - }, - callbackOnce: &sync.Once{}, - ctx: context.Background(), - msg: msg, - producer: producerImpl, - flushImmediately: true, - totalChunks: totalChunks, - chunkID: chunkID, - uuid: uuid, - chunkRecorder: newChunkRecorder(), - uncompressedPayload: wholePayload, - uncompressedSize: int64(len(wholePayload)), - compressedPayload: wholePayload, - compressedSize: len(wholePayload), - payloadChunkSize: internal.MaxMessageSize - proto.Size(mm), - mm: mm, - deliverAt: time.Now(), - maxMessageSize: internal.MaxMessageSize, - }, + sr, uint32(internal.MaxMessageSize), ) } diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go index 833e9f4b38..65b5980001 100755 --- a/pulsar/producer_partition.go +++ b/pulsar/producer_partition.go @@ -660,36 +660,7 @@ func (p *partitionProducer) internalSend(sr *sendRequest) { } // update chunk id sr.mm.ChunkId = proto.Int32(int32(chunkID)) - nsr := sendRequestPool.Get().(*sendRequest) - *nsr = sendRequest{ - pool: sendRequestPool, - ctx: sr.ctx, - msg: sr.msg, - producer: sr.producer, - callback: sr.callback, - callbackOnce: sr.callbackOnce, - publishTime: sr.publishTime, - flushImmediately: sr.flushImmediately, - totalChunks: sr.totalChunks, - chunkID: chunkID, - uuid: uuid, - chunkRecorder: cr, - transaction: sr.transaction, - memLimit: sr.memLimit, - semaphore: sr.semaphore, - reservedMem: int64(rhs - lhs), - sendAsBatch: sr.sendAsBatch, - schema: sr.schema, - schemaVersion: sr.schemaVersion, - uncompressedPayload: sr.uncompressedPayload, - uncompressedSize: sr.uncompressedSize, - compressedPayload: sr.compressedPayload, - compressedSize: sr.compressedSize, - payloadChunkSize: sr.payloadChunkSize, - mm: sr.mm, - deliverAt: sr.deliverAt, - maxMessageSize: sr.maxMessageSize, - } + nsr := newChunkSendRequest(sr, chunkID, uuid, cr, int64(rhs-lhs)) p.internalSingleSend(nsr.mm, nsr.compressedPayload[lhs:rhs], nsr, uint32(nsr.maxMessageSize)) } @@ -1326,18 +1297,7 @@ func (p *partitionProducer) internalSendAsync( return } - sr := sendRequestPool.Get().(*sendRequest) - *sr = sendRequest{ - pool: sendRequestPool, - ctx: ctx, - msg: msg, - producer: p, - callback: callback, - callbackOnce: &sync.Once{}, - flushImmediately: flushImmediately, - publishTime: time.Now(), - chunkID: -1, - } + sr := newSendRequest(ctx, p, msg, callback, flushImmediately) if err := p.prepareTransaction(sr); err != nil { sr.done(nil, err) @@ -1612,6 +1572,7 @@ func (p *partitionProducer) Close() { } type sendRequest struct { + doneFlag atomic.Bool pool *sync.Pool ctx context.Context msg *ProducerMessage @@ -1648,7 +1609,83 @@ type sendRequest struct { maxMessageSize int32 } +func newSendRequest( + ctx context.Context, + p *partitionProducer, + msg *ProducerMessage, + callback func(MessageID, *ProducerMessage, error), + flushImmediately bool, +) *sendRequest { + sr := sendRequestPool.Get().(*sendRequest) + sr.pool = sendRequestPool + sr.ctx = ctx + sr.msg = msg + sr.producer = p + sr.callback = callback + sr.callbackOnce = &sync.Once{} + sr.flushImmediately = flushImmediately + sr.publishTime = time.Now() + sr.chunkID = -1 + sr.totalChunks = 0 + sr.uuid = "" + sr.chunkRecorder = nil + sr.transaction = nil + sr.memLimit = nil + sr.semaphore = nil + sr.reservedMem = 0 + sr.sendAsBatch = false + sr.schema = nil + sr.schemaVersion = nil + sr.uncompressedPayload = nil + sr.uncompressedSize = 0 + sr.compressedPayload = nil + sr.compressedSize = 0 + sr.payloadChunkSize = 0 + sr.mm = nil + sr.deliverAt = time.Time{} + sr.maxMessageSize = 0 + sr.doneFlag.Store(false) + return sr +} + +func newChunkSendRequest(p *sendRequest, chunkID int, uuid string, cr *chunkRecorder, reservedMem int64) *sendRequest { + sr := sendRequestPool.Get().(*sendRequest) + sr.pool = sendRequestPool + sr.ctx = p.ctx + sr.msg = p.msg + sr.producer = p.producer + sr.callback = p.callback + sr.callbackOnce = p.callbackOnce + sr.publishTime = p.publishTime + sr.flushImmediately = p.flushImmediately + sr.totalChunks = p.totalChunks + sr.chunkID = chunkID + sr.uuid = uuid + sr.chunkRecorder = cr + sr.transaction = p.transaction + sr.memLimit = p.memLimit + sr.semaphore = p.semaphore + sr.reservedMem = reservedMem + sr.sendAsBatch = p.sendAsBatch + sr.schema = p.schema + sr.schemaVersion = p.schemaVersion + sr.uncompressedPayload = p.uncompressedPayload + sr.uncompressedSize = p.uncompressedSize + sr.compressedPayload = p.compressedPayload + sr.compressedSize = p.compressedSize + sr.payloadChunkSize = p.payloadChunkSize + sr.mm = p.mm + sr.deliverAt = p.deliverAt + sr.maxMessageSize = p.maxMessageSize + sr.doneFlag.Store(false) + return sr +} + func (sr *sendRequest) done(msgID MessageID, err error) { + if !sr.doneFlag.CompareAndSwap(false, true) { + return + } + if err == nil { sr.producer.metrics.PublishLatency.Observe(float64(time.Now().UnixNano()-sr.publishTime.UnixNano()) / 1.0e9) sr.producer.metrics.MessagesPublished.Inc() @@ -1698,12 +1735,45 @@ func (sr *sendRequest) done(msgID MessageID, err error) { pool := sr.pool if pool != nil { - // reset all the fields - *sr = sendRequest{} + sr.reset() pool.Put(sr) } } +// reset clears all fields and returns the sendRequest to a reusable state. +// The doneFlag is intentionally left raised; newSendRequest will lower it. +func (sr *sendRequest) reset() { + sr.doneFlag.Store(true) + sr.pool = nil + sr.ctx = nil + sr.msg = nil + sr.producer = nil + sr.callback = nil + sr.callbackOnce = nil + sr.publishTime = time.Time{} + sr.flushImmediately = false + sr.totalChunks = 0 + sr.chunkID = 0 + sr.uuid = "" + sr.chunkRecorder = nil + sr.memLimit = nil + sr.reservedMem = 0 + sr.semaphore = nil + sr.reservedSemaphore = 0 + sr.sendAsBatch = false + sr.transaction = nil + sr.schema = nil + sr.schemaVersion = nil + sr.uncompressedPayload = nil + sr.uncompressedSize = 0 + sr.compressedPayload = nil + sr.compressedSize = 0 + sr.payloadChunkSize = 0 + sr.mm = nil + sr.deliverAt = time.Time{} + sr.maxMessageSize = 0 +} + func (p *partitionProducer) blockIfQueueFull() bool { // DisableBlockIfQueueFull == false means enable block return !p.options.DisableBlockIfQueueFull diff --git a/pulsar/send_request_pool_test.go b/pulsar/send_request_pool_test.go new file mode 100644 index 0000000000..40c642530f --- /dev/null +++ b/pulsar/send_request_pool_test.go @@ -0,0 +1,45 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package pulsar + +import ( + "context" + "errors" + "testing" + + plog "github.com/apache/pulsar-client-go/pulsar/log" + "github.com/stretchr/testify/require" +) + +func TestSendRequestDoneIsIdempotentAfterPutToPool(t *testing.T) { + sr := newSendRequest( + context.Background(), + &partitionProducer{log: plog.DefaultNopLogger()}, + &ProducerMessage{Properties: map[string]string{"k": "v"}}, + func(MessageID, *ProducerMessage, error) {}, + false, + ) + + // First done() call returns sr to the pool and resets it. + sr.done(nil, errors.New("first error")) + + // A second done() call on the same pointer should be ignored safely. + require.NotPanics(t, func() { + sr.done(nil, errors.New("second error")) + }) +}