diff --git a/README.md b/README.md index 7aaef58..cde8973 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,9 @@ List names = bounded.map(userIds, id -> fetchName(id)); // Collect results without failing fast — failed tasks return Either.fail() List> results = parallel.toEither(ids, id -> riskyFetch(id)); + +// Race — first successful result wins, remaining tasks are interrupted +var result = parallel.race(() -> fetchFromPrimary(), () -> fetchFromReplica()); ``` ## Context propagation diff --git a/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java b/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java index c2ced5c..ebb6f9c 100644 --- a/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java +++ b/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java @@ -38,7 +38,7 @@ void mapReturnsResultsInInputOrder() { i -> { try { Thread.sleep(i * 20L); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } return i * 10; @@ -60,7 +60,7 @@ void mapRunsConcurrently() { maxConcurrent.updateAndGet(max -> Math.max(max, c)); try { Thread.sleep(50); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } current.decrementAndGet(); @@ -172,7 +172,7 @@ void limitsConcurrentTasks() { maxConcurrent.updateAndGet(max -> Math.max(max, c)); try { Thread.sleep(50); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } current.decrementAndGet(); @@ -196,7 +196,7 @@ void stillRunsConcurrently() { maxConcurrent.updateAndGet(max -> Math.max(max, c)); try { Thread.sleep(50); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } current.decrementAndGet(); @@ -230,4 +230,88 @@ void rejectsZeroConcurrency() { () -> Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(0)); } } + + @Nested + class Race { + + @Test + void returnsFirstSuccessfulResult() { + final var result = + parallel.race( + () -> { + sleepQuietly(200); + return "slow"; + }, + () -> "fast"); + assertEquals("fast", result); + } + + @Test + void interruptsRemainingTasks() { + final var interrupted = new AtomicInteger(0); + parallel.race( + () -> { + try { + Thread.sleep(5000); + } catch (final InterruptedException e) { + interrupted.incrementAndGet(); + } + return "slow"; + }, + () -> "fast"); + + sleepQuietly(50); + assertEquals(1, interrupted.get(), "slow task should have been interrupted"); + } + + @Test + void throwsWhenAllFail() { + assertThrows( + IllegalStateException.class, + () -> + parallel.race( + List.of( + () -> { + throw new IllegalStateException("fail1"); + }, + () -> { + throw new IllegalStateException("fail2"); + }))); + } + + @Test + void propagatesMdc() { + MDC.put("traceId", "race-trace"); + final var result = parallel.race(() -> MDC.get("traceId"), () -> MDC.get("traceId")); + assertEquals("race-trace", result); + } + + @Test + void rejectsEmptyList() { + assertThrows(IllegalArgumentException.class, () -> parallel.race(List.of())); + } + + @Test + void succeedsEvenIfSomeFail() { + final var result = + parallel.race( + List.of( + () -> { + throw new RuntimeException("fail"); + }, + () -> { + sleepQuietly(50); + return "success"; + })); + assertEquals("success", result); + } + + private void sleepQuietly(long millis) { + try { + Thread.sleep(millis); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + } } diff --git a/blockless/src/main/java/io/github/pjlabs/blockless/Blockless.java b/blockless/src/main/java/io/github/pjlabs/blockless/Blockless.java index 6a4fa51..14a8955 100644 --- a/blockless/src/main/java/io/github/pjlabs/blockless/Blockless.java +++ b/blockless/src/main/java/io/github/pjlabs/blockless/Blockless.java @@ -125,7 +125,7 @@ public static T get(Callable callable) { /** * Rethrows RuntimeException subclasses directly; wraps checked exceptions in RuntimeException. */ - private static RuntimeException wrapIfChecked(Throwable t) { + static RuntimeException wrapIfChecked(Throwable t) { if (t instanceof RuntimeException re) { return re; } diff --git a/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java b/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java index 2a32fab..eb0f412 100644 --- a/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java +++ b/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java @@ -7,7 +7,9 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Semaphore; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import java.util.function.Supplier; @@ -74,6 +76,81 @@ private Supplier bounded(Supplier task) { }; } + /** + * Runs all suppliers concurrently on virtual threads with context propagation. Returns the result + * of the first supplier to complete successfully. Remaining tasks are interrupted. + * + *

If all suppliers fail, the exception from the last completed task is thrown. + * + * @throws IllegalArgumentException if tasks is empty + */ + @SafeVarargs + public final T race(Supplier... tasks) { + return race(List.of(tasks)); + } + + /** + * Runs all suppliers concurrently on virtual threads with context propagation. Returns the result + * of the first supplier to complete successfully. Remaining tasks are interrupted. + * + *

If all suppliers fail, the exception from the last completed task is thrown. + * + * @throws IllegalArgumentException if tasks is empty + */ + public T race(List> tasks) { + Objects.requireNonNull(tasks, "tasks"); + if (tasks.isEmpty()) { + throw new IllegalArgumentException("tasks must not be empty"); + } + + final var result = new AtomicReference(); + final var lastFailure = new AtomicReference(); + final var success = new CountDownLatch(1); + final var remaining = new CountDownLatch(tasks.size()); + final var threads = new ArrayList(tasks.size()); + + for (final var task : tasks) { + final var wrapped = + CallableContext.wrap((semaphore != null ? bounded(task) : task)::get, propagators); + final var thread = + Thread.startVirtualThread( + () -> { + try { + result.set(wrapped.call()); + success.countDown(); + } catch (final Exception e) { + lastFailure.set(e); + } finally { + remaining.countDown(); + } + }); + threads.add(thread); + } + + try { + // Wait until one succeeds or all finish + while (success.getCount() > 0 && remaining.getCount() > 0) { + remaining.await(1, java.util.concurrent.TimeUnit.MILLISECONDS); + } + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } finally { + for (final var thread : threads) { + thread.interrupt(); + } + } + + if (success.getCount() == 0) { + return result.get(); + } + + final var failure = lastFailure.get(); + throw failure != null + ? Blockless.wrapIfChecked(failure) + : new RuntimeException("All tasks failed"); + } + /** * Applies {@code fn} to each element on virtual threads with context propagation, returning * results in input order. Blocks until all tasks complete.