From 0359ed908516407bf9dc230671ce98513e74aaf8 Mon Sep 17 00:00:00 2001 From: Reuven Lax Date: Tue, 16 Jun 2026 10:59:03 -0700 Subject: [PATCH] handle multi-key bundles in side-input ronnre --- .../worker/GroupAlsoByWindowsParDoFn.java | 11 +-- .../worker/PartialGroupByKeyParDoFns.java | 68 +++++++++++++----- .../dataflow/worker/SimpleParDoFn.java | 35 +++------ .../dataflow/worker/SimpleParDoFnHelpers.java | 72 ++++++++++++++----- ...treamingKeyedWorkItemSideInputParDoFn.java | 49 +++---------- .../worker/StreamingSideInputDoFnRunner.java | 28 ++++++-- .../worker/StreamingSideInputProcessor.java | 2 +- .../worker/PartialGroupByKeyParDoFnsTest.java | 2 +- .../worker/SimpleParDoFnHelpersTest.java | 7 +- .../StreamingSideInputDoFnRunnerTest.java | 22 +++--- .../StreamingSideInputProcessorTest.java | 2 +- 11 files changed, 171 insertions(+), 127 deletions(-) diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowsParDoFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowsParDoFn.java index e204a78a7d2e..91b97504f76e 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowsParDoFn.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/GroupAlsoByWindowsParDoFn.java @@ -221,11 +221,12 @@ private DoFnRunner>> createRunner() { if (hasStreamingSideInput) { return new StreamingSideInputDoFnRunner<>( basicRunner, - new StreamingSideInputFetcher<>( - sideInputViews, - inputCoder, - windowingStrategy, - (StreamingModeExecutionContext.StreamingModeStepContext) stepContext)); + () -> + new StreamingSideInputFetcher<>( + sideInputViews, + inputCoder, + windowingStrategy, + (StreamingModeExecutionContext.StreamingModeStepContext) stepContext)); } else { return basicRunner; } diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PartialGroupByKeyParDoFns.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PartialGroupByKeyParDoFns.java index a6d7810412a1..aec7172db97a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PartialGroupByKeyParDoFns.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/PartialGroupByKeyParDoFns.java @@ -23,6 +23,7 @@ import com.google.api.services.dataflow.model.SideInputInfo; import java.util.List; import java.util.Set; +import java.util.function.Supplier; import org.apache.beam.runners.core.GlobalCombineFnRunner; import org.apache.beam.runners.core.GlobalCombineFnRunners; import org.apache.beam.runners.core.NullSideInputReader; @@ -48,6 +49,7 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.WindowedValue; import org.apache.beam.sdk.values.WindowedValues; +import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.CountingOutputStream; @@ -68,7 +70,7 @@ public static ParDoFn create( DataflowExecutionContext executionContext, DataflowOperationContext operationContext) throws Exception { - AppliedCombineFn combineFn; + AppliedCombineFn combineFn; SideInputReader sideInputReader; StepContext stepContext; if (cloudUserFn == null) { @@ -80,8 +82,8 @@ public static ParDoFn create( SerializableUtils.deserializeFromByteArray( getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN), "serialized combine fn"); @SuppressWarnings("unchecked") - AppliedCombineFn combineFnUnchecked = - ((AppliedCombineFn) deserializedFn); + AppliedCombineFn combineFnUnchecked = + ((AppliedCombineFn) deserializedFn); combineFn = combineFnUnchecked; sideInputReader = @@ -136,13 +138,15 @@ static ParDoFn create( if (sideInputReader.isEmpty()) { return new SimplePartialGroupByKeyParDoFn<>(groupingTable, receiver); } else if (options.as(StreamingOptions.class).isStreaming()) { - StreamingSideInputFetcher, ?> sideInputFetcher = - new StreamingSideInputFetcher<>( - combineFn.getSideInputViews(), - combineFn.getKvCoder(), - combineFn.getWindowingStrategy(), - (StreamingModeExecutionContext.StreamingModeStepContext) stepContext); - return new StreamingSideInputPGBKParDoFn<>(groupingTable, receiver, sideInputFetcher); + Supplier, BoundedWindow>> sideInputFetcherSupplier = + () -> + new StreamingSideInputFetcher<>( + combineFn.getSideInputViews(), + combineFn.getKvCoder(), + (WindowingStrategy) combineFn.getWindowingStrategy(), + (StreamingModeExecutionContext.StreamingModeStepContext) stepContext); + return new StreamingSideInputPGBKParDoFn<>( + groupingTable, receiver, sideInputFetcherSupplier); } else { return new BatchSideInputPGBKParDoFn<>(groupingTable, receiver); } @@ -240,7 +244,7 @@ public WindowingCoderGroupingKeyCreator(Coder coder) { } @Override - public Object createGroupingKey(WindowedValue key) throws Exception { + public Object createGroupingKey(WindowedValue key) { // Ignore timestamp for grouping purposes. // The PGBK output will inherit the timestamp of one of its inputs. return WindowedValues.builder(key) @@ -333,19 +337,21 @@ static class StreamingSideInputPGBKParDoFn, InputT, AccumT> groupingTable; private final Receiver receiver; - private final StreamingSideInputFetcher, W> sideInputFetcher; + private final Supplier, W>> sideInputFetcherSupplier; + private StreamingSideInputFetcher, W> sideInputFetcher = null; + private boolean activeKey = false; StreamingSideInputPGBKParDoFn( GroupingTable, InputT, AccumT> groupingTable, Receiver receiver, - StreamingSideInputFetcher, W> sideInputFetcher) { + Supplier, W>> sideInputFetcherSupplier) { this.groupingTable = groupingTable; this.receiver = receiver; - this.sideInputFetcher = sideInputFetcher; + this.sideInputFetcherSupplier = sideInputFetcherSupplier; } - @Override - public void startBundle(Receiver... receivers) throws Exception { + private void onStartKey() throws Exception { + this.sideInputFetcher = sideInputFetcherSupplier.get(); // Find the set of ready windows. Set readyWindows = sideInputFetcher.getReadyWindows(); @@ -361,12 +367,23 @@ public void startBundle(Receiver... receivers) throws Exception { elementsBag.clear(); } sideInputFetcher.releaseBlockedWindows(readyWindows); + this.activeKey = true; + } + + @Override + public void startBundle(Receiver... receivers) throws Exception { + this.activeKey = false; } @Override public void processElement(Object elem) throws Exception { @SuppressWarnings({"unchecked"}) WindowedValue> input = (WindowedValue>) elem; + + if (!activeKey) { + onStartKey(); + } + for (BoundedWindow w : input.getWindows()) { WindowedValue> windowsExpandedInput = WindowedValues.of(input.getValue(), input.getTimestamp(), w, input.getPaneInfo()); @@ -378,15 +395,28 @@ public void processElement(Object elem) throws Exception { } @Override - public void processTimers() {} + public void processTimers() throws Exception { + if (!activeKey) { + onStartKey(); + } + } @Override - public void finishKey(Object key) throws Exception {} + public void finishKey(Object key) throws Exception { + if (!activeKey) { + onStartKey(); + } + sideInputFetcher.persist(); + sideInputFetcher = null; + this.activeKey = false; + } @Override public void finishBundle() throws Exception { groupingTable.flush(receiver); - sideInputFetcher.persist(); + if (sideInputFetcher != null) { + sideInputFetcher.persist(); + } } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java index e0f1e0f410cd..01304207d2aa 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFn.java @@ -17,7 +17,6 @@ */ package org.apache.beam.runners.dataflow.worker; -import java.io.Closeable; import java.util.Collection; import java.util.Collections; import java.util.Iterator; @@ -50,7 +49,7 @@ "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) public class SimpleParDoFn implements ParDoFn { - private final SimpleParDoFnHelpers helpers; + private final SimpleParDoFnHelpers helpers; private @Nullable StreamingSideInputProcessor sideInputProcessor; /** Creates a {@link SimpleParDoFn} using basic information about the step being executed. */ @@ -76,7 +75,8 @@ public class SimpleParDoFn implements operationContext, doFnSchemaInformation, sideInputMapping, - runnerFactory); + runnerFactory, + this::onStartKey); } @Override @@ -86,17 +86,14 @@ public void startBundle(Receiver... receivers) throws Exception { // There is non-trivial setup that needs to be performed for watermark propagation // even on empty bundles. helpers.reallyStartBundle(); - onStartKey(); } } - protected void onStartKey() { - // TODO(relax): This assumes single-key bundles, which will change! Refactor this to not make - // this assumption. + protected void onStartKey(Object key) { if (helpers.hasStreamingSideInput) { sideInputProcessor = new StreamingSideInputProcessor<>( - new StreamingSideInputFetcher( + new StreamingSideInputFetcher<>( helpers.fnInfo.getSideInputViews(), helpers.fnInfo.getInputCoder(), (WindowingStrategy) helpers.fnInfo.getWindowingStrategy(), @@ -123,20 +120,7 @@ protected void onStartKey() { @Override @SuppressWarnings("unchecked") public void processElement(Object untypedElem) throws Exception { - if (helpers.fnRunner == null) { - // If we need to run reallyStartBundle in here, we need to make sure to switch the state - // sampler into the start state. - try (Closeable start = helpers.operationContext.enterStart()) { - helpers.reallyStartBundle(); - onStartKey(); - } - } - helpers.outputsPerElementTracker.onProcessElement(); - - WindowedValue elem = (WindowedValue) untypedElem; - onProcessWindowedValue(elem); - - helpers.outputsPerElementTracker.onProcessElementSuccess(); + helpers.processElement(null, (WindowedValue) untypedElem, this::onProcessWindowedValue); } protected void onProcessWindowedValue(WindowedValue elem) { @@ -169,7 +153,6 @@ protected void onProcessWindowedValue(WindowedValue elem) { @Override public void processTimers() throws Exception { - // Note: We need to get windowCoder to decode the timers. If we haven't already deserialized // the fnInfo, we peek at a new instance to retrieve that. If this extra deserialization becomes // excessively costly, we could either (1) have the DoFnInstanceManager remember the associated @@ -185,18 +168,18 @@ public void processTimers() throws Exception { SimpleParDoFnHelpers.TimerType.USER, helpers.userStepContext, windowCoder, - this::onStartKey, () -> sideInputProcessor); helpers.processTimers( SimpleParDoFnHelpers.TimerType.SYSTEM, helpers.stepContext, windowCoder, - this::onStartKey, () -> sideInputProcessor); } @Override - public void finishKey(Object key) throws Exception {} + public void finishKey(Object key) throws Exception { + helpers.finishKey(sideInputProcessor); + } @Override public void finishBundle() throws Exception { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnHelpers.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnHelpers.java index 964cf2323d51..d8723ce56bcb 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnHelpers.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnHelpers.java @@ -25,6 +25,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import java.util.function.Supplier; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.SideInputReader; @@ -69,7 +70,7 @@ "rawtypes", // TODO(https://github.com/apache/beam/issues/20447) "nullness" // TODO(https://github.com/apache/beam/issues/20497) }) -class SimpleParDoFnHelpers { +class SimpleParDoFnHelpers { private static final Logger LOG = LoggerFactory.getLogger(SimpleParDoFnHelpers.class); // TODO: Remove once Distributions has shipped. @@ -103,6 +104,8 @@ class SimpleParDoFnHelpers { // This may additionally be null if it is not a real DoFn but an OldDoFn or // GroupAlsoByWindowViaWindowSetDoFn protected @Nullable DoFnSignature fnSignature; + boolean activeKey = false; + private final Consumer onStartKey; SimpleParDoFnHelpers( PipelineOptions options, @@ -114,7 +117,8 @@ class SimpleParDoFnHelpers { DataflowOperationContext operationContext, DoFnSchemaInformation doFnSchemaInformation, Map> sideInputMapping, - DoFnRunnerFactory runnerFactory) { + DoFnRunnerFactory runnerFactory, + Consumer onStartKey) { this.options = options; this.doFnInstanceManager = doFnInstanceManager; @@ -145,6 +149,11 @@ class SimpleParDoFnHelpers { this.outputsPerElementTracker = createOutputsPerElementTracker(); this.doFnSchemaInformation = doFnSchemaInformation; this.sideInputMapping = sideInputMapping; + this.onStartKey = + k -> { + onStartKey.accept(k); + this.activeKey = true; + }; } boolean hasState() { @@ -157,6 +166,7 @@ void startBundle(Receiver... receivers) throws Exception { "unexpected number of receivers for DoFn"); this.receivers = receivers; + this.activeKey = false; } void reallyStartBundle() throws Exception { @@ -233,18 +243,29 @@ public void output(TupleTag tag, WindowedValue output) { fnRunner.startBundle(); } + void finishKey(StreamingSideInputProcessor sideInputProcessor) { + if (!activeKey) { + // This means that there were no elements for this key. Try to unblock any queued elements. + onStartKey.accept((K) stepContext.stateInternals().getKey()); + } + if (sideInputProcessor != null) { + sideInputProcessor.handleFinishKeyOrBundle(); + } + this.activeKey = false; + } + void finishBundle(StreamingSideInputProcessor sideInputProcessor) throws Exception { if (fnRunner != null) { fnRunner.finishBundle(); if (sideInputProcessor != null) { - sideInputProcessor.handleFinishBundle(); + sideInputProcessor.handleFinishKeyOrBundle(); } doFnInstanceManager.complete(fnInfo); fnRunner = null; fnInfo = null; fnSignature = null; - sideInputProcessor = null; } + this.activeKey = false; } void abort() throws Exception { @@ -253,6 +274,28 @@ void abort() throws Exception { fnInfo = null; } + void processElement( + @Nullable K key, WindowedValue element, Consumer> consumer) + throws Exception { + if (fnRunner == null) { + // If we need to run reallyStartBundle in here, we need to make sure to switch the state + // sampler into the start state. + try (Closeable start = operationContext.enterStart()) { + reallyStartBundle(); + } + } + + if (!activeKey) { + if (key == null && stepContext.stateInternals() != null) { + key = (K) stepContext.stateInternals().getKey(); + } + onStartKey.accept(key); + } + outputsPerElementTracker.onProcessElement(); + consumer.accept(element); + outputsPerElementTracker.onProcessElementSuccess(); + } + @VisibleForTesting static final String CLEANUP_TIMER_ID = "cleanup-timer"; enum TimerType { @@ -261,8 +304,7 @@ enum TimerType { public void processTimer( SimpleParDoFnHelpers doFn, TimerInternals.TimerData timer, - Supplier> sideInputProcessor) - throws Exception { + Supplier> sideInputProcessor) { doFn.processUserTimer(timer, sideInputProcessor.get()); } }, @@ -271,8 +313,7 @@ public void processTimer( public void processTimer( SimpleParDoFnHelpers doFn, TimerInternals.TimerData timer, - Supplier> sideInputProcessor) - throws Exception { + Supplier> sideInputProcessor) { throw new UnsupportedOperationException( "Attempt to deliver a timer to a DoFn, but timers are not supported here."); } @@ -282,8 +323,7 @@ public void processTimer( public void processTimer( SimpleParDoFnHelpers doFn, TimerInternals.TimerData timer, - Supplier> sideInputProcessor) - throws Exception { + Supplier> sideInputProcessor) { doFn.processSystemTimer(timer, sideInputProcessor.get()); } }; @@ -293,23 +333,24 @@ public abstract void processTimer( TimerInternals.TimerData timer, Supplier> sideInputProcessor) throws Exception; - }; + } void processTimers( TimerType mode, DataflowExecutionContext.DataflowStepContext context, Coder windowCoder, - Runnable startKey, Supplier> sideInputProcessor) throws Exception { TimerInternals.TimerData timer = context.getNextFiredTimer(windowCoder); - if (timer != null && fnRunner == null) { // If we need to run reallyStartBundle in here, we need to make sure to switch the state // sampler into the start state. try (Closeable start = operationContext.enterStart()) { reallyStartBundle(); - startKey.run(); + } + + if (!activeKey) { + this.onStartKey.accept((K) context.stateInternals().getKey()); } } @@ -340,8 +381,7 @@ protected void processUserTimer( } private void processSystemTimer( - TimerInternals.TimerData timer, StreamingSideInputProcessor sideInputProcessor) - throws Exception { + TimerInternals.TimerData timer, StreamingSideInputProcessor sideInputProcessor) { // Timer owned by this class, for cleaning up state in expired windows if (timer.getTimerId().equals(CLEANUP_TIMER_ID)) { checkState( diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingKeyedWorkItemSideInputParDoFn.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingKeyedWorkItemSideInputParDoFn.java index 63de0b8d55db..63ee20782c88 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingKeyedWorkItemSideInputParDoFn.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingKeyedWorkItemSideInputParDoFn.java @@ -19,7 +19,6 @@ import com.google.api.client.util.Lists; import com.google.common.collect.Iterables; -import java.io.Closeable; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -29,15 +28,11 @@ import org.apache.beam.runners.core.KeyedWorkItem; import org.apache.beam.runners.core.KeyedWorkItems; import org.apache.beam.runners.core.SideInputReader; -import org.apache.beam.runners.core.StateNamespaces; -import org.apache.beam.runners.core.StateTag; -import org.apache.beam.runners.core.StateTags; import org.apache.beam.runners.dataflow.worker.util.ValueInEmptyWindows; import org.apache.beam.runners.dataflow.worker.util.common.worker.ParDoFn; import org.apache.beam.runners.dataflow.worker.util.common.worker.Receiver; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.options.PipelineOptions; -import org.apache.beam.sdk.state.ValueState; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.DoFnInfo; @@ -55,9 +50,8 @@ /* Similar to {@link SimpleParDoFn} but for splittable ProcessFns. */ public class StreamingKeyedWorkItemSideInputParDoFn implements ParDoFn { - private final StateTag> keyAddr; private final Coder inputCoder; - private final SimpleParDoFnHelpers, OutputT, W> helpers; + private final SimpleParDoFnHelpers, OutputT, W> helpers; protected @Nullable StreamingSideInputProcessor sideInputProcessor; StreamingKeyedWorkItemSideInputParDoFn( @@ -84,15 +78,11 @@ public class StreamingKeyedWorkItemSideInputParDoFn keyValue() { - return helpers.stepContext.stateInternals().state(StateNamespaces.global(), keyAddr); - } - @Override public void startBundle(Receiver... receivers) throws Exception { helpers.startBundle(receivers); @@ -100,11 +90,10 @@ public void startBundle(Receiver... receivers) throws Exception { // There is non-trivial setup that needs to be performed for watermark propagation // even on empty bundles. helpers.reallyStartBundle(); - onStartKey(); } } - protected void onStartKey() { + protected void onStartKey(@Nullable K key) { if (helpers.hasStreamingSideInput) { sideInputProcessor = new StreamingSideInputProcessor<>( @@ -118,9 +107,6 @@ protected void onStartKey() { if (sideInputProcessor != null) { boolean hasState = helpers.hasState(); - - // TODO(relax): We should be able to get this without writing it to state! - @Nullable K key = keyValue().read(); if (key != null) { sideInputProcessor.tryUnblockElementsAndTimers( (unblockedElements, unblockedTimers) -> { @@ -148,26 +134,13 @@ protected void onStartKey() { @Override @SuppressWarnings("unchecked") public void processElement(Object untypedElem) throws Exception { - if (helpers.fnRunner == null) { - // If we need to run reallyStartBundle in here, we need to make sure to switch the state - // sampler into the start state. - try (Closeable start = helpers.operationContext.enterStart()) { - helpers.reallyStartBundle(); - onStartKey(); - } - } - helpers.outputsPerElementTracker.onProcessElement(); - - WindowedValue> elem = + WindowedValue> typedElem = (WindowedValue>) untypedElem; - onProcessWindowedValue(elem); - - helpers.outputsPerElementTracker.onProcessElementSuccess(); + helpers.processElement(typedElem.getValue().key(), typedElem, this::onProcessWindowedValue); } @Override public void processTimers() throws Exception { - // Note: We need to get windowCoder to decode the timers. If we haven't already deserialized // the fnInfo, we peek at a new instance to retrieve that. If this extra deserialization becomes // excessively costly, we could either (1) have the DoFnInstanceManager remember the associated @@ -183,18 +156,18 @@ public void processTimers() throws Exception { SimpleParDoFnHelpers.TimerType.FAIL_USER, helpers.userStepContext, windowCoder, - this::onStartKey, () -> sideInputProcessor); helpers.processTimers( SimpleParDoFnHelpers.TimerType.SYSTEM, helpers.stepContext, windowCoder, - this::onStartKey, () -> sideInputProcessor); } @Override - public void finishKey(Object key) throws Exception {} + public void finishKey(Object key) throws Exception { + helpers.finishKey(sideInputProcessor); + } @Override public void finishBundle() throws Exception { @@ -208,10 +181,6 @@ public void abort() throws Exception { } protected void onProcessWindowedValue(WindowedValue> elem) { - // TODO: Get rid of this! - final K key = elem.getValue().key(); - keyValue().write(key); - boolean hasState = helpers.hasState(); Collection windowsProcessed; if (sideInputProcessor != null) { diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunner.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunner.java index ef1a5922fe5c..9ea169427fa9 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunner.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunner.java @@ -18,10 +18,12 @@ package org.apache.beam.runners.dataflow.worker; import java.util.Iterator; +import java.util.function.Supplier; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.sdk.state.TimeDomain; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.util.Preconditions; import org.apache.beam.sdk.values.CausedByDrain; import org.apache.beam.sdk.values.WindowedValue; import org.checkerframework.checker.nullness.qual.Nullable; @@ -37,28 +39,40 @@ public class StreamingSideInputDoFnRunner implements DoFnRunner { private final DoFnRunner simpleDoFnRunner; - private final StreamingSideInputProcessor sideInputProcessor; + private @Nullable StreamingSideInputProcessor sideInputProcessor = null; + private Supplier> sideInputFetcherSupplier; + boolean activeKey = false; public StreamingSideInputDoFnRunner( DoFnRunner simpleDoFnRunner, - StreamingSideInputFetcher sideInputFetcher) { + Supplier> sideInputFetcherSupplier) { this.simpleDoFnRunner = simpleDoFnRunner; - this.sideInputProcessor = new StreamingSideInputProcessor<>(sideInputFetcher); + this.sideInputFetcherSupplier = sideInputFetcherSupplier; } @Override public void startBundle() { simpleDoFnRunner.startBundle(); + this.activeKey = false; + } + + private void onNewKey() { + this.sideInputProcessor = new StreamingSideInputProcessor<>(sideInputFetcherSupplier.get()); sideInputProcessor.tryUnblockElements( unblocked -> { for (WindowedValue elem : unblocked) { simpleDoFnRunner.processElement(elem); } }); + this.activeKey = true; } @Override public void processElement(WindowedValue compressedElem) { + if (!activeKey) { + onNewKey(); + } + Preconditions.checkStateNotNull(sideInputProcessor); for (Iterator> it = sideInputProcessor.handleProcessElement(compressedElem); it.hasNext(); ) { @@ -83,13 +97,19 @@ public void onTimer( @Override public void finishKey(KeyT key) { + if (!activeKey) { + // This means that there were no elements for this key. Try to unblock any queued elements. + onNewKey(); + } + Preconditions.checkStateNotNull(sideInputProcessor).handleFinishKeyOrBundle(); simpleDoFnRunner.finishKey(key); + this.activeKey = false; } @Override public void finishBundle() { simpleDoFnRunner.finishBundle(); - sideInputProcessor.handleFinishBundle(); + this.activeKey = false; } @Override diff --git a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputProcessor.java b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputProcessor.java index 34c1a06d54de..0c314108e81a 100644 --- a/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputProcessor.java +++ b/runners/google-cloud-dataflow-java/worker/src/main/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputProcessor.java @@ -85,7 +85,7 @@ void tryUnblockElementsAndTimers( sideInputFetcher.releaseBlockedWindows(readyWindows); } - void handleFinishBundle() { + void handleFinishKeyOrBundle() { sideInputFetcher.persist(); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/PartialGroupByKeyParDoFnsTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/PartialGroupByKeyParDoFnsTest.java index 19af6effb44d..3c9579112145 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/PartialGroupByKeyParDoFnsTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/PartialGroupByKeyParDoFnsTest.java @@ -238,7 +238,7 @@ public void testPartialGroupByKeyWithCombinerAndSideInputs() throws Exception { new CoderSizeEstimator(WindowedValues.getValueOnlyCoder(keyCoder)), new CoderSizeEstimator(valueCoder)), receiver, - mockSideInputFetcher); + () -> mockSideInputFetcher); Set readyWindows = ImmutableSet.of(GlobalWindow.INSTANCE); when(mockSideInputFetcher.getReadyWindows()).thenReturn(readyWindows); diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnHelpersTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnHelpersTest.java index 6bbbf953967d..83974c32ba79 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnHelpersTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/SimpleParDoFnHelpersTest.java @@ -65,7 +65,7 @@ public void processElement() {} private TestDoFn doFn = new TestDoFn(); - private SimpleParDoFnHelpers helpers; + private SimpleParDoFnHelpers helpers; @Before @SuppressWarnings("unchecked") @@ -94,7 +94,8 @@ public void setUp() throws Exception { operationContext, DoFnSchemaInformation.create(), ImmutableMap.of(), - runnerFactory); + runnerFactory, + k -> {}); } @Test @@ -117,7 +118,7 @@ public void testFinishBundle() throws Exception { helpers.finishBundle(sideInputProcessor); verify(mockRunner).finishBundle(); - verify(sideInputProcessor).handleFinishBundle(); + verify(sideInputProcessor).handleFinishKeyOrBundle(); verify(doFnInstanceManager).complete(any()); } diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java index c110cc0d2bf7..6fea3b2fb449 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputDoFnRunnerTest.java @@ -33,6 +33,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.Supplier; import org.apache.beam.runners.core.DoFnRunner; import org.apache.beam.runners.core.DoFnRunners; import org.apache.beam.runners.core.InMemoryStateInternals; @@ -120,9 +121,8 @@ public void testSideInputReady() throws Exception { ListOutputManager outputManager = new ListOutputManager(); List> views = Arrays.asList(view); - StreamingSideInputFetcher sideInputFetcher = createFetcher(views); StreamingSideInputDoFnRunner runner = - createRunner(outputManager, views, sideInputFetcher); + createRunner(outputManager, views, () -> createFetcher(views)); runner.startBundle(); runner.processElement(createDatum("e", 0)); @@ -146,7 +146,7 @@ public void testSideInputNotReady() throws Exception { List> views = Arrays.asList(view); StreamingSideInputFetcher sideInputFetcher = createFetcher(views); StreamingSideInputDoFnRunner runner = - createRunner(outputManager, views, sideInputFetcher); + createRunner(outputManager, views, () -> sideInputFetcher); runner.startBundle(); runner.processElement(createDatum("e", 0)); @@ -203,7 +203,7 @@ public void testMultipleWindowsNotReady() throws Exception { SlidingWindows.of(Duration.millis(10)).every(Duration.millis(10)), outputManager, views, - sideInputFetcher); + () -> sideInputFetcher); IntervalWindow window1 = new IntervalWindow(new Instant(0), new Instant(10)); IntervalWindow window2 = new IntervalWindow(new Instant(-5), new Instant(5)); @@ -301,7 +301,7 @@ public void testSideInputNotification() throws Exception { List> views = Arrays.asList(view); StreamingSideInputFetcher sideInputFetcher = createFetcher(views); StreamingSideInputDoFnRunner runner = - createRunner(outputManager, views, sideInputFetcher); + createRunner(outputManager, views, () -> sideInputFetcher); sideInputFetcher.watermarkHold(createWindow(0)).add(new Instant(0)); sideInputFetcher.elementBag(createWindow(0)).add(createDatum("e", 0)); @@ -370,7 +370,7 @@ public void testMultipleSideInputs() throws Exception { List> views = Arrays.asList(view1, view2); StreamingSideInputFetcher sideInputFetcher = createFetcher(views); StreamingSideInputDoFnRunner runner = - createRunner(outputManager, views, sideInputFetcher); + createRunner(outputManager, views, () -> sideInputFetcher); sideInputFetcher.watermarkHold(createWindow(0)).add(new Instant(0)); sideInputFetcher.elementBag(createWindow(0)).add(createDatum("e1", 0)); @@ -391,16 +391,16 @@ public void testMultipleSideInputs() throws Exception { private StreamingSideInputDoFnRunner createRunner( WindowedValueMultiReceiver outputManager, List> views, - StreamingSideInputFetcher sideInputFetcher) + Supplier> sideInputFetcherSupplier) throws Exception { - return createRunner(WINDOW_FN, outputManager, views, sideInputFetcher); + return createRunner(WINDOW_FN, outputManager, views, sideInputFetcherSupplier); } private StreamingSideInputDoFnRunner createRunner( WindowFn windowFn, WindowedValueMultiReceiver outputManager, List> views, - StreamingSideInputFetcher sideInputFetcher) + Supplier> sideInputFetcherSupplier) throws Exception { DoFnRunner simpleDoFnRunner = DoFnRunners.simpleRunner( @@ -416,11 +416,11 @@ private StreamingSideInputDoFnRunner createRunne WindowingStrategy.of(windowFn), DoFnSchemaInformation.create(), Collections.emptyMap()); - return new StreamingSideInputDoFnRunner<>(simpleDoFnRunner, sideInputFetcher); + return new StreamingSideInputDoFnRunner<>(simpleDoFnRunner, sideInputFetcherSupplier); } private StreamingSideInputFetcher createFetcher( - List> views) throws Exception { + List> views) { @SuppressWarnings({"unchecked", "rawtypes"}) Iterable> typedViews = (Iterable) views; diff --git a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputProcessorTest.java b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputProcessorTest.java index 19e22b038839..d87ad41c22d2 100644 --- a/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputProcessorTest.java +++ b/runners/google-cloud-dataflow-java/worker/src/test/java/org/apache/beam/runners/dataflow/worker/StreamingSideInputProcessorTest.java @@ -127,7 +127,7 @@ public void testHandleFinishBundle() { doNothing().when(mockFetcher).persist(); // When - processor.handleFinishBundle(); + processor.handleFinishKeyOrBundle(); // Then verify(mockFetcher).persist();