diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/delayed/InMemoryDelayedDeliveryTracker.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/delayed/InMemoryDelayedDeliveryTracker.java index 8da74a553ddf5..2ca430c4f8866 100644 --- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/delayed/InMemoryDelayedDeliveryTracker.java +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/delayed/InMemoryDelayedDeliveryTracker.java @@ -122,21 +122,29 @@ private static long trimLowerBit(long timestamp, int bits) { } @Override - public boolean addMessage(long ledgerId, long entryId, long deliverAt) { + public synchronized boolean addMessage(long ledgerId, long entryId, long deliverAt) { if (deliverAt < 0 || deliverAt <= getCutoffTime()) { messagesHaveFixedDelay = false; return false; } - log.debug() - .attr("ledgerId", ledgerId) - .attr("entryId", entryId) - .attr("deliveryInMs", () -> deliverAt - clock.millis()) - .log("Add message"); - long timestamp = trimLowerBit(deliverAt, timestampPrecisionBitCnt); - delayedMessageMap.computeIfAbsent(timestamp, k -> new TreeMap<>()) - .computeIfAbsent(ledgerId, k -> new Roaring64Bitmap()) - .add(entryId); - delayedMessagesCount.incrementAndGet(); + + log.debug() + .attr("ledgerId", ledgerId) + .attr("entryId", entryId) + .attr("deliveryInMs", () -> deliverAt - clock.millis()) + .log("Add message"); + long timestamp = trimLowerBit(deliverAt, timestampPrecisionBitCnt); + + Roaring64Bitmap bitmap = delayedMessageMap.computeIfAbsent(timestamp, k -> new TreeMap<>()) + .computeIfAbsent(ledgerId, k -> new Roaring64Bitmap()); + // Roaring64Bitmap does not store duplicates, so track if it a new element + // so we can keep delayedMessagesCount in sync + boolean isNew = !bitmap.contains(entryId); + + if (isNew) { + bitmap.add(entryId); + delayedMessagesCount.incrementAndGet(); + } updateTimer(); @@ -161,7 +169,7 @@ private void checkAndUpdateHighest(long deliverAt) { * Return true if there's at least a message that is scheduled to be delivered already. */ @Override - public boolean hasMessageAvailable() { + public synchronized boolean hasMessageAvailable() { boolean hasMessageAvailable = !delayedMessageMap.isEmpty() && delayedMessageMap.firstKey() <= getCutoffTime(); if (!hasMessageAvailable) { @@ -174,7 +182,7 @@ public boolean hasMessageAvailable() { * Get a set of position of messages that have already reached. */ @Override - public NavigableSet getScheduledMessages(int maxMessages) { + public synchronized NavigableSet getScheduledMessages(int maxMessages) { int n = maxMessages; NavigableSet positions = new TreeSet<>(); long cutoffTime = getCutoffTime(); @@ -237,14 +245,14 @@ public NavigableSet getScheduledMessages(int maxMessages) { } @Override - public CompletableFuture clear() { + public synchronized CompletableFuture clear() { this.delayedMessageMap.clear(); this.delayedMessagesCount.set(0); return CompletableFuture.completedFuture(null); } @Override - public long getNumberOfDelayedMessages() { + public synchronized long getNumberOfDelayedMessages() { return delayedMessagesCount.get(); } @@ -255,14 +263,14 @@ public long getNumberOfDelayedMessages() { * @return the memory usage of the buffer */ @Override - public long getBufferMemoryUsage() { + public synchronized long getBufferMemoryUsage() { return delayedMessageMap.values().stream().mapToLong( ledgerMap -> ledgerMap.values().stream().mapToLong( Roaring64Bitmap::getLongSizeInBytes).sum()).sum(); } @Override - public void close() { + public synchronized void close() { super.close(); } @@ -275,7 +283,7 @@ && getNumberOfDelayedMessages() >= fixedDelayDetectionLookahead && !hasMessageAvailable(); } - protected long nextDeliveryTime() { + protected synchronized long nextDeliveryTime() { return delayedMessageMap.firstKey(); } } diff --git a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/persistent/PersistentDispatcherMultipleConsumers.java b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/persistent/PersistentDispatcherMultipleConsumers.java index c569cf5b68cb2..274924b716cdb 100644 --- a/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/persistent/PersistentDispatcherMultipleConsumers.java +++ b/pulsar-broker/src/main/java/org/apache/pulsar/broker/service/persistent/PersistentDispatcherMultipleConsumers.java @@ -427,7 +427,7 @@ public synchronized void readMoreEntries() { } } - protected Predicate createReadEntriesSkipConditionForNormalRead() { + protected synchronized Predicate createReadEntriesSkipConditionForNormalRead() { Predicate skipCondition = null; // Filter out and skip read delayed messages exist in DelayedDeliveryTracker if (delayedDeliveryTracker.isPresent()) { @@ -1378,7 +1378,7 @@ protected synchronized boolean shouldPauseDeliveryForDelayTracker() { } @Override - public long getNumberOfDelayedMessages() { + public synchronized long getNumberOfDelayedMessages() { return delayedDeliveryTracker.map(DelayedDeliveryTracker::getNumberOfDelayedMessages).orElse(0L); } @@ -1389,7 +1389,9 @@ public CompletableFuture clearDelayedMessages() { } if (delayedDeliveryTracker.isPresent()) { - return this.delayedDeliveryTracker.get().clear(); + synchronized (this) { + return this.delayedDeliveryTracker.get().clear(); + } } else { DelayedDeliveryTrackerFactory delayedDeliveryTrackerFactory = topic.getBrokerService().getDelayedDeliveryTrackerFactory(); @@ -1464,11 +1466,11 @@ public PersistentTopic getTopic() { } - public long getDelayedTrackerMemoryUsage() { + public synchronized long getDelayedTrackerMemoryUsage() { return delayedDeliveryTracker.map(DelayedDeliveryTracker::getBufferMemoryUsage).orElse(0L); } - public Map getBucketDelayedIndexStats() { + public synchronized Map getBucketDelayedIndexStats() { if (delayedDeliveryTracker.isEmpty()) { return Collections.emptyMap(); } diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/broker/delayed/InMemoryDeliveryTrackerTest.java b/pulsar-broker/src/test/java/org/apache/pulsar/broker/delayed/InMemoryDeliveryTrackerTest.java index e25595072d3c9..3c5b576c32a81 100644 --- a/pulsar-broker/src/test/java/org/apache/pulsar/broker/delayed/InMemoryDeliveryTrackerTest.java +++ b/pulsar-broker/src/test/java/org/apache/pulsar/broker/delayed/InMemoryDeliveryTrackerTest.java @@ -36,8 +36,13 @@ import java.util.NavigableMap; import java.util.Set; import java.util.TreeMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; import lombok.Cleanup; import org.apache.bookkeeper.mledger.Position; import org.apache.pulsar.broker.service.persistent.AbstractPersistentDispatcherMultipleConsumers; @@ -274,4 +279,111 @@ public void testDelaySequence(InMemoryDelayedDeliveryTracker tracker) throws Exc tracker.close(); } + @Test(dataProvider = "delayedTracker") + public void testAddMultipleMessagesSameWindow(InMemoryDelayedDeliveryTracker tracker) throws Exception { + tracker.addMessage(1, 1, 50); + tracker.addMessage(1, 1, 50); + tracker.addMessage(1, 1, 50); + + clockTime.set(60); + + tracker.getScheduledMessages(10); + } + + @Test(dataProvider = "delayedTracker") + public void testRaceConditionInUpdateTimer(InMemoryDelayedDeliveryTracker tracker) throws Exception { + final int numThreads = 16; + final int operationsPerThread = 1000; + final CountDownLatch startLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(numThreads); + final AtomicInteger errors = new AtomicInteger(0); + final AtomicReference firstException = new AtomicReference<>(); + + @Cleanup("shutdown") + ExecutorService executorService = Executors.newFixedThreadPool(32); + + for (int i = 0; i < 2; i++) { + executorService.submit(() -> { + try { + startLatch.await(); + for (int j = 0; j < operationsPerThread; j++) { + tracker.clear(); + Thread.sleep(1); + } + } catch (Exception e) { + errors.incrementAndGet(); + firstException.compareAndSet(null, e); + e.printStackTrace(); + } finally { + doneLatch.countDown(); + } + }); + } + + for (int i = 0; i < 5; i++) { + executorService.submit(() -> { + try { + startLatch.await(); + for (int j = 0; j < operationsPerThread; j++) { + tracker.addMessage(1, 1, 10); + Thread.sleep(1); + } + } catch (Exception e) { + errors.incrementAndGet(); + firstException.compareAndSet(null, e); + e.printStackTrace(); + } finally { + doneLatch.countDown(); + } + }); + } + + for (int i = 0; i < 5; i++) { + executorService.submit(() -> { + try { + startLatch.await(); + for (int j = 0; j < operationsPerThread; j++) { + tracker.getNumberOfDelayedMessages(); + Thread.sleep(1); + } + } catch (Exception e) { + errors.incrementAndGet(); + firstException.compareAndSet(null, e); + e.printStackTrace(); + } finally { + doneLatch.countDown(); + } + }); + } + + for (int i = 0; i < 5; i++) { + executorService.submit(() -> { + try { + startLatch.await(); + for (int j = 0; j < operationsPerThread; j++) { + tracker.getScheduledMessages(1); + Thread.sleep(1); + } + } catch (Exception e) { + errors.incrementAndGet(); + firstException.compareAndSet(null, e); + e.printStackTrace(); + } finally { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); + assertTrue(doneLatch.await(30, TimeUnit.SECONDS), "Test should complete within 30 seconds"); + + if (errors.get() > 0) { + Exception exception = firstException.get(); + if (exception != null) { + System.err.println("First exception caught: " + exception.getMessage()); + exception.printStackTrace(); + } + } + assertEquals(errors.get(), 0, "No exceptions should occur during concurrent operations"); + } } diff --git a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentDispatcherMultipleConsumersTest.java b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentDispatcherMultipleConsumersTest.java index e5da7850dfd63..cf91f29988e38 100644 --- a/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentDispatcherMultipleConsumersTest.java +++ b/pulsar-broker/src/test/java/org/apache/pulsar/broker/service/persistent/PersistentDispatcherMultipleConsumersTest.java @@ -20,7 +20,12 @@ import com.carrotsearch.hppc.ObjectSet; import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; import lombok.Cleanup; import lombok.CustomLog; import org.apache.bookkeeper.mledger.ManagedCursor; @@ -34,6 +39,7 @@ import org.apache.pulsar.client.api.Producer; import org.apache.pulsar.client.api.Schema; import org.apache.pulsar.client.api.SubscriptionType; +import org.apache.pulsar.common.api.proto.MessageMetadata; import org.awaitility.reflect.WhiteboxImpl; import org.mockito.Mockito; import org.testng.Assert; @@ -152,4 +158,97 @@ public void testSkipReadEntriesFromCloseCursor() throws Exception { // Verify: the topic can be deleted successfully. admin.topics().delete(topicName, false); } + + @Test + public void testRaceConditionInTrackDelayedDelivery() throws Exception { + final int numThreads = 16; + final int operationsPerThread = 2000; + final CountDownLatch startLatch = new CountDownLatch(1); + final CountDownLatch doneLatch = new CountDownLatch(numThreads); + final AtomicInteger errors = new AtomicInteger(0); + final AtomicReference firstException = new AtomicReference<>(); + + final String topicName = newTopicName(); + final String subscription = "s1"; + + // Needed to create the topic + Consumer consumer = pulsarClient.newConsumer(Schema.STRING) + .topic(topicName).subscriptionName(subscription) + .subscriptionType(SubscriptionType.Shared).subscribe(); + + PersistentTopic topic = (PersistentTopic) getTopic(topicName, false).join().get(); + + ManagedCursor cursor = Mockito.mock(ManagedCursorImpl.class); + Mockito.doReturn(subscription).when(cursor).getName(); + + Subscription sub = Mockito.mock(PersistentSubscription.class); + Mockito.doReturn(topic).when(sub).getTopic(); + + PersistentDispatcherMultipleConsumers dispatcher = + new PersistentDispatcherMultipleConsumers(topic, cursor, sub); + + // Align all writes to the same bucket + // This is the key which triggers the race condition + long deliverAt = System.currentTimeMillis() + 5000; + + MessageMetadata messageMetadata = new MessageMetadata() + .setSequenceId(1) + .setProducerName("testProducer") + .setPartitionKeyB64Encoded(false) + .setPublishTime(System.currentTimeMillis()) + .setDeliverAtTime(deliverAt); + + @Cleanup("shutdown") + ExecutorService executorService = Executors.newFixedThreadPool(32); + + // Start clear message thread + for (int i = 0; i < numThreads / 2; i++) { + executorService.submit(() -> { + try { + startLatch.await(); + for (int j = 0; j < operationsPerThread; j++) { + dispatcher.clearDelayedMessages(); + Thread.sleep(1); + } + } catch (Exception e) { + errors.incrementAndGet(); + firstException.compareAndSet(null, e); + e.printStackTrace(); + } finally { + doneLatch.countDown(); + } + }); + } + + // Start track delayed delivery thread + for (int i = numThreads / 2; i < numThreads; i++) { + executorService.submit(() -> { + try { + startLatch.await(); + for (int j = 0; j < operationsPerThread; j++) { + dispatcher.trackDelayedDelivery(1, 1, messageMetadata); + Thread.sleep(1); + } + } catch (Exception e) { + errors.incrementAndGet(); + firstException.compareAndSet(null, e); + e.printStackTrace(); + } finally { + doneLatch.countDown(); + } + }); + } + + startLatch.countDown(); + Assert.assertTrue(doneLatch.await(30, TimeUnit.SECONDS), "Test should complete within 30 seconds"); + + if (errors.get() > 0) { + Exception exception = firstException.get(); + if (exception != null) { + System.err.println("First exception caught: " + exception.getMessage()); + exception.printStackTrace(); + } + } + Assert.assertEquals(errors.get(), 0, "No exceptions should occur during concurrent operations"); + } }