Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run!",
"modification": 1,
"modification": 2,
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ public static <InputT, OutputT> DoFnRunner<InputT, OutputT> simpleRunner(
public static <K, InputT, OutputT, W extends BoundedWindow>
DoFnRunner<KeyedWorkItem<K, InputT>, KV<K, OutputT>> lateDataDroppingRunner(
DoFnRunner<KeyedWorkItem<K, InputT>, KV<K, OutputT>> wrappedRunner,
TimerInternals timerInternals,
StepContext stepContext,
WindowingStrategy<?, W> windowingStrategy) {
return new LateDataDroppingDoFnRunner<>(wrappedRunner, windowingStrategy, timerInternals);
return new LateDataDroppingDoFnRunner<>(wrappedRunner, windowingStrategy, stepContext);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ public class LateDataDroppingDoFnRunner<K, InputT, OutputT, W extends BoundedWin
public LateDataDroppingDoFnRunner(
DoFnRunner<KeyedWorkItem<K, InputT>, KV<K, OutputT>> doFnRunner,
WindowingStrategy<?, ?> windowingStrategy,
TimerInternals timerInternals) {
StepContext stepContext) {
this.doFnRunner = doFnRunner;
lateDataFilter = new LateDataFilter(windowingStrategy, timerInternals);
lateDataFilter = new LateDataFilter(windowingStrategy, stepContext);
}

@Override
Expand Down Expand Up @@ -116,13 +116,12 @@ public <KeyT> void onWindowExpiration(BoundedWindow window, Instant timestamp, K
@VisibleForTesting
static class LateDataFilter {
private final WindowingStrategy<?, ?> windowingStrategy;
private final TimerInternals timerInternals;
private final StepContext stepContext;
private final Counter droppedDueToLateness;

public LateDataFilter(
WindowingStrategy<?, ?> windowingStrategy, TimerInternals timerInternals) {
public LateDataFilter(WindowingStrategy<?, ?> windowingStrategy, StepContext stepContext) {
this.windowingStrategy = windowingStrategy;
this.timerInternals = timerInternals;
this.stepContext = stepContext;
Comment thread
arunpandianp marked this conversation as resolved.
this.droppedDueToLateness =
Metrics.counter(LateDataDroppingDoFnRunner.class, DROPPED_DUE_TO_LATENESS);
}
Expand All @@ -146,8 +145,8 @@ public <K, InputT> Iterable<WindowedValue<InputT>> filter(
element.getTimestamp(),
key,
window,
timerInternals.currentInputWatermarkTime(),
timerInternals.currentOutputWatermarkTime());
stepContext.timerInternals().currentInputWatermarkTime(),
stepContext.timerInternals().currentOutputWatermarkTime());
} else {
nonLateElements.add(
WindowedValues.of(
Expand All @@ -173,7 +172,7 @@ public <K, InputT> Iterable<WindowedValue<InputT>> filter(
* @return True if element can be dropped.
*/
private boolean canDropDueToExpiredWindow(BoundedWindow window) {
Instant inputWM = timerInternals.currentInputWatermarkTime();
Instant inputWM = stepContext.timerInternals().currentInputWatermarkTime();
return LateDataUtils.garbageCollectionTime(window, windowingStrategy).isBefore(inputWM);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ public class LateDataDroppingDoFnRunnerTest {
private static final FixedWindows WINDOW_FN = FixedWindows.of(Duration.millis(10));

@Mock private TimerInternals mockTimerInternals;
@Mock private StepContext mockStepContext;

@Before
public void setUp() {
Expand All @@ -60,9 +61,10 @@ public void testLateDataFilter() throws Exception {
MetricsContainerImpl container = new MetricsContainerImpl("any");
MetricsEnvironment.setCurrentContainer(container);
when(mockTimerInternals.currentInputWatermarkTime()).thenReturn(new Instant(15L));
when(mockStepContext.timerInternals()).thenReturn(mockTimerInternals);

LateDataFilter lateDataFilter =
new LateDataFilter(WindowingStrategy.of(WINDOW_FN), mockTimerInternals);
new LateDataFilter(WindowingStrategy.of(WINDOW_FN), mockStepContext);

Iterable<WindowedValue<Integer>> actual =
lateDataFilter.filter(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.CombineWithContext.CombineFnWithContext;
import org.apache.beam.sdk.transforms.DoFn.BundleFinalizer;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.GlobalWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
Expand Down Expand Up @@ -94,6 +95,23 @@ public class ReduceFnTester<InputT, OutputT, W extends BoundedWindow> {
private final TestInMemoryStateInternals<String> stateInternals =
new TestInMemoryStateInternals<>(KEY);
private final InMemoryTimerInternals timerInternals = new InMemoryTimerInternals();
private final StepContext stepContext =
new StepContext() {
@Override
public StateInternals stateInternals() {
return stateInternals;
}

@Override
public TimerInternals timerInternals() {
return timerInternals;
}

@Override
public BundleFinalizer bundleFinalizer() {
throw new UnsupportedOperationException();
}
};

private final WindowFn<Object, W> windowFn;
private final TestWindowedValueReceiver testOutputter;
Expand Down Expand Up @@ -577,7 +595,7 @@ public final void injectElements(List<TimestampedValue<InputT>> values) throws E

ReduceFnRunner<String, InputT, OutputT, W> runner = createRunner();
runner.processElements(
new LateDataDroppingDoFnRunner.LateDataFilter(objectStrategy, timerInternals)
new LateDataDroppingDoFnRunner.LateDataFilter(objectStrategy, stepContext)
.filter(KEY, inputs));

// Persist after each bundle.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ protected DoFnRunner<KeyedWorkItem<K, InputT>, KV<K, OutputT>> createWrappingDoF
// for some K, V

return DoFnRunners.lateDataDroppingRunner(
(DoFnRunner) doFnRunner, timerInternals, windowingStrategy);
(DoFnRunner) doFnRunner, stepContext, windowingStrategy);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ private DoFnRunner<InputT, KV<K, Iterable<V>>> createRunner() {
}
return (DoFnRunner<InputT, KV<K, Iterable<V>>>)
DoFnRunners.<K, V, Iterable<V>, W>lateDataDroppingRunner(
streamingGABWRunner, stepContext.timerInternals(), windowingStrategy);
streamingGABWRunner, stepContext, windowingStrategy);
} else {
if (hasStreamingSideInput) {
return new StreamingSideInputDoFnRunner<>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,26 @@ public void flushState() {
checkStateNotNull(systemTimerInternals).persistTo(builder);
checkStateNotNull(userTimerInternals).persistTo(builder);
}
poisonStateAndTimerInternals();
}

/**
* Poisons the state and timer internals to prevent any subsequent (stale) usage.
*
* <p>This ensures that if these key-specific internals are incorrectly cached and used after
* the key's execution context has finished, it will fail fast with a clear error rather than
* silently corrupting state.
*/
private void poisonStateAndTimerInternals() {
if (stateInternals != null) {
stateInternals.poison();
}
if (systemTimerInternals != null) {
systemTimerInternals.poison();
}
if (userTimerInternals != null) {
userTimerInternals.poison();
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,19 @@ class WindmillTimerInternals implements TimerInternals {
private final Consumer<TimerData> onTimerModified;
private final WindmillTagEncoding windmillTagEncoding;

private boolean poisoned = false;
Comment thread
arunpandianp marked this conversation as resolved.

public void poison() {
this.poisoned = true;
}

private void checkNotPoisoned() {
if (poisoned) {
throw new IllegalStateException(
"WindmillTimerInternals is poisoned and cannot be used after flushState().");
}
}

public WindmillTimerInternals(
String stateFamily, // unique identifies a step
WindmillTimerType type,
Expand All @@ -80,12 +93,14 @@ public WindmillTimerInternals(
}

public WindmillTimerInternals withType(WindmillTimerType type) {
checkNotPoisoned();
return new WindmillTimerInternals(
stateFamily, type, processingTime, watermarks, windmillTagEncoding, onTimerModified);
}

@Override
public void setTimer(TimerData timerKey) {
checkNotPoisoned();
String timerDataKey = getTimerDataKey(timerKey.getTimerId(), timerKey.getTimerFamilyId());
timerMap.put(
new SimpleEntry<>(timerDataKey, timerKey.getNamespace()),
Expand All @@ -101,6 +116,7 @@ public void setTimer(
Instant timestamp,
Instant outputTimestamp,
TimeDomain timeDomain) {
checkNotPoisoned();
TimerData timer =
TimerData.of(
timerId,
Expand All @@ -124,6 +140,7 @@ private static String getTimerDataKey(String timerId, String timerFamilyId) {

@Override
public void deleteTimer(TimerData timerKey) {
checkNotPoisoned();
String timerDataKey = getTimerDataKey(timerKey.getTimerId(), timerKey.getTimerFamilyId());
timerMap.put(
new SimpleEntry<>(timerDataKey, timerKey.getNamespace()),
Expand All @@ -139,6 +156,7 @@ public void deleteTimer(StateNamespace namespace, String timerId, String timerFa
@Override
public void deleteTimer(
StateNamespace namespace, String timerId, String timerFamilyId, TimeDomain timeDomain) {
checkNotPoisoned();
deleteTimer(
TimerData.of(
timerId,
Expand All @@ -152,12 +170,14 @@ public void deleteTimer(

@Override
public Instant currentProcessingTime() {
checkNotPoisoned();
Instant now = Instant.now();
return processingTime.isAfter(now) ? processingTime : now;
}

@Override
public @Nullable Instant currentSynchronizedProcessingTime() {
checkNotPoisoned();
return watermarks.synchronizedProcessingTime();
}

Expand All @@ -172,6 +192,7 @@ public Instant currentProcessingTime() {
*/
@Override
public Instant currentInputWatermarkTime() {
checkNotPoisoned();
return watermarks.inputDataWatermark();
}

Expand All @@ -186,10 +207,12 @@ public Instant currentInputWatermarkTime() {
*/
@Override
public @Nullable Instant currentOutputWatermarkTime() {
checkNotPoisoned();
return watermarks.outputDataWatermark();
}

public void persistTo(Windmill.WorkItemCommitRequest.Builder outputBuilder) {
checkNotPoisoned();
for (Entry<TimerData, Boolean> value : timerMap.values()) {
// Regardless of whether it is set or not, it must have some TimerData stored so we
// can know its time domain
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ public class WindmillStateInternals<K> implements StateInternals {
private final CachingStateTable workItemDerivedState;
private final Supplier<Closeable> scopedReadStateSupplier;

private boolean poisoned = false;
Comment thread
arunpandianp marked this conversation as resolved.

public void poison() {
this.poisoned = true;
}

private void checkNotPoisoned() {
if (poisoned) {
throw new IllegalStateException(
"WindmillStateInternals is poisoned and cannot be used after flushState().");
}
}

public WindmillStateInternals(
@Nullable K key,
String stateFamily,
Expand All @@ -78,6 +91,7 @@ public WindmillStateInternals(

@Override
public @Nullable K getKey() {
checkNotPoisoned();
return key;
}

Expand All @@ -104,6 +118,7 @@ private void persist(
}

public void persist(final Windmill.WorkItemCommitRequest.Builder commitBuilder) {
checkNotPoisoned();
List<Future<WorkItemCommitRequest>> commitsToMerge = new ArrayList<>();

// Call persist on each first, which may schedule some futures for reading.
Expand All @@ -126,12 +141,14 @@ public void persist(final Windmill.WorkItemCommitRequest.Builder commitBuilder)

@Override
public <T extends State> T state(StateNamespace namespace, StateTag<T> address) {
checkNotPoisoned();
return workItemState.get(namespace, address, StateContexts.nullContext());
}

@Override
public <T extends State> T state(
StateNamespace namespace, StateTag<T> address, StateContext<?> c) {
checkNotPoisoned();
return workItemState.get(namespace, address, c);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,8 +753,7 @@ DoFnRunner<KeyedWorkItem<String, InputT>, KV<String, OutputT>> makeRunnerForGabw
NullSideInputReader.empty(),
outputManager,
stepContext);
return DoFnRunners.lateDataDroppingRunner(
doFnRunner, stepContext.timerInternals(), windowingStrategy);
return DoFnRunners.lateDataDroppingRunner(doFnRunner, stepContext, windowingStrategy);
}

private IntervalWindow window(long start, long end) {
Expand Down
Loading
Loading