diff --git a/README.md b/README.md index 7aaef58..d558912 100644 --- a/README.md +++ b/README.md @@ -76,7 +76,8 @@ var results = parallel.withMaxConcurrency(10).map(ids, this::fetchById); ``` Each task runs on its own virtual thread. Results stay in input order. -MDC and trace context survive the hop. +MDC and trace context survive the hop. If any task fails, remaining tasks +are interrupted — no wasted work running in the background. ### Usage diff --git a/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java b/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java index c2ced5c..e6b312f 100644 --- a/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java +++ b/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java @@ -38,7 +38,7 @@ void mapReturnsResultsInInputOrder() { i -> { try { Thread.sleep(i * 20L); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } return i * 10; @@ -60,7 +60,7 @@ void mapRunsConcurrently() { maxConcurrent.updateAndGet(max -> Math.max(max, c)); try { Thread.sleep(50); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } current.decrementAndGet(); @@ -172,7 +172,7 @@ void limitsConcurrentTasks() { maxConcurrent.updateAndGet(max -> Math.max(max, c)); try { Thread.sleep(50); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } current.decrementAndGet(); @@ -196,7 +196,7 @@ void stillRunsConcurrently() { maxConcurrent.updateAndGet(max -> Math.max(max, c)); try { Thread.sleep(50); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } current.decrementAndGet(); @@ -230,4 +230,91 @@ void rejectsZeroConcurrency() { () -> Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(0)); } } + + @Nested + class CancelOnFailure { + + @Test + void mapInterruptsRemainingTasksOnFailure() { + final var interrupted = new AtomicInteger(0); + + assertThrows( + RuntimeException.class, + () -> + parallel.map( + List.of(1, 2, 3), + i -> { + if (i == 1) { + throw new RuntimeException("fail fast"); + } + try { + Thread.sleep(5000); + } catch (final InterruptedException e) { + interrupted.incrementAndGet(); + } + return i; + })); + + // Give interrupts a moment to propagate + try { + Thread.sleep(100); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + assertTrue( + interrupted.get() > 0, "remaining tasks should have been interrupted, but none were"); + } + + @Test + void asMapInterruptsRemainingTasksOnFailure() { + final var interrupted = new AtomicInteger(0); + + assertThrows( + RuntimeException.class, + () -> + parallel.asMap( + List.of("a", "b", "c"), + key -> { + if ("a".equals(key)) { + throw new RuntimeException("fail fast"); + } + try { + Thread.sleep(5000); + } catch (final InterruptedException e) { + interrupted.incrementAndGet(); + } + return key.toUpperCase(); + })); + + try { + Thread.sleep(100); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + assertTrue( + interrupted.get() > 0, "remaining tasks should have been interrupted, but none were"); + } + + @Test + void toEitherDoesNotInterruptOnFailure() { + final var completed = new AtomicInteger(0); + + parallel.toEither( + List.of(1, 2, 3), + i -> { + if (i == 1) { + throw new RuntimeException("fail"); + } + try { + Thread.sleep(50); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + } + completed.incrementAndGet(); + return i; + }); + + assertEquals(2, completed.get(), "toEither should let all tasks run to completion"); + } + } } diff --git a/blockless/src/main/java/io/github/pjlabs/blockless/Blockless.java b/blockless/src/main/java/io/github/pjlabs/blockless/Blockless.java index 6a4fa51..127abf9 100644 --- a/blockless/src/main/java/io/github/pjlabs/blockless/Blockless.java +++ b/blockless/src/main/java/io/github/pjlabs/blockless/Blockless.java @@ -43,7 +43,7 @@ public static Supplier supplier(CompletionStage stage) { () -> { try { latch.await(); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); throwable.set(e); } @@ -51,7 +51,7 @@ public static Supplier supplier(CompletionStage stage) { return () -> { try { thread.join(); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); throwable.set(e); } @@ -89,17 +89,17 @@ public static Supplier supplier(Callable callable) { () -> { try { result.set(callable.call()); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); throwable.set(e); - } catch (Exception e) { + } catch (final Exception e) { throwable.set(e); } }); return () -> { try { thread.join(); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); throwable.set(e); } @@ -122,10 +122,54 @@ public static T get(Callable callable) { return supplier(callable).get(); } + /** A task handle that exposes both the result supplier and the underlying virtual thread. */ + record AsyncTask(Supplier supplier, Thread thread) { + + T get() { + return supplier.get(); + } + + void interrupt() { + thread.interrupt(); + } + } + + /** Like {@link #supplier(Callable)} but returns the thread handle for cancellation. */ + static AsyncTask asyncTask(Callable callable) { + final var result = new AtomicReference(); + final var throwable = new AtomicReference(); + final var thread = + Thread.startVirtualThread( + () -> { + try { + result.set(callable.call()); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throwable.set(e); + } catch (final Exception e) { + throwable.set(e); + } + }); + final Supplier supplier = + () -> { + try { + thread.join(); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throwable.set(e); + } + if (throwable.get() != null) { + throw wrapIfChecked(throwable.get()); + } + return result.get(); + }; + return new AsyncTask<>(supplier, thread); + } + /** * Rethrows RuntimeException subclasses directly; wraps checked exceptions in RuntimeException. */ - private static RuntimeException wrapIfChecked(Throwable t) { + static RuntimeException wrapIfChecked(Throwable t) { if (t instanceof RuntimeException re) { return re; } diff --git a/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java b/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java index 2a32fab..6670744 100644 --- a/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java +++ b/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java @@ -74,35 +74,54 @@ private Supplier bounded(Supplier task) { }; } + /** Creates an AsyncTask with thread handle for cancellation support. */ + private Blockless.AsyncTask fork(Supplier task) { + Objects.requireNonNull(task, "task"); + final var effective = semaphore != null ? bounded(task) : task; + return Blockless.asyncTask(CallableContext.wrap(effective::get, propagators)); + } + /** * Applies {@code fn} to each element on virtual threads with context propagation, returning - * results in input order. Blocks until all tasks complete. + * results in input order. Blocks until all tasks complete. If any task fails, remaining tasks are + * interrupted. */ public List map(List items, Function fn) { Objects.requireNonNull(items, "items"); Objects.requireNonNull(fn, "fn"); - final var suppliers = items.stream().map(item -> async(() -> fn.apply(item))).toList(); + final var tasks = items.stream().map(item -> fork(() -> fn.apply(item))).toList(); - return suppliers.stream().map(Supplier::get).toList(); + try { + return tasks.stream().map(Blockless.AsyncTask::get).toList(); + } catch (final RuntimeException e) { + tasks.forEach(Blockless.AsyncTask::interrupt); + throw e; + } } /** * Computes a value for each key on virtual threads with context propagation, returning a map - * preserving key iteration order. Blocks until all tasks complete. + * preserving key iteration order. Blocks until all tasks complete. If any task fails, remaining + * tasks are interrupted. */ public Map asMap(Collection keys, Function valueMapper) { Objects.requireNonNull(keys, "keys"); Objects.requireNonNull(valueMapper, "valueMapper"); final var entries = - keys.stream().map(key -> Map.entry(key, async(() -> valueMapper.apply(key)))).toList(); + keys.stream().map(key -> Map.entry(key, fork(() -> valueMapper.apply(key)))).toList(); - final var result = new LinkedHashMap(); - for (var entry : entries) { - result.put(entry.getKey(), entry.getValue().get()); + try { + final var result = new LinkedHashMap(); + for (final var entry : entries) { + result.put(entry.getKey(), entry.getValue().get()); + } + return result; + } catch (final RuntimeException e) { + entries.forEach(entry -> entry.getValue().interrupt()); + throw e; } - return result; } /**