From 9d316a24f96b4a908630402ed6e779a802654000 Mon Sep 17 00:00:00 2001 From: jrmccluskey Date: Wed, 17 Jun 2026 11:45:17 -0400 Subject: [PATCH 1/3] Add retry filter support to the Java Remote RetryHandler --- .../sdk/ml/inference/remote/RetryHandler.java | 26 +++- .../ml/inference/remote/RetryHandlerTest.java | 145 ++++++++++++++++++ 2 files changed, 169 insertions(+), 2 deletions(-) create mode 100644 sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RetryHandlerTest.java diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java index cf1b2f282c6c..78607f4a5c0d 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java @@ -34,13 +34,24 @@ public class RetryHandler implements Serializable { private final Duration initialBackoff; private final Duration maxBackoff; private final Duration maxCumulativeBackoff; + private final RetryFilter retryFilter; + + @FunctionalInterface + public interface RetryFilter extends Serializable { + boolean shouldRetry(Exception e); + } private RetryHandler( - int maxRetries, Duration initialBackoff, Duration maxBackoff, Duration maxCumulativeBackoff) { + int maxRetries, + Duration initialBackoff, + Duration maxBackoff, + Duration maxCumulativeBackoff, + RetryFilter retryFilter) { this.maxRetries = maxRetries; this.initialBackoff = initialBackoff; this.maxBackoff = maxBackoff; this.maxCumulativeBackoff = maxCumulativeBackoff; + this.retryFilter = retryFilter; } public static RetryHandler withDefaults() { @@ -48,10 +59,16 @@ public static RetryHandler withDefaults() { 3, // maxRetries Duration.standardSeconds(1), // initialBackoff Duration.standardSeconds(10), // maxBackoff per retry - Duration.standardMinutes(1) // maxCumulativeBackoff + Duration.standardMinutes(1), // maxCumulativeBackoff + (RetryFilter) e -> true // retryFilter default to retrying all exceptions ); } + public RetryHandler withRetryFilter(RetryFilter retryFilter) { + return new RetryHandler( + maxRetries, initialBackoff, maxBackoff, maxCumulativeBackoff, retryFilter); + } + public T execute(RetryableRequest request) throws Exception { BackOff backoff = FluentBackoff.DEFAULT @@ -72,6 +89,11 @@ public T execute(RetryableRequest request) throws Exception { } catch (Exception e) { lastException = e; + if (!retryFilter.shouldRetry(e)) { + LOG.warn("Exception not eligible for retry. Failing immediately.", e); + throw e; + } + long backoffMillis = backoff.nextBackOffMillis(); if (backoffMillis == BackOff.STOP) { diff --git a/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RetryHandlerTest.java b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RetryHandlerTest.java new file mode 100644 index 000000000000..30f93dbdc0ee --- /dev/null +++ b/sdks/java/ml/inference/remote/src/test/java/org/apache/beam/sdk/ml/inference/remote/RetryHandlerTest.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.beam.sdk.ml.inference.remote; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; + +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class RetryHandlerTest { + + private static class NonRetryableException extends Exception { + NonRetryableException(String message) { + super(message); + } + } + + private static class RetryableException extends Exception { + RetryableException(String message) { + super(message); + } + } + + @Test + public void testRetryWithDefaultFilter() throws Exception { + RetryHandler handler = RetryHandler.withDefaults(); + AtomicInteger attempts = new AtomicInteger(0); + + RuntimeException thrown = + assertThrows( + RuntimeException.class, + () -> + handler.execute( + () -> { + attempts.incrementAndGet(); + throw new Exception("Always fails"); + })); + + assertTrue(thrown.getMessage().contains("exhausting retries")); + assertEquals(4, attempts.get()); // 1 initial attempt + 3 retries + } + + @Test + public void testRetryWithCustomFilter_ShouldNotRetry() { + RetryHandler handler = + RetryHandler.withDefaults() + .withRetryFilter( + e -> { + if (e instanceof NonRetryableException) { + return false; + } + return true; + }); + + AtomicInteger attempts = new AtomicInteger(0); + + NonRetryableException thrown = + assertThrows( + NonRetryableException.class, + () -> + handler.execute( + () -> { + attempts.incrementAndGet(); + throw new NonRetryableException("Should not retry"); + })); + + assertEquals("Should not retry", thrown.getMessage()); + assertEquals(1, attempts.get()); // 1 initial attempt, no retries + } + + @Test + public void testRetryWithCustomFilter_ShouldRetry() { + RetryHandler handler = + RetryHandler.withDefaults() + .withRetryFilter( + e -> { + if (e instanceof NonRetryableException) { + return false; + } + return true; + }); + + AtomicInteger attempts = new AtomicInteger(0); + + RuntimeException thrown = + assertThrows( + RuntimeException.class, + () -> + handler.execute( + () -> { + attempts.incrementAndGet(); + throw new RetryableException("Should retry"); + })); + + assertTrue(thrown.getMessage().contains("exhausting retries")); + assertEquals(4, attempts.get()); // 1 initial attempt + 3 retries + } + + @Test + public void testRetryWithCustomFilter_EventualSuccess() throws Exception { + RetryHandler handler = + RetryHandler.withDefaults() + .withRetryFilter( + e -> { + if (e instanceof NonRetryableException) { + return false; + } + return true; + }); + + AtomicInteger attempts = new AtomicInteger(0); + + String result = + handler.execute( + () -> { + if (attempts.incrementAndGet() < 3) { + throw new RetryableException("Temporary failure"); + } + return "success"; + }); + + assertEquals("success", result); + assertEquals(3, attempts.get()); // 1 initial attempt + 2 retries + } +} From 352185d729b8c2822245b5cebd9edb86a72f8846 Mon Sep 17 00:00:00 2001 From: Jack McCluskey <34928439+jrmccluskey@users.noreply.github.com> Date: Wed, 17 Jun 2026 11:55:29 -0400 Subject: [PATCH 2/3] Apply suggestions from code review Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- .../apache/beam/sdk/ml/inference/remote/RetryHandler.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java index 78607f4a5c0d..dc21e59fd5bd 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java @@ -51,7 +51,7 @@ private RetryHandler( this.initialBackoff = initialBackoff; this.maxBackoff = maxBackoff; this.maxCumulativeBackoff = maxCumulativeBackoff; - this.retryFilter = retryFilter; + this.retryFilter = java.util.Objects.requireNonNull(retryFilter, "retryFilter must not be null"); } public static RetryHandler withDefaults() { @@ -60,7 +60,7 @@ public static RetryHandler withDefaults() { Duration.standardSeconds(1), // initialBackoff Duration.standardSeconds(10), // maxBackoff per retry Duration.standardMinutes(1), // maxCumulativeBackoff - (RetryFilter) e -> true // retryFilter default to retrying all exceptions + e -> true // retryFilter default to retrying all exceptions ); } @@ -89,7 +89,7 @@ public T execute(RetryableRequest request) throws Exception { } catch (Exception e) { lastException = e; - if (!retryFilter.shouldRetry(e)) { + if (retryFilter != null && !retryFilter.shouldRetry(e)) { LOG.warn("Exception not eligible for retry. Failing immediately.", e); throw e; } From 7b04c86435147d9b2ddf1340524f1c4356c5ef5b Mon Sep 17 00:00:00 2001 From: jrmccluskey Date: Wed, 17 Jun 2026 13:10:24 -0400 Subject: [PATCH 3/3] spotless --- .../org/apache/beam/sdk/ml/inference/remote/RetryHandler.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java index dc21e59fd5bd..b124840c8ba3 100644 --- a/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java +++ b/sdks/java/ml/inference/remote/src/main/java/org/apache/beam/sdk/ml/inference/remote/RetryHandler.java @@ -51,7 +51,8 @@ private RetryHandler( this.initialBackoff = initialBackoff; this.maxBackoff = maxBackoff; this.maxCumulativeBackoff = maxCumulativeBackoff; - this.retryFilter = java.util.Objects.requireNonNull(retryFilter, "retryFilter must not be null"); + this.retryFilter = + java.util.Objects.requireNonNull(retryFilter, "retryFilter must not be null"); } public static RetryHandler withDefaults() {