Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ List<String> names = bounded.map(userIds, id -> fetchName(id));

// Collect results without failing fast — failed tasks return Either.fail()
List<Either<String, Throwable>> results = parallel.toEither(ids, id -> riskyFetch(id));

// Race — first successful result wins, remaining tasks are interrupted
var result = parallel.race(() -> fetchFromPrimary(), () -> fetchFromReplica());
```

## Context propagation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void mapReturnsResultsInInputOrder() {
i -> {
try {
Thread.sleep(i * 20L);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
return i * 10;
Expand All @@ -60,7 +60,7 @@ void mapRunsConcurrently() {
maxConcurrent.updateAndGet(max -> Math.max(max, c));
try {
Thread.sleep(50);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
current.decrementAndGet();
Expand Down Expand Up @@ -172,7 +172,7 @@ void limitsConcurrentTasks() {
maxConcurrent.updateAndGet(max -> Math.max(max, c));
try {
Thread.sleep(50);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
current.decrementAndGet();
Expand All @@ -196,7 +196,7 @@ void stillRunsConcurrently() {
maxConcurrent.updateAndGet(max -> Math.max(max, c));
try {
Thread.sleep(50);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
current.decrementAndGet();
Expand Down Expand Up @@ -230,4 +230,88 @@ void rejectsZeroConcurrency() {
() -> Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(0));
}
}

@Nested
class Race {

@Test
void returnsFirstSuccessfulResult() {
final var result =
parallel.race(
() -> {
sleepQuietly(200);
return "slow";
},
() -> "fast");
assertEquals("fast", result);
}

@Test
void interruptsRemainingTasks() {
final var interrupted = new AtomicInteger(0);
parallel.race(
() -> {
try {
Thread.sleep(5000);
} catch (final InterruptedException e) {
interrupted.incrementAndGet();
}
return "slow";
},
() -> "fast");

sleepQuietly(50);
assertEquals(1, interrupted.get(), "slow task should have been interrupted");
}

@Test
void throwsWhenAllFail() {
assertThrows(
IllegalStateException.class,
() ->
parallel.race(
List.of(
() -> {
throw new IllegalStateException("fail1");
},
() -> {
throw new IllegalStateException("fail2");
})));
}

@Test
void propagatesMdc() {
MDC.put("traceId", "race-trace");
final var result = parallel.race(() -> MDC.get("traceId"), () -> MDC.get("traceId"));
assertEquals("race-trace", result);
}

@Test
void rejectsEmptyList() {
assertThrows(IllegalArgumentException.class, () -> parallel.race(List.of()));
}

@Test
void succeedsEvenIfSomeFail() {
final var result =
parallel.race(
List.of(
() -> {
throw new RuntimeException("fail");
},
() -> {
sleepQuietly(50);
return "success";
}));
assertEquals("success", result);
}

private void sleepQuietly(long millis) {
try {
Thread.sleep(millis);
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public static <T> T get(Callable<T> callable) {
/**
* Rethrows RuntimeException subclasses directly; wraps checked exceptions in RuntimeException.
*/
private static RuntimeException wrapIfChecked(Throwable t) {
static RuntimeException wrapIfChecked(Throwable t) {
if (t instanceof RuntimeException re) {
return re;
}
Expand Down
77 changes: 77 additions & 0 deletions blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;

Expand Down Expand Up @@ -74,6 +76,81 @@ private <T> Supplier<T> bounded(Supplier<T> task) {
};
}

/**
* Runs all suppliers concurrently on virtual threads with context propagation. Returns the result
* of the first supplier to complete successfully. Remaining tasks are interrupted.
*
* <p>If all suppliers fail, the exception from the last completed task is thrown.
*
* @throws IllegalArgumentException if tasks is empty
*/
@SafeVarargs
public final <T> T race(Supplier<T>... tasks) {
return race(List.of(tasks));
}

/**
* Runs all suppliers concurrently on virtual threads with context propagation. Returns the result
* of the first supplier to complete successfully. Remaining tasks are interrupted.
*
* <p>If all suppliers fail, the exception from the last completed task is thrown.
*
* @throws IllegalArgumentException if tasks is empty
*/
public <T> T race(List<Supplier<T>> tasks) {
Objects.requireNonNull(tasks, "tasks");
if (tasks.isEmpty()) {
throw new IllegalArgumentException("tasks must not be empty");
}

final var result = new AtomicReference<T>();
final var lastFailure = new AtomicReference<Throwable>();
final var success = new CountDownLatch(1);
final var remaining = new CountDownLatch(tasks.size());
final var threads = new ArrayList<Thread>(tasks.size());

for (final var task : tasks) {
final var wrapped =
CallableContext.wrap((semaphore != null ? bounded(task) : task)::get, propagators);
final var thread =
Thread.startVirtualThread(
() -> {
try {
result.set(wrapped.call());
success.countDown();
} catch (final Exception e) {
lastFailure.set(e);
} finally {
remaining.countDown();
}
});
threads.add(thread);
}

try {
// Wait until one succeeds or all finish
while (success.getCount() > 0 && remaining.getCount() > 0) {
remaining.await(1, java.util.concurrent.TimeUnit.MILLISECONDS);
}
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} finally {
for (final var thread : threads) {
thread.interrupt();
}
}

if (success.getCount() == 0) {
return result.get();
}

final var failure = lastFailure.get();
throw failure != null
? Blockless.wrapIfChecked(failure)
: new RuntimeException("All tasks failed");
}

/**
* Applies {@code fn} to each element on virtual threads with context propagation, returning
* results in input order. Blocks until all tasks complete.
Expand Down
Loading