Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ private static long trimLowerBit(long timestamp, int bits) {
}

@Override
public boolean addMessage(long ledgerId, long entryId, long deliverAt) {
public synchronized boolean addMessage(long ledgerId, long entryId, long deliverAt) {
if (deliverAt < 0 || deliverAt <= getCutoffTime()) {
messagesHaveFixedDelay = false;
return false;
Expand Down Expand Up @@ -161,7 +161,7 @@ private void checkAndUpdateHighest(long deliverAt) {
* Return true if there's at least a message that is scheduled to be delivered already.
*/
@Override
public boolean hasMessageAvailable() {
public synchronized boolean hasMessageAvailable() {
boolean hasMessageAvailable = !delayedMessageMap.isEmpty()
&& delayedMessageMap.firstKey() <= getCutoffTime();
if (!hasMessageAvailable) {
Expand All @@ -174,7 +174,7 @@ public boolean hasMessageAvailable() {
* Get a set of position of messages that have already reached.
*/
@Override
public NavigableSet<Position> getScheduledMessages(int maxMessages) {
public synchronized NavigableSet<Position> getScheduledMessages(int maxMessages) {
int n = maxMessages;
NavigableSet<Position> positions = new TreeSet<>();
long cutoffTime = getCutoffTime();
Expand Down Expand Up @@ -237,7 +237,7 @@ public NavigableSet<Position> getScheduledMessages(int maxMessages) {
}

@Override
public CompletableFuture<Void> clear() {
public synchronized CompletableFuture<Void> clear() {
this.delayedMessageMap.clear();
this.delayedMessagesCount.set(0);
return CompletableFuture.completedFuture(null);
Expand All @@ -262,7 +262,7 @@ public long getBufferMemoryUsage() {
}

@Override
public void close() {
public synchronized void close() {
super.close();
}

Expand All @@ -275,7 +275,7 @@ && getNumberOfDelayedMessages() >= fixedDelayDetectionLookahead
&& !hasMessageAvailable();
}

protected long nextDeliveryTime() {
protected synchronized long nextDeliveryTime() {
return delayedMessageMap.firstKey();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ public synchronized void readMoreEntries() {
}
}

protected Predicate<Position> createReadEntriesSkipConditionForNormalRead() {
protected synchronized Predicate<Position> createReadEntriesSkipConditionForNormalRead() {
Predicate<Position> skipCondition = null;
// Filter out and skip read delayed messages exist in DelayedDeliveryTracker
if (delayedDeliveryTracker.isPresent()) {
Expand Down Expand Up @@ -1378,7 +1378,7 @@ protected synchronized boolean shouldPauseDeliveryForDelayTracker() {
}

@Override
public long getNumberOfDelayedMessages() {
public synchronized long getNumberOfDelayedMessages() {
return delayedDeliveryTracker.map(DelayedDeliveryTracker::getNumberOfDelayedMessages).orElse(0L);
}

Expand All @@ -1389,7 +1389,9 @@ public CompletableFuture<Void> clearDelayedMessages() {
}

if (delayedDeliveryTracker.isPresent()) {
return this.delayedDeliveryTracker.get().clear();
synchronized (this) {
return this.delayedDeliveryTracker.get().clear();
}
} else {
DelayedDeliveryTrackerFactory delayedDeliveryTrackerFactory =
topic.getBrokerService().getDelayedDeliveryTrackerFactory();
Expand Down Expand Up @@ -1464,11 +1466,11 @@ public PersistentTopic getTopic() {
}


public long getDelayedTrackerMemoryUsage() {
public synchronized long getDelayedTrackerMemoryUsage() {
return delayedDeliveryTracker.map(DelayedDeliveryTracker::getBufferMemoryUsage).orElse(0L);
}

public Map<String, TopicMetricBean> getBucketDelayedIndexStats() {
public synchronized Map<String, TopicMetricBean> getBucketDelayedIndexStats() {
if (delayedDeliveryTracker.isEmpty()) {
return Collections.emptyMap();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@
import io.netty.util.concurrent.DefaultThreadFactory;
import java.lang.reflect.Method;
import java.time.Clock;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.NavigableMap;
import java.util.Set;
import java.util.TreeMap;
Expand Down Expand Up @@ -274,4 +280,100 @@ public void testDelaySequence(InMemoryDelayedDeliveryTracker tracker) throws Exc
tracker.close();
}

@Test(dataProvider = "delayedTracker")
public void testRaceConditionInUpdateTimer(InMemoryDelayedDeliveryTracker tracker) throws Exception {
final int numThreads = 15;
final int operationsPerThread = 2000;
final CountDownLatch startLatch = new CountDownLatch(1);
final CountDownLatch doneLatch = new CountDownLatch(numThreads);
final AtomicInteger errors = new AtomicInteger(0);
final AtomicReference<Exception> firstException = new AtomicReference<>();

@Cleanup("shutdown")
ExecutorService executorService = Executors.newFixedThreadPool(32);

for (int i = 0; i < numThreads / 4; i++) {
executorService.submit(() -> {
try {
startLatch.await();
for (int j = 0; j < operationsPerThread; j++) {
tracker.getNumberOfDelayedMessages();
Thread.sleep(1);
}
} catch (Exception e) {
errors.incrementAndGet();
firstException.compareAndSet(null, e);
e.printStackTrace();
} finally {
doneLatch.countDown();
}
});
}

for (int i = numThreads / 4; i < numThreads; i++) {
executorService.submit(() -> {
try {
startLatch.await();
for (int j = 0; j < operationsPerThread; j++) {
tracker.addMessage(1, 1, 10);
Thread.sleep(1);
}
} catch (Exception e) {
errors.incrementAndGet();
firstException.compareAndSet(null, e);
e.printStackTrace();
} finally {
doneLatch.countDown();
}
});
}

for (int i = numThreads / 4; i < numThreads; i++) {
executorService.submit(() -> {
try {
startLatch.await();
for (int j = 0; j < operationsPerThread; j++) {
tracker.clear();
Thread.sleep(1);
}
} catch (Exception e) {
errors.incrementAndGet();
firstException.compareAndSet(null, e);
e.printStackTrace();
} finally {
doneLatch.countDown();
}
});
}

for (int i = numThreads / 4; i < numThreads; i++) {
executorService.submit(() -> {
try {
startLatch.await();
for (int j = 0; j < operationsPerThread; j++) {
tracker.getScheduledMessages(1);
Thread.sleep(1);
}
} catch (Exception e) {
errors.incrementAndGet();
firstException.compareAndSet(null, e);
e.printStackTrace();
} finally {
doneLatch.countDown();
}
});
}

startLatch.countDown();
assertTrue(doneLatch.await(30, TimeUnit.SECONDS), "Test should complete within 30 seconds");

if (errors.get() > 0) {
Exception exception = firstException.get();
if (exception != null) {
System.err.println("First exception caught: " + exception.getMessage());
exception.printStackTrace();
}
}
assertEquals(errors.get(), 0, "No exceptions should occur during concurrent operations");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@

import com.carrotsearch.hppc.ObjectSet;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import lombok.Cleanup;
import lombok.CustomLog;
import org.apache.bookkeeper.mledger.ManagedCursor;
Expand All @@ -32,10 +37,13 @@
import org.apache.pulsar.client.api.Consumer;
import org.apache.pulsar.client.api.MessageId;
import org.apache.pulsar.client.api.Producer;
import org.apache.pulsar.common.api.proto.MessageMetadata;
import org.apache.pulsar.client.api.Schema;
import org.apache.pulsar.client.api.SubscriptionType;
import org.awaitility.reflect.WhiteboxImpl;
import org.mockito.Mockito;
import static org.testng.Assert.assertTrue;
import static org.testng.Assert.assertEquals;
import org.testng.Assert;
import org.testng.annotations.Test;

Expand Down Expand Up @@ -152,4 +160,97 @@ public void testSkipReadEntriesFromCloseCursor() throws Exception {
// Verify: the topic can be deleted successfully.
admin.topics().delete(topicName, false);
}

@Test
public void testRaceConditionInTrackDelayedDelivery() throws Exception {
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is based on pulsar-broker/src/test/java/org/apache/pulsar/broker/delayed/bucket/BucketDelayedDeliveryTrackerThreadSafetyTest.java

final int numThreads = 16;
final int operationsPerThread = 2000;
final CountDownLatch startLatch = new CountDownLatch(1);
final CountDownLatch doneLatch = new CountDownLatch(numThreads);
final AtomicInteger errors = new AtomicInteger(0);
final AtomicReference<Exception> firstException = new AtomicReference<>();

final String topicName = newTopicName();
final String subscription = "s1";

// Needed to create the topic
Consumer<String> consumer = pulsarClient.newConsumer(Schema.STRING)
.topic(topicName).subscriptionName(subscription)
.subscriptionType(SubscriptionType.Shared).subscribe();

PersistentTopic topic = (PersistentTopic) getTopic(topicName, false).join().get();

ManagedCursor cursor = Mockito.mock(ManagedCursorImpl.class);
Mockito.doReturn(subscription).when(cursor).getName();

Subscription sub = Mockito.mock(PersistentSubscription.class);
Mockito.doReturn(topic).when(sub).getTopic();

PersistentDispatcherMultipleConsumers dispatcher =
new PersistentDispatcherMultipleConsumers(topic, cursor, sub);

// Align all writes to the same bucket
// This is the key which triggers the race condition
long deliverAt = System.currentTimeMillis() + 5000;

MessageMetadata messageMetadata = new MessageMetadata()
.setSequenceId(1)
.setProducerName("testProducer")
.setPartitionKeyB64Encoded(false)
.setPublishTime(System.currentTimeMillis())
.setDeliverAtTime(deliverAt);

@Cleanup("shutdown")
ExecutorService executorService = Executors.newFixedThreadPool(32);
Comment thread
chamons marked this conversation as resolved.

// Start clear message thread
for (int i = 0; i < numThreads / 2; i++) {
executorService.submit(() -> {
try {
startLatch.await();
for (int j = 0; j < operationsPerThread; j++) {
dispatcher.clearDelayedMessages();
Thread.sleep(1);
}
} catch (Exception e) {
errors.incrementAndGet();
firstException.compareAndSet(null, e);
e.printStackTrace();
} finally {
doneLatch.countDown();
}
});
}

// Start track delayed delivery thread
for (int i = numThreads / 2; i < numThreads; i++) {
executorService.submit(() -> {
try {
startLatch.await();
for (int j = 0; j < operationsPerThread; j++) {
dispatcher.trackDelayedDelivery(1, 1, messageMetadata);
Thread.sleep(1);
}
} catch (Exception e) {
errors.incrementAndGet();
firstException.compareAndSet(null, e);
e.printStackTrace();
} finally {
doneLatch.countDown();
}
});
}

startLatch.countDown();
assertTrue(doneLatch.await(30, TimeUnit.SECONDS), "Test should complete within 30 seconds");

if (errors.get() > 0) {
Exception exception = firstException.get();
if (exception != null) {
System.err.println("First exception caught: " + exception.getMessage());
exception.printStackTrace();
}
}
assertEquals(errors.get(), 0, "No exceptions should occur during concurrent operations");
}
}
Loading