Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ var results = parallel.withMaxConcurrency(10).map(ids, this::fetchById);
```

Each task runs on its own virtual thread. Results stay in input order.
MDC and trace context survive the hop.
MDC and trace context survive the hop. If any task fails, remaining tasks
are interrupted — no wasted work running in the background.

### Usage

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void mapReturnsResultsInInputOrder() {
i -> {
try {
Thread.sleep(i * 20L);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
return i * 10;
Expand All @@ -60,7 +60,7 @@ void mapRunsConcurrently() {
maxConcurrent.updateAndGet(max -> Math.max(max, c));
try {
Thread.sleep(50);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
current.decrementAndGet();
Expand Down Expand Up @@ -172,7 +172,7 @@ void limitsConcurrentTasks() {
maxConcurrent.updateAndGet(max -> Math.max(max, c));
try {
Thread.sleep(50);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
current.decrementAndGet();
Expand All @@ -196,7 +196,7 @@ void stillRunsConcurrently() {
maxConcurrent.updateAndGet(max -> Math.max(max, c));
try {
Thread.sleep(50);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
current.decrementAndGet();
Expand Down Expand Up @@ -230,4 +230,91 @@ void rejectsZeroConcurrency() {
() -> Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(0));
}
}

@Nested
class CancelOnFailure {

@Test
void mapInterruptsRemainingTasksOnFailure() {
final var interrupted = new AtomicInteger(0);

assertThrows(
RuntimeException.class,
() ->
parallel.map(
List.of(1, 2, 3),
i -> {
if (i == 1) {
throw new RuntimeException("fail fast");
}
try {
Thread.sleep(5000);
} catch (final InterruptedException e) {
interrupted.incrementAndGet();
}
return i;
}));

// Give interrupts a moment to propagate
try {
Thread.sleep(100);
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
assertTrue(
interrupted.get() > 0, "remaining tasks should have been interrupted, but none were");
}

@Test
void asMapInterruptsRemainingTasksOnFailure() {
final var interrupted = new AtomicInteger(0);

assertThrows(
RuntimeException.class,
() ->
parallel.asMap(
List.of("a", "b", "c"),
key -> {
if ("a".equals(key)) {
throw new RuntimeException("fail fast");
}
try {
Thread.sleep(5000);
} catch (final InterruptedException e) {
interrupted.incrementAndGet();
}
return key.toUpperCase();
}));

try {
Thread.sleep(100);
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
assertTrue(
interrupted.get() > 0, "remaining tasks should have been interrupted, but none were");
}

@Test
void toEitherDoesNotInterruptOnFailure() {
final var completed = new AtomicInteger(0);

parallel.toEither(
List.of(1, 2, 3),
i -> {
if (i == 1) {
throw new RuntimeException("fail");
}
try {
Thread.sleep(50);
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
completed.incrementAndGet();
return i;
});

assertEquals(2, completed.get(), "toEither should let all tasks run to completion");
}
}
}
56 changes: 50 additions & 6 deletions blockless/src/main/java/io/github/pjlabs/blockless/Blockless.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ public static <T> Supplier<T> supplier(CompletionStage<T> stage) {
() -> {
try {
latch.await();
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
throwable.set(e);
}
});
return () -> {
try {
thread.join();
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
throwable.set(e);
}
Expand Down Expand Up @@ -89,17 +89,17 @@ public static <T> Supplier<T> supplier(Callable<T> callable) {
() -> {
try {
result.set(callable.call());
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
throwable.set(e);
} catch (Exception e) {
} catch (final Exception e) {
throwable.set(e);
}
});
return () -> {
try {
thread.join();
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
throwable.set(e);
}
Expand All @@ -122,10 +122,54 @@ public static <T> T get(Callable<T> callable) {
return supplier(callable).get();
}

/** A task handle that exposes both the result supplier and the underlying virtual thread. */
record AsyncTask<T>(Supplier<T> supplier, Thread thread) {

T get() {
return supplier.get();
}

void interrupt() {
thread.interrupt();
}
}

/** Like {@link #supplier(Callable)} but returns the thread handle for cancellation. */
static <T> AsyncTask<T> asyncTask(Callable<T> callable) {
final var result = new AtomicReference<T>();
final var throwable = new AtomicReference<Throwable>();
final var thread =
Thread.startVirtualThread(
() -> {
try {
result.set(callable.call());
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
throwable.set(e);
} catch (final Exception e) {
throwable.set(e);
}
});
final Supplier<T> supplier =
() -> {
try {
thread.join();
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
throwable.set(e);
}
if (throwable.get() != null) {
throw wrapIfChecked(throwable.get());
}
return result.get();
};
return new AsyncTask<>(supplier, thread);
}

/**
* Rethrows RuntimeException subclasses directly; wraps checked exceptions in RuntimeException.
*/
private static RuntimeException wrapIfChecked(Throwable t) {
static RuntimeException wrapIfChecked(Throwable t) {
if (t instanceof RuntimeException re) {
return re;
}
Expand Down
37 changes: 28 additions & 9 deletions blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,35 +74,54 @@ private <T> Supplier<T> bounded(Supplier<T> task) {
};
}

/** Creates an AsyncTask with thread handle for cancellation support. */
private <T> Blockless.AsyncTask<T> fork(Supplier<T> task) {
Objects.requireNonNull(task, "task");
final var effective = semaphore != null ? bounded(task) : task;
return Blockless.asyncTask(CallableContext.wrap(effective::get, propagators));
}

/**
* Applies {@code fn} to each element on virtual threads with context propagation, returning
* results in input order. Blocks until all tasks complete.
* results in input order. Blocks until all tasks complete. If any task fails, remaining tasks are
* interrupted.
*/
public <T, R> List<R> map(List<T> items, Function<T, R> fn) {
Objects.requireNonNull(items, "items");
Objects.requireNonNull(fn, "fn");

final var suppliers = items.stream().map(item -> async(() -> fn.apply(item))).toList();
final var tasks = items.stream().map(item -> fork(() -> fn.apply(item))).toList();

return suppliers.stream().map(Supplier::get).toList();
try {
return tasks.stream().map(Blockless.AsyncTask::get).toList();
} catch (final RuntimeException e) {
tasks.forEach(Blockless.AsyncTask::interrupt);
throw e;
}
}

/**
* Computes a value for each key on virtual threads with context propagation, returning a map
* preserving key iteration order. Blocks until all tasks complete.
* preserving key iteration order. Blocks until all tasks complete. If any task fails, remaining
* tasks are interrupted.
*/
public <K, V> Map<K, V> asMap(Collection<K> keys, Function<K, V> valueMapper) {
Objects.requireNonNull(keys, "keys");
Objects.requireNonNull(valueMapper, "valueMapper");

final var entries =
keys.stream().map(key -> Map.entry(key, async(() -> valueMapper.apply(key)))).toList();
keys.stream().map(key -> Map.entry(key, fork(() -> valueMapper.apply(key)))).toList();

final var result = new LinkedHashMap<K, V>();
for (var entry : entries) {
result.put(entry.getKey(), entry.getValue().get());
try {
final var result = new LinkedHashMap<K, V>();
for (final var entry : entries) {
result.put(entry.getKey(), entry.getValue().get());
}
return result;
} catch (final RuntimeException e) {
entries.forEach(entry -> entry.getValue().interrupt());
throw e;
}
return result;
}

/**
Expand Down
Loading