Skip to content
11 changes: 9 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,15 @@ Supplier<String> data = parallel.async(() -> fetchData());
Map<String, Profile> profiles = parallel.asMap(userIds, id -> loadProfile(id));

// Limit concurrent tasks (extras park until a permit frees up)
var bounded = parallel.withMaxConcurrency(10);
List<String> names = bounded.map(userIds, id -> fetchName(id));
var boundedParallel = parallel.withMaxConcurrency(10);
List<String> names = boundedParallel.map(userIds, id -> fetchName(id));

// Per-task timeout — thread is interrupted if task exceeds the deadline
var timedParallel = parallel.withTimeout(Duration.ofSeconds(5));
List<String> names = timedParallel.map(userIds, id -> fetchName(id));

// Combines with withMaxConcurrency
var safeParallel = parallel.withMaxConcurrency(10).withTimeout(Duration.ofSeconds(5));

// Collect results without failing fast — failed tasks return Either.fail()
List<Either<String, Throwable>> results = parallel.toEither(ids, id -> riskyFetch(id));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.github.pjlabs.blockless.context.slf4j.Slf4jMdcContextPropagator;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.AfterEach;
Expand Down Expand Up @@ -38,7 +39,7 @@ void mapReturnsResultsInInputOrder() {
i -> {
try {
Thread.sleep(i * 20L);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
return i * 10;
Expand All @@ -60,7 +61,7 @@ void mapRunsConcurrently() {
maxConcurrent.updateAndGet(max -> Math.max(max, c));
try {
Thread.sleep(50);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
current.decrementAndGet();
Expand Down Expand Up @@ -161,18 +162,19 @@ class BoundedConcurrency {

@Test
void limitsConcurrentTasks() {
final var bounded = Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2);
final var boundedParallel =
Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2);
final var maxConcurrent = new AtomicInteger(0);
final var current = new AtomicInteger(0);

bounded.map(
boundedParallel.map(
List.of(1, 2, 3, 4, 5),
i -> {
final int c = current.incrementAndGet();
maxConcurrent.updateAndGet(max -> Math.max(max, c));
try {
Thread.sleep(50);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
current.decrementAndGet();
Expand All @@ -185,18 +187,19 @@ void limitsConcurrentTasks() {

@Test
void stillRunsConcurrently() {
final var bounded = Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(3);
final var boundedParallel =
Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(3);
final var maxConcurrent = new AtomicInteger(0);
final var current = new AtomicInteger(0);

bounded.map(
boundedParallel.map(
List.of(1, 2, 3, 4, 5),
i -> {
final int c = current.incrementAndGet();
maxConcurrent.updateAndGet(max -> Math.max(max, c));
try {
Thread.sleep(50);
} catch (InterruptedException e) {
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
current.decrementAndGet();
Expand All @@ -210,16 +213,18 @@ void stillRunsConcurrently() {

@Test
void preservesResultOrder() {
final var bounded = Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2);
final var results = bounded.map(List.of(3, 1, 2), i -> i * 10);
final var boundedParallel =
Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2);
final var results = boundedParallel.map(List.of(3, 1, 2), i -> i * 10);
assertEquals(List.of(30, 10, 20), results);
}

@Test
void propagatesMdc() {
MDC.put("traceId", "bounded-trace");
final var bounded = Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2);
final var results = bounded.map(List.of(1, 2, 3), i -> MDC.get("traceId"));
final var boundedParallel =
Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2);
final var results = boundedParallel.map(List.of(1, 2, 3), i -> MDC.get("traceId"));
assertTrue(results.stream().allMatch("bounded-trace"::equals));
}

Expand All @@ -230,4 +235,69 @@ void rejectsZeroConcurrency() {
() -> Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(0));
}
}

@Nested
class Timeout {

@Test
void completesWithinTimeout() {
final var timedParallel =
Parallel.create(new Slf4jMdcContextPropagator()).withTimeout(Duration.ofSeconds(5));
final var results = timedParallel.map(List.of(1, 2, 3), i -> i * 10);
assertEquals(List.of(10, 20, 30), results);
}

@Test
void interruptsSlowTask() {
final var timedParallel =
Parallel.create(new Slf4jMdcContextPropagator()).withTimeout(Duration.ofMillis(50));
assertThrows(
RuntimeException.class,
() ->
timedParallel.map(
List.of(1),
i -> {
try {
Thread.sleep(5000);
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
return i;
}));
}

@Test
void fastTasksUnaffected() {
final var timedParallel =
Parallel.create(new Slf4jMdcContextPropagator()).withTimeout(Duration.ofSeconds(1));
final var result = timedParallel.map(List.of("a", "b"), s -> s.toUpperCase());
assertEquals(List.of("A", "B"), result);
}

@Test
void combinesWithMaxConcurrency() {
final var timedParallel =
Parallel.create(new Slf4jMdcContextPropagator())
.withMaxConcurrency(2)
.withTimeout(Duration.ofSeconds(5));
final var results = timedParallel.map(List.of(1, 2, 3, 4), i -> i * 10);
assertEquals(List.of(10, 20, 30, 40), results);
}

@Test
void rejectsZeroTimeout() {
assertThrows(
IllegalArgumentException.class,
() -> Parallel.create(new Slf4jMdcContextPropagator()).withTimeout(Duration.ZERO));
}

@Test
void rejectsNegativeTimeout() {
assertThrows(
IllegalArgumentException.class,
() ->
Parallel.create(new Slf4jMdcContextPropagator()).withTimeout(Duration.ofMillis(-1)));
}
}
}
58 changes: 52 additions & 6 deletions blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.github.pjlabs.blockless;

import java.time.Duration;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
Expand All @@ -8,6 +9,8 @@
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.function.Supplier;

Expand All @@ -26,10 +29,12 @@ public final class Parallel {

private final List<ContextPropagator> propagators;
private final Semaphore semaphore;
private final Duration timeout;

private Parallel(List<ContextPropagator> propagators, Semaphore semaphore) {
private Parallel(List<ContextPropagator> propagators, Semaphore semaphore, Duration timeout) {
this.propagators = List.copyOf(propagators);
this.semaphore = semaphore;
this.timeout = timeout;
}

/** Creates an unbounded {@link Parallel} instance with the given propagators. */
Expand All @@ -39,7 +44,7 @@ public static Parallel create(ContextPropagator... propagators) {

/** Creates an unbounded {@link Parallel} instance with the given propagators. */
public static Parallel create(List<ContextPropagator> propagators) {
return new Parallel(propagators, null);
return new Parallel(propagators, null, null);
}

/**
Expand All @@ -50,7 +55,20 @@ public Parallel withMaxConcurrency(int maxConcurrency) {
if (maxConcurrency < 1) {
throw new IllegalArgumentException("maxConcurrency must be at least 1");
}
return new Parallel(propagators, new Semaphore(maxConcurrency));
return new Parallel(propagators, new Semaphore(maxConcurrency), timeout);
}

/**
* Returns a new {@link Parallel} with a per-task timeout. If a task does not complete within the
* duration, its thread is interrupted and a {@link TimeoutException} is thrown (wrapped in {@link
* RuntimeException}).
*/
public Parallel withTimeout(Duration timeout) {
Objects.requireNonNull(timeout, "timeout");
if (timeout.isNegative() || timeout.isZero()) {
throw new IllegalArgumentException("timeout must be positive");
}
return new Parallel(propagators, semaphore, timeout);
}

/**
Expand All @@ -59,11 +77,39 @@ public Parallel withMaxConcurrency(int maxConcurrency) {
*/
public <T> Supplier<T> async(Supplier<T> task) {
Objects.requireNonNull(task, "task");
final var effective = semaphore != null ? bounded(task) : task;
return Blockless.supplier(CallableContext.wrap(effective::get, propagators));
var wrappedSupplier = semaphore != null ? boundedSupplier(task) : task;
if (timeout != null) {
wrappedSupplier = timedSupplier(wrappedSupplier);
}
return Blockless.supplier(CallableContext.wrap(wrappedSupplier::get, propagators));
}

private <T> Supplier<T> timedSupplier(Supplier<T> task) {
return () -> {
final var taskThread = Thread.currentThread();
final var done = new AtomicBoolean(false);
final var timer =
Thread.startVirtualThread(
() -> {
try {
Thread.sleep(timeout);
Comment thread
jimbao marked this conversation as resolved.
if (!done.get()) {
taskThread.interrupt();
}
} catch (InterruptedException ignored) {
// Timer cancelled — task completed in time
}
});
try {
return task.get();
} finally {
done.set(true);
timer.interrupt();
}
};
}

private <T> Supplier<T> bounded(Supplier<T> task) {
private <T> Supplier<T> boundedSupplier(Supplier<T> task) {
return () -> {
semaphore.acquireUninterruptibly();
try {
Expand Down
Loading