diff --git a/README.md b/README.md index 7aaef58..45f1775 100644 --- a/README.md +++ b/README.md @@ -93,8 +93,15 @@ Supplier data = parallel.async(() -> fetchData()); Map profiles = parallel.asMap(userIds, id -> loadProfile(id)); // Limit concurrent tasks (extras park until a permit frees up) -var bounded = parallel.withMaxConcurrency(10); -List names = bounded.map(userIds, id -> fetchName(id)); +var boundedParallel = parallel.withMaxConcurrency(10); +List names = boundedParallel.map(userIds, id -> fetchName(id)); + +// Per-task timeout — thread is interrupted if task exceeds the deadline +var timedParallel = parallel.withTimeout(Duration.ofSeconds(5)); +List names = timedParallel.map(userIds, id -> fetchName(id)); + +// Combines with withMaxConcurrency +var safeParallel = parallel.withMaxConcurrency(10).withTimeout(Duration.ofSeconds(5)); // Collect results without failing fast — failed tasks return Either.fail() List> results = parallel.toEither(ids, id -> riskyFetch(id)); diff --git a/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java b/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java index c2ced5c..dc8f57f 100644 --- a/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java +++ b/blockless-tests/src/test/java/io/github/pjlabs/blockless/ParallelTest.java @@ -5,6 +5,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import io.github.pjlabs.blockless.context.slf4j.Slf4jMdcContextPropagator; +import java.time.Duration; import java.util.List; import java.util.concurrent.atomic.AtomicInteger; import org.junit.jupiter.api.AfterEach; @@ -38,7 +39,7 @@ void mapReturnsResultsInInputOrder() { i -> { try { Thread.sleep(i * 20L); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } return i * 10; @@ -60,7 +61,7 @@ void mapRunsConcurrently() { maxConcurrent.updateAndGet(max -> Math.max(max, c)); try { Thread.sleep(50); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } current.decrementAndGet(); @@ -161,18 +162,19 @@ class BoundedConcurrency { @Test void limitsConcurrentTasks() { - final var bounded = Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2); + final var boundedParallel = + Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2); final var maxConcurrent = new AtomicInteger(0); final var current = new AtomicInteger(0); - bounded.map( + boundedParallel.map( List.of(1, 2, 3, 4, 5), i -> { final int c = current.incrementAndGet(); maxConcurrent.updateAndGet(max -> Math.max(max, c)); try { Thread.sleep(50); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } current.decrementAndGet(); @@ -185,18 +187,19 @@ void limitsConcurrentTasks() { @Test void stillRunsConcurrently() { - final var bounded = Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(3); + final var boundedParallel = + Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(3); final var maxConcurrent = new AtomicInteger(0); final var current = new AtomicInteger(0); - bounded.map( + boundedParallel.map( List.of(1, 2, 3, 4, 5), i -> { final int c = current.incrementAndGet(); maxConcurrent.updateAndGet(max -> Math.max(max, c)); try { Thread.sleep(50); - } catch (InterruptedException e) { + } catch (final InterruptedException e) { Thread.currentThread().interrupt(); } current.decrementAndGet(); @@ -210,16 +213,18 @@ void stillRunsConcurrently() { @Test void preservesResultOrder() { - final var bounded = Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2); - final var results = bounded.map(List.of(3, 1, 2), i -> i * 10); + final var boundedParallel = + Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2); + final var results = boundedParallel.map(List.of(3, 1, 2), i -> i * 10); assertEquals(List.of(30, 10, 20), results); } @Test void propagatesMdc() { MDC.put("traceId", "bounded-trace"); - final var bounded = Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2); - final var results = bounded.map(List.of(1, 2, 3), i -> MDC.get("traceId")); + final var boundedParallel = + Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(2); + final var results = boundedParallel.map(List.of(1, 2, 3), i -> MDC.get("traceId")); assertTrue(results.stream().allMatch("bounded-trace"::equals)); } @@ -230,4 +235,69 @@ void rejectsZeroConcurrency() { () -> Parallel.create(new Slf4jMdcContextPropagator()).withMaxConcurrency(0)); } } + + @Nested + class Timeout { + + @Test + void completesWithinTimeout() { + final var timedParallel = + Parallel.create(new Slf4jMdcContextPropagator()).withTimeout(Duration.ofSeconds(5)); + final var results = timedParallel.map(List.of(1, 2, 3), i -> i * 10); + assertEquals(List.of(10, 20, 30), results); + } + + @Test + void interruptsSlowTask() { + final var timedParallel = + Parallel.create(new Slf4jMdcContextPropagator()).withTimeout(Duration.ofMillis(50)); + assertThrows( + RuntimeException.class, + () -> + timedParallel.map( + List.of(1), + i -> { + try { + Thread.sleep(5000); + } catch (final InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + return i; + })); + } + + @Test + void fastTasksUnaffected() { + final var timedParallel = + Parallel.create(new Slf4jMdcContextPropagator()).withTimeout(Duration.ofSeconds(1)); + final var result = timedParallel.map(List.of("a", "b"), s -> s.toUpperCase()); + assertEquals(List.of("A", "B"), result); + } + + @Test + void combinesWithMaxConcurrency() { + final var timedParallel = + Parallel.create(new Slf4jMdcContextPropagator()) + .withMaxConcurrency(2) + .withTimeout(Duration.ofSeconds(5)); + final var results = timedParallel.map(List.of(1, 2, 3, 4), i -> i * 10); + assertEquals(List.of(10, 20, 30, 40), results); + } + + @Test + void rejectsZeroTimeout() { + assertThrows( + IllegalArgumentException.class, + () -> Parallel.create(new Slf4jMdcContextPropagator()).withTimeout(Duration.ZERO)); + } + + @Test + void rejectsNegativeTimeout() { + assertThrows( + IllegalArgumentException.class, + () -> + Parallel.create(new Slf4jMdcContextPropagator()).withTimeout(Duration.ofMillis(-1))); + } + } } diff --git a/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java b/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java index 2a32fab..7b894fc 100644 --- a/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java +++ b/blockless/src/main/java/io/github/pjlabs/blockless/Parallel.java @@ -1,5 +1,6 @@ package io.github.pjlabs.blockless; +import java.time.Duration; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -8,6 +9,8 @@ import java.util.Map; import java.util.Objects; import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import java.util.function.Supplier; @@ -26,10 +29,12 @@ public final class Parallel { private final List propagators; private final Semaphore semaphore; + private final Duration timeout; - private Parallel(List propagators, Semaphore semaphore) { + private Parallel(List propagators, Semaphore semaphore, Duration timeout) { this.propagators = List.copyOf(propagators); this.semaphore = semaphore; + this.timeout = timeout; } /** Creates an unbounded {@link Parallel} instance with the given propagators. */ @@ -39,7 +44,7 @@ public static Parallel create(ContextPropagator... propagators) { /** Creates an unbounded {@link Parallel} instance with the given propagators. */ public static Parallel create(List propagators) { - return new Parallel(propagators, null); + return new Parallel(propagators, null, null); } /** @@ -50,7 +55,20 @@ public Parallel withMaxConcurrency(int maxConcurrency) { if (maxConcurrency < 1) { throw new IllegalArgumentException("maxConcurrency must be at least 1"); } - return new Parallel(propagators, new Semaphore(maxConcurrency)); + return new Parallel(propagators, new Semaphore(maxConcurrency), timeout); + } + + /** + * Returns a new {@link Parallel} with a per-task timeout. If a task does not complete within the + * duration, its thread is interrupted and a {@link TimeoutException} is thrown (wrapped in {@link + * RuntimeException}). + */ + public Parallel withTimeout(Duration timeout) { + Objects.requireNonNull(timeout, "timeout"); + if (timeout.isNegative() || timeout.isZero()) { + throw new IllegalArgumentException("timeout must be positive"); + } + return new Parallel(propagators, semaphore, timeout); } /** @@ -59,11 +77,39 @@ public Parallel withMaxConcurrency(int maxConcurrency) { */ public Supplier async(Supplier task) { Objects.requireNonNull(task, "task"); - final var effective = semaphore != null ? bounded(task) : task; - return Blockless.supplier(CallableContext.wrap(effective::get, propagators)); + var wrappedSupplier = semaphore != null ? boundedSupplier(task) : task; + if (timeout != null) { + wrappedSupplier = timedSupplier(wrappedSupplier); + } + return Blockless.supplier(CallableContext.wrap(wrappedSupplier::get, propagators)); + } + + private Supplier timedSupplier(Supplier task) { + return () -> { + final var taskThread = Thread.currentThread(); + final var done = new AtomicBoolean(false); + final var timer = + Thread.startVirtualThread( + () -> { + try { + Thread.sleep(timeout); + if (!done.get()) { + taskThread.interrupt(); + } + } catch (InterruptedException ignored) { + // Timer cancelled — task completed in time + } + }); + try { + return task.get(); + } finally { + done.set(true); + timer.interrupt(); + } + }; } - private Supplier bounded(Supplier task) { + private Supplier boundedSupplier(Supplier task) { return () -> { semaphore.acquireUninterruptibly(); try {