From 957335045ebe59c5e7b10a4bc6ef83ef4b7b8c39 Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 16 Jun 2026 22:27:37 +0000 Subject: [PATCH 1/2] [Dataflow Streaming] Access state/timerinternals via StepContext --- .../apache/beam/runners/core/DoFnRunners.java | 4 +- .../core/LateDataDroppingDoFnRunner.java | 17 +++--- .../core/LateDataDroppingDoFnRunnerTest.java | 4 +- .../beam/runners/core/ReduceFnTester.java | 20 ++++++- .../streaming/WindowDoFnOperator.java | 2 +- .../worker/GroupAlsoByWindowsParDoFn.java | 2 +- .../worker/StreamingModeExecutionContext.java | 20 +++++++ .../worker/WindmillTimerInternals.java | 23 +++++++ .../state/WindmillStateInternals.java | 17 ++++++ .../StreamingGroupAlsoByWindowFnsTest.java | 3 +- .../StreamingModeExecutionContextTest.java | 60 +++++++++++++++++++ 11 files changed, 155 insertions(+), 17 deletions(-) diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java index 2cab758d36e0..0608ea24b049 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/DoFnRunners.java @@ -85,9 +85,9 @@ public static DoFnRunner simpleRunner( public static DoFnRunner, KV> lateDataDroppingRunner( DoFnRunner, KV> wrappedRunner, - TimerInternals timerInternals, + StepContext stepContext, WindowingStrategy windowingStrategy) { - return new LateDataDroppingDoFnRunner<>(wrappedRunner, windowingStrategy, timerInternals); + return new LateDataDroppingDoFnRunner<>(wrappedRunner, windowingStrategy, stepContext); } /** diff --git a/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java b/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java index dfab198f8932..1d5908c24975 100644 --- a/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java +++ b/runners/core-java/src/main/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunner.java @@ -58,9 +58,9 @@ public class LateDataDroppingDoFnRunner, KV> doFnRunner, WindowingStrategy windowingStrategy, - TimerInternals timerInternals) { + StepContext stepContext) { this.doFnRunner = doFnRunner; - lateDataFilter = new LateDataFilter(windowingStrategy, timerInternals); + lateDataFilter = new LateDataFilter(windowingStrategy, stepContext); } @Override @@ -116,13 +116,12 @@ public void onWindowExpiration(BoundedWindow window, Instant timestamp, K @VisibleForTesting static class LateDataFilter { private final WindowingStrategy windowingStrategy; - private final TimerInternals timerInternals; + private final StepContext stepContext; private final Counter droppedDueToLateness; - public LateDataFilter( - WindowingStrategy windowingStrategy, TimerInternals timerInternals) { + public LateDataFilter(WindowingStrategy windowingStrategy, StepContext stepContext) { this.windowingStrategy = windowingStrategy; - this.timerInternals = timerInternals; + this.stepContext = stepContext; this.droppedDueToLateness = Metrics.counter(LateDataDroppingDoFnRunner.class, DROPPED_DUE_TO_LATENESS); } @@ -146,8 +145,8 @@ public Iterable> filter( element.getTimestamp(), key, window, - timerInternals.currentInputWatermarkTime(), - timerInternals.currentOutputWatermarkTime()); + stepContext.timerInternals().currentInputWatermarkTime(), + stepContext.timerInternals().currentOutputWatermarkTime()); } else { nonLateElements.add( WindowedValues.of( @@ -173,7 +172,7 @@ public Iterable> filter( * @return True if element can be dropped. */ private boolean canDropDueToExpiredWindow(BoundedWindow window) { - Instant inputWM = timerInternals.currentInputWatermarkTime(); + Instant inputWM = stepContext.timerInternals().currentInputWatermarkTime(); return LateDataUtils.garbageCollectionTime(window, windowingStrategy).isBefore(inputWM); } } diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunnerTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunnerTest.java index 5a6257ba0a1b..3ce098ee05ea 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunnerTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/LateDataDroppingDoFnRunnerTest.java @@ -49,6 +49,7 @@ public class LateDataDroppingDoFnRunnerTest { private static final FixedWindows WINDOW_FN = FixedWindows.of(Duration.millis(10)); @Mock private TimerInternals mockTimerInternals; + @Mock private StepContext mockStepContext; @Before public void setUp() { @@ -60,9 +61,10 @@ public void testLateDataFilter() throws Exception { MetricsContainerImpl container = new MetricsContainerImpl("any"); MetricsEnvironment.setCurrentContainer(container); when(mockTimerInternals.currentInputWatermarkTime()).thenReturn(new Instant(15L)); + when(mockStepContext.timerInternals()).thenReturn(mockTimerInternals); LateDataFilter lateDataFilter = - new LateDataFilter(WindowingStrategy.of(WINDOW_FN), mockTimerInternals); + new LateDataFilter(WindowingStrategy.of(WINDOW_FN), mockStepContext); Iterable> actual = lateDataFilter.filter( diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java index fe2c0a253467..6ff3b99b8605 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/ReduceFnTester.java @@ -47,6 +47,7 @@ import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.Combine.CombineFn; import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext; +import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.PaneInfo; @@ -94,6 +95,23 @@ public class ReduceFnTester { private final TestInMemoryStateInternals stateInternals = new TestInMemoryStateInternals<>(KEY); private final InMemoryTimerInternals timerInternals = new InMemoryTimerInternals(); + private final StepContext stepContext = + new StepContext() { + @Override + public StateInternals stateInternals() { + return stateInternals; + } + + @Override + public TimerInternals timerInternals() { + return timerInternals; + } + + @Override + public BundleFinalizer bundleFinalizer() { + throw new UnsupportedOperationException(); + } + }; private final WindowFn windowFn; private final TestWindowedValueReceiver testOutputter; @@ -577,7 +595,7 @@ public final void injectElements(List> values) throws E ReduceFnRunner runner = createRunner(); runner.processElements( - new LateDataDroppingDoFnRunner.LateDataFilter(objectStrategy, timerInternals) + new LateDataDroppingDoFnRunner.LateDataFilter(objectStrategy, stepContext) .filter(KEY, inputs)); // Persist after each bundle. diff --git a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java index fad97e2d7dfd..5043ed9e5303 100644 --- a/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java +++ b/runners/flink/src/main/java/org/apache/beam/runners/flink/translation/wrappers/streaming/WindowDoFnOperator.java @@ -118,7 +118,7 @@ protected DoFnRunner, KV> createWrappingDoF // for some K, V return DoFnRunners.lateDataDroppingRunner( - (DoFnRunner) doFnRunner, timerInternals, windowingStrategy); + (DoFnRunner) doFnRunner, stepContext, windowingStrategy); } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowsParDoFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowsParDoFn.java index e204a78a7d2e..6c92b7f10ea1 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowsParDoFn.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowsParDoFn.java @@ -216,7 +216,7 @@ private DoFnRunner>> createRunner() { } return (DoFnRunner>>) DoFnRunners., W>lateDataDroppingRunner( - streamingGABWRunner, stepContext.timerInternals(), windowingStrategy); + streamingGABWRunner, stepContext, windowingStrategy); } else { if (hasStreamingSideInput) { return new StreamingSideInputDoFnRunner<>( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java index d88864745648..66df22333944 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContext.java @@ -891,6 +891,26 @@ public void flushState() { checkStateNotNull(systemTimerInternals).persistTo(builder); checkStateNotNull(userTimerInternals).persistTo(builder); } + poisonStateAndTimerInternals(); + } + + /** + * Poisons the state and timer internals to prevent any subsequent (stale) usage. + * + *

This ensures that if these key-specific internals are incorrectly cached and used after + * the key's execution context has finished, it will fail fast with a clear error rather than + * silently corrupting state. + */ + private void poisonStateAndTimerInternals() { + if (stateInternals != null) { + stateInternals.poison(); + } + if (systemTimerInternals != null) { + systemTimerInternals.poison(); + } + if (userTimerInternals != null) { + userTimerInternals.poison(); + } } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillTimerInternals.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillTimerInternals.java index 07ce62d59339..9b8757769db0 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillTimerInternals.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/WindmillTimerInternals.java @@ -64,6 +64,19 @@ class WindmillTimerInternals implements TimerInternals { private final Consumer onTimerModified; private final WindmillTagEncoding windmillTagEncoding; + private boolean poisoned = false; + + public void poison() { + this.poisoned = true; + } + + private void checkNotPoisoned() { + if (poisoned) { + throw new IllegalStateException( + "WindmillTimerInternals is poisoned and cannot be used after flushState()."); + } + } + public WindmillTimerInternals( String stateFamily, // unique identifies a step WindmillTimerType type, @@ -80,12 +93,14 @@ public WindmillTimerInternals( } public WindmillTimerInternals withType(WindmillTimerType type) { + checkNotPoisoned(); return new WindmillTimerInternals( stateFamily, type, processingTime, watermarks, windmillTagEncoding, onTimerModified); } @Override public void setTimer(TimerData timerKey) { + checkNotPoisoned(); String timerDataKey = getTimerDataKey(timerKey.getTimerId(), timerKey.getTimerFamilyId()); timerMap.put( new SimpleEntry<>(timerDataKey, timerKey.getNamespace()), @@ -101,6 +116,7 @@ public void setTimer( Instant timestamp, Instant outputTimestamp, TimeDomain timeDomain) { + checkNotPoisoned(); TimerData timer = TimerData.of( timerId, @@ -124,6 +140,7 @@ private static String getTimerDataKey(String timerId, String timerFamilyId) { @Override public void deleteTimer(TimerData timerKey) { + checkNotPoisoned(); String timerDataKey = getTimerDataKey(timerKey.getTimerId(), timerKey.getTimerFamilyId()); timerMap.put( new SimpleEntry<>(timerDataKey, timerKey.getNamespace()), @@ -139,6 +156,7 @@ public void deleteTimer(StateNamespace namespace, String timerId, String timerFa @Override public void deleteTimer( StateNamespace namespace, String timerId, String timerFamilyId, TimeDomain timeDomain) { + checkNotPoisoned(); deleteTimer( TimerData.of( timerId, @@ -152,12 +170,14 @@ public void deleteTimer( @Override public Instant currentProcessingTime() { + checkNotPoisoned(); Instant now = Instant.now(); return processingTime.isAfter(now) ? processingTime : now; } @Override public @Nullable Instant currentSynchronizedProcessingTime() { + checkNotPoisoned(); return watermarks.synchronizedProcessingTime(); } @@ -172,6 +192,7 @@ public Instant currentProcessingTime() { */ @Override public Instant currentInputWatermarkTime() { + checkNotPoisoned(); return watermarks.inputDataWatermark(); } @@ -186,10 +207,12 @@ public Instant currentInputWatermarkTime() { */ @Override public @Nullable Instant currentOutputWatermarkTime() { + checkNotPoisoned(); return watermarks.outputDataWatermark(); } public void persistTo(Windmill.WorkItemCommitRequest.Builder outputBuilder) { + checkNotPoisoned(); for (Entry value : timerMap.values()) { // Regardless of whether it is set or not, it must have some TimerData stored so we // can know its time domain diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java index db036bee43c3..f10ba70212ef 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/windmill/state/WindmillStateInternals.java @@ -55,6 +55,19 @@ public class WindmillStateInternals implements StateInternals { private final CachingStateTable workItemDerivedState; private final Supplier scopedReadStateSupplier; + private boolean poisoned = false; + + public void poison() { + this.poisoned = true; + } + + private void checkNotPoisoned() { + if (poisoned) { + throw new IllegalStateException( + "WindmillStateInternals is poisoned and cannot be used after flushState()."); + } + } + public WindmillStateInternals( @Nullable K key, String stateFamily, @@ -78,6 +91,7 @@ public WindmillStateInternals( @Override public @Nullable K getKey() { + checkNotPoisoned(); return key; } @@ -104,6 +118,7 @@ private void persist( } public void persist(final Windmill.WorkItemCommitRequest.Builder commitBuilder) { + checkNotPoisoned(); List> commitsToMerge = new ArrayList<>(); // Call persist on each first, which may schedule some futures for reading. @@ -126,12 +141,14 @@ public void persist(final Windmill.WorkItemCommitRequest.Builder commitBuilder) @Override public T state(StateNamespace namespace, StateTag address) { + checkNotPoisoned(); return workItemState.get(namespace, address, StateContexts.nullContext()); } @Override public T state( StateNamespace namespace, StateTag address, StateContext c) { + checkNotPoisoned(); return workItemState.get(namespace, address, c); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingGroupAlsoByWindowFnsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingGroupAlsoByWindowFnsTest.java index c31c39de9ccb..fe5ba60acec3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingGroupAlsoByWindowFnsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingGroupAlsoByWindowFnsTest.java @@ -753,8 +753,7 @@ DoFnRunner, KV> makeRunnerForGabw NullSideInputReader.empty(), outputManager, stepContext); - return DoFnRunners.lateDataDroppingRunner( - doFnRunner, stepContext.timerInternals(), windowingStrategy); + return DoFnRunners.lateDataDroppingRunner(doFnRunner, stepContext, windowingStrategy); } private IntervalWindow window(long start, long end) { diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java index 13601410bfd9..056185c587f3 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingModeExecutionContextTest.java @@ -42,6 +42,7 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.beam.runners.core.SideInputReader; +import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespaceForTest; import org.apache.beam.runners.core.TimerInternals; import org.apache.beam.runners.core.TimerInternals.TimerData; @@ -460,4 +461,63 @@ public void testSetBacklogBytes() { assertEquals(1234, outputBuilder.getSourceBacklogBytes()); } + + @Test + public void testInternalsPoisonedAfterFlushState() throws Exception { + Windmill.WorkItemCommitRequest.Builder outputBuilder = + Windmill.WorkItemCommitRequest.newBuilder(); + NameContext nameContext = NameContextsForTests.nameContextForTest(); + DataflowOperationContext operationContext = + executionContext.createOperationContext(nameContext); + StreamingModeExecutionContext.StepContext stepContext = + executionContext.getStepContext(operationContext); + + executionContext.start( + "key", + createMockWork( + Windmill.WorkItem.newBuilder().setKey(ByteString.EMPTY).setWorkToken(17L).build(), + Watermarks.builder().setInputDataWatermark(new Instant(1000)).build()), + stateReader, + sideInputStateFetcher, + outputBuilder, + workExecutor); + + TimerInternals timerInternals = stepContext.timerInternals(); + StateInternals stateInternals = stepContext.stateInternals(); + + executionContext.finishKey(); + executionContext.flushState(); + + // Verify timerInternals is poisoned + try { + timerInternals.currentProcessingTime(); + org.junit.Assert.fail("Expected IllegalStateException"); + } catch (IllegalStateException e) { + assertThat(e.getMessage(), Matchers.containsString("poisoned")); + } + + // Verify stateInternals is poisoned + try { + stateInternals.getKey(); + org.junit.Assert.fail("Expected IllegalStateException"); + } catch (IllegalStateException e) { + assertThat(e.getMessage(), Matchers.containsString("poisoned")); + } + + // Verify stepContext.stateInternals() returns poisoned instance + try { + stepContext.stateInternals().getKey(); + org.junit.Assert.fail("Expected IllegalStateException"); + } catch (IllegalStateException e) { + assertThat(e.getMessage(), Matchers.containsString("poisoned")); + } + + // Verify stepContext.timerInternals() returns poisoned instance + try { + stepContext.timerInternals().currentProcessingTime(); + org.junit.Assert.fail("Expected IllegalStateException"); + } catch (IllegalStateException e) { + assertThat(e.getMessage(), Matchers.containsString("poisoned")); + } + } } From dbaf7ce37c89fc3b6b5df21c8365a8f0f412376e Mon Sep 17 00:00:00 2001 From: Arun Pandian Date: Tue, 16 Jun 2026 22:38:31 +0000 Subject: [PATCH 2/2] trigger postsubmit tests --- ...stCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json index e623d3373a93..50d17c108f2e 100644 --- a/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json +++ b/.github/trigger_files/beam_PostCommit_Java_ValidatesRunner_Dataflow_Streaming_Engine.json @@ -1,4 +1,4 @@ { "comment": "Modify this file in a trivial way to cause this test suite to run!", - "modification": 1, + "modification": 2, }