diff --git a/sdks/python/apache_beam/io/watch.py b/sdks/python/apache_beam/io/watch.py new file mode 100644 index 000000000000..8ccbd4f8e6e5 --- /dev/null +++ b/sdks/python/apache_beam/io/watch.py @@ -0,0 +1,676 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Experimental ``Watch`` transform for the Python SDK. + +``Watch`` continuously watches a growing set of outputs for each input element, +calling a user poll function on an interval until a per-input termination +condition fires. It is the engine behind periodic file-discovery and any +periodic polling source. + +For every input element the transform runs an independent loop:: + + poll -> keep never-seen-before outputs -> emit them (timestamped) -> + update watermark -> check termination -> wait(poll_interval) -> poll -> ... + +The output is an unbounded ``PCollection`` of ``(input, output)`` pairs. Each +output carries the event time the poll function first reported it. Dedup uses a +stable 128-bit hash of the encoded output, so the output coder must be +deterministic for dedup to hold across workers and restarts. + +Example:: + + from apache_beam.io.watch import Watch, PollResult, after_total_of + from apache_beam.transforms.window import TimestampedValue + from apache_beam.utils.timestamp import Duration, Timestamp + + def poll(prefix): + now = Timestamp.now() + outputs = [TimestampedValue(prefix + str(i), now) for i in range(3)] + return PollResult.complete(outputs) + + watched = (inputs + | Watch.growth_of(poll) + .with_poll_interval(Duration(seconds=5)) + .with_termination_per_input(after_total_of(60))) + +This API is experimental and may change in backwards-incompatible ways. +""" + +import collections +import dataclasses +import hashlib +import logging +import time +from typing import Any +from typing import Callable +from typing import Iterable +from typing import List +from typing import Optional +from typing import Tuple + +from apache_beam import coders +from apache_beam.coders.coders import Coder +from apache_beam.coders.coders import NullableCoder +from apache_beam.coders.coders import TimestampCoder +from apache_beam.coders.coders import TupleCoder +from apache_beam.io import iobase +from apache_beam.io.watermark_estimators import ManualWatermarkEstimator +from apache_beam.runners import sdf_utils +from apache_beam.transforms import PTransform +from apache_beam.transforms import core +from apache_beam.transforms.window import TimestampedValue +from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import Duration +from apache_beam.utils.timestamp import Timestamp + +__all__ = [ + 'Watch', + 'PollResult', + 'PollFn', + 'TerminationCondition', + 'never', + 'after_total_of', +] + +_LOGGER = logging.getLogger(__name__) + +_HASH_DIGEST_SIZE = 16 # 128-bit digest width. + +# ------------------------------------------------------------------------------ +# Public API. +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass(frozen=True) +class PollResult: + """Outputs produced by one poll, plus an optional explicit watermark. + + ``watermark`` of ``None`` lets the transform infer the watermark from the + earliest new output. A watermark of ``MAX_TIMESTAMP`` (set by + :meth:`complete`) marks the input finished, so polling stops. + """ + outputs: Tuple[TimestampedValue, ...] + watermark: Optional[Timestamp] = None + + @property + def is_complete(self) -> bool: + return self.watermark == MAX_TIMESTAMP + + @staticmethod + def _normalize(outputs, timestamp) -> Tuple[TimestampedValue, ...]: + if timestamp is None: + default_ts = Timestamp.now() + else: + default_ts = Timestamp.of(timestamp) + normalized = [] + for output in outputs: + if isinstance(output, TimestampedValue): + normalized.append(output) + else: + normalized.append(TimestampedValue(output, default_ts)) + return tuple(normalized) + + @staticmethod + def incomplete(outputs: Iterable, timestamp=None) -> 'PollResult': + """Reports outputs and expects more; the transform infers the watermark. + + A raw (non-:class:`TimestampedValue`) output is stamped with ``timestamp`` + when given, else with the current processing time. + """ + return PollResult(PollResult._normalize(outputs, timestamp), watermark=None) + + @staticmethod + def complete(outputs: Iterable, timestamp=None) -> 'PollResult': + """Reports the final outputs for an input, after which polling stops. + + A raw (non-:class:`TimestampedValue`) output is stamped with ``timestamp`` + when given, else with the current processing time. + """ + return PollResult( + PollResult._normalize(outputs, timestamp), watermark=MAX_TIMESTAMP) + + def with_watermark(self, watermark) -> 'PollResult': + return dataclasses.replace(self, watermark=Timestamp.of(watermark)) + + +class PollFn(object): + """Optional base for a poll function ``input -> PollResult``. + + Any callable with that signature works; subclass only to attach an output + coder hint via :meth:`default_output_coder`. + """ + def __call__(self, element: Any) -> PollResult: + raise NotImplementedError + + def default_output_coder(self) -> Optional[Coder]: + return None + + +class TerminationCondition(object): + """Per-input stop policy with immutable, encodable state. + + Hooks follow the lifecycle of one input's polling loop. ``state`` flows from + :meth:`for_new_input` through the per-round hooks and is serialized with + :meth:`state_coder`. + """ + def for_new_input(self, now: Timestamp, element: Any) -> Any: + raise NotImplementedError + + def on_seen_new_output(self, now: Timestamp, state: Any) -> Any: + return state + + def on_poll_complete(self, state: Any) -> Any: + return state + + def can_stop_polling(self, now: Timestamp, state: Any) -> bool: + raise NotImplementedError + + def state_coder(self) -> Coder: + raise NotImplementedError + + +class _Never(TerminationCondition): + """Polls until the poll function returns :meth:`PollResult.complete`.""" + def for_new_input(self, now, element): + return 0 + + def can_stop_polling(self, now, state): + return False + + def state_coder(self): + return coders.VarIntCoder() + + +class _AfterTotalOf(TerminationCondition): + """Stops once the wall-clock time since the input was first seen exceeds a + fixed duration.""" + def __init__(self, duration: Duration): + self._duration_micros = duration.micros + + def for_new_input(self, now, element): + return (now, self._duration_micros) + + def can_stop_polling(self, now, state): + start, duration_micros = state + return (now - start).micros > duration_micros + + def state_coder(self): + return TupleCoder([TimestampCoder(), coders.VarIntCoder()]) + + +def never() -> TerminationCondition: + """Polls until :meth:`PollResult.complete`.""" + return _Never() + + +def after_total_of(duration) -> TerminationCondition: + """Stops polling an input after ``duration`` (a :class:`Duration` or seconds) + has elapsed since it was first seen.""" + return _AfterTotalOf(_as_duration(duration)) + + +# ------------------------------------------------------------------------------ +# Restriction state. +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass(frozen=True) +class _PollingGrowthState: + """Keep-polling state: emitted-output hashes, watermark, termination state. + + ``completed`` maps a 16-byte output hash to the event time it was first seen. + It is insertion-ordered and treated as immutable; the tracker builds a new + mapping for each residual. + """ + completed: 'collections.OrderedDict[bytes, Timestamp]' + poll_watermark: Optional[Timestamp] + termination_state: Any + + +@dataclasses.dataclass(frozen=True) +class _NonPollingGrowthState: + """Replay-then-stop state: the outputs already emitted this round. + + Produced as the checkpoint primary so a bundle retry re-emits exactly those + outputs. + """ + pending: PollResult + + +_GrowthState = Any # Union[_PollingGrowthState, _NonPollingGrowthState] + +# ------------------------------------------------------------------------------ +# Coders. +# ------------------------------------------------------------------------------ + + +class _HashCode128Coder(Coder): + """Fixed-width coder for a 16-byte output hash. + + Encodes and decodes exactly 16 bytes and raises on any other length, so a + corrupt restriction surfaces at decode time. + """ + def encode(self, value: bytes) -> bytes: + if len(value) != _HASH_DIGEST_SIZE: + raise ValueError( + 'hash must be %d bytes, got %d' % (_HASH_DIGEST_SIZE, len(value))) + return value + + def decode(self, encoded: bytes) -> bytes: + if len(encoded) != _HASH_DIGEST_SIZE: + raise ValueError( + 'hash must be %d bytes, got %d' % (_HASH_DIGEST_SIZE, len(encoded))) + return encoded + + def is_deterministic(self) -> bool: + return True + + +class _TimestampedValueCoder(Coder): + """Coder for :class:`TimestampedValue`. + + The Python SDK ships no coder for ``TimestampedValue``, so this encodes the + ``(value, timestamp)`` pair with a :class:`TupleCoder` and rebuilds the + ``TimestampedValue`` on decode. + """ + def __init__(self, value_coder: Coder): + self._tuple_coder = TupleCoder([value_coder, TimestampCoder()]) + + def encode(self, value: TimestampedValue) -> bytes: + return self._tuple_coder.encode((value.value, value.timestamp)) + + def decode(self, encoded: bytes) -> TimestampedValue: + value, timestamp = self._tuple_coder.decode(encoded) + return TimestampedValue(value, timestamp) + + def is_deterministic(self) -> bool: + return self._tuple_coder.is_deterministic() + + +class _GrowthStateCoder(Coder): + """Encodes a :class:`_PollingGrowthState` or :class:`_NonPollingGrowthState`. + + A ``(tag, payload)`` envelope selects the variant; the payload is a + variant-specific :class:`TupleCoder`. ``completed`` is encoded as an ordered + list of ``(hash, timestamp)`` pairs so insertion order survives a round trip. + This format is internal to the Python SDK. + """ + def __init__(self, output_coder: Coder, termination: TerminationCondition): + nullable_ts = NullableCoder(TimestampCoder()) + self._envelope_coder = TupleCoder( + [coders.VarIntCoder(), coders.BytesCoder()]) + self._polling_coder = TupleCoder([ + termination.state_coder(), + nullable_ts, + coders.ListCoder(TupleCoder([_HashCode128Coder(), TimestampCoder()])), + ]) + self._non_polling_coder = TupleCoder([ + nullable_ts, + coders.ListCoder(_TimestampedValueCoder(output_coder)), + ]) + + def encode(self, state: _GrowthState) -> bytes: + if isinstance(state, _PollingGrowthState): + payload = self._polling_coder.encode(( + state.termination_state, + state.poll_watermark, + list(state.completed.items()))) + return self._envelope_coder.encode((0, payload)) + payload = self._non_polling_coder.encode( + (state.pending.watermark, list(state.pending.outputs))) + return self._envelope_coder.encode((1, payload)) + + def decode(self, encoded: bytes) -> _GrowthState: + tag, payload = self._envelope_coder.decode(encoded) + if tag == 0: + termination_state, poll_watermark, items = self._polling_coder.decode( + payload) + return _PollingGrowthState( + collections.OrderedDict(items), poll_watermark, termination_state) + if tag == 1: + watermark, outputs = self._non_polling_coder.decode(payload) + return _NonPollingGrowthState(PollResult(tuple(outputs), watermark)) + raise ValueError('unknown Watch growth state tag: %r' % (tag, )) + + def is_deterministic(self) -> bool: + return False + + +# ------------------------------------------------------------------------------ +# Restriction tracker. +# ------------------------------------------------------------------------------ + + +class _GrowthRestrictionTracker(iobase.RestrictionTracker): + """Drives one input's polling loop. + + ``process()`` only sees a ``RestrictionTrackerView`` whose ``try_claim`` + returns a bool, so the poll happens inside ``try_claim`` and its result is + returned through a two-slot holder list passed as the claim position: + ``holder[0]`` carries the input element in, ``holder[1]`` carries the work + out. At most one claim succeeds per ``process()``. + + The poll runs while the tracker lock is held, so a ``PollFn`` must be bounded + or timeout-safe; a blocking poll delays runner-initiated checkpoints. + """ + def __init__( + self, + restriction: _GrowthState, + poll_fn: Callable[[Any], PollResult], + key_coder: Coder, + termination: TerminationCondition, + now_fn: Callable[[], float]): + self._restriction = restriction + self._poll_fn = poll_fn + self._key_coder = key_coder + self._termination = termination + self._now = now_fn + self._should_stop = False + self._primary = None # type: Optional[_GrowthState] + self._residual = None # type: Optional[_GrowthState] + + def current_restriction(self) -> _GrowthState: + return self._restriction + + def _hash_output(self, value: Any) -> bytes: + return hashlib.blake2b( + self._key_coder.encode(value), digest_size=_HASH_DIGEST_SIZE).digest() + + def try_claim(self, holder: list) -> bool: + """Performs one poll round (or one replay) and reports it via ``holder``. + + Returns ``False`` only when a checkpoint already stopped this invocation, + in which case ``process()`` must emit nothing. + """ + if self._should_stop: + return False + restriction = self._restriction + if isinstance(restriction, _NonPollingGrowthState): + holder[1] = ('replay', restriction.pending) + self._should_stop = True + return True + + element = holder[0] + now = Timestamp.of(self._now()) + result = self._poll_fn(element) + + new_outputs = [] # type: List[TimestampedValue] + claimed = [] # type: List[Tuple[bytes, Timestamp]] + seen_this_round = set() # type: set + for output in result.outputs: + key_hash = self._hash_output(output.value) + if key_hash in restriction.completed or key_hash in seen_this_round: + continue + seen_this_round.add(key_hash) + new_outputs.append(output) + claimed.append((key_hash, output.timestamp)) + new_outputs.sort(key=lambda output: output.timestamp) + + termination_state = restriction.termination_state + if new_outputs: + termination_state = self._termination.on_seen_new_output( + now, termination_state) + termination_state = self._termination.on_poll_complete(termination_state) + + if result.watermark is not None: + watermark = result.watermark + elif new_outputs: + watermark = new_outputs[0].timestamp + else: + watermark = None + + # A watermark at MAX means no more output is possible, so polling stops. + reached_max = watermark is not None and watermark >= MAX_TIMESTAMP + stop = ( + result.is_complete or reached_max or + self._termination.can_stop_polling(now, termination_state)) + + self._primary = _NonPollingGrowthState( + PollResult(tuple(new_outputs), watermark)) + if stop: + # Terminal round: no polling work remains, so a checkpoint (runner- + # initiated or via defer_remainder) resumes a state that emits nothing. + self._residual = _NonPollingGrowthState(PollResult((), watermark)) + else: + merged = collections.OrderedDict(restriction.completed) + for key_hash, first_seen in claimed: + merged[key_hash] = first_seen + residual_watermark = self._max_watermark( + restriction.poll_watermark, watermark) + self._residual = _PollingGrowthState( + merged, residual_watermark, termination_state) + holder[1] = ('poll', new_outputs, watermark, stop) + self._should_stop = True + return True + + @staticmethod + def _max_watermark(left: Optional[Timestamp], + right: Optional[Timestamp]) -> Optional[Timestamp]: + if left is None: + return right + if right is None: + return left + return max(left, right) + + def try_split(self, fraction_of_remainder): + # Only self-checkpoint (fraction 0) is supported; decline dynamic splits. + if fraction_of_remainder != 0: + return None + if self._primary is None: + # No claim happened this invocation: keep the whole state as the residual. + primary = _NonPollingGrowthState(PollResult((), None)) + residual = self._restriction + self._restriction = primary + self._should_stop = True + return primary, residual + primary, residual = self._primary, self._residual + self._restriction = primary + self._should_stop = True + return primary, residual + + def check_done(self) -> bool: + # Called after every process(); the single claim or a split sets the flag. + if self._should_stop: + return True + raise ValueError( + 'Watch restriction was neither claimed nor checkpointed: %r' % + (self._restriction, )) + + def current_progress(self) -> 'iobase.RestrictionProgress': + if self._should_stop: + return iobase.RestrictionProgress(completed=1.0, remaining=0.0) + return iobase.RestrictionProgress(completed=0.0, remaining=1.0) + + def is_bounded(self) -> bool: + # A polling restriction is unbounded; a replay-then-stop one is bounded. + return isinstance(self._restriction, _NonPollingGrowthState) + + +# ------------------------------------------------------------------------------ +# Splittable DoFn (its own restriction provider). +# ------------------------------------------------------------------------------ + + +class _WatchGrowthDoFn(core.DoFn, core.RestrictionProvider): + """Polling SDF that emits ``(input, output)`` pairs. + + The DoFn is its own ``RestrictionProvider``: ``RestrictionParam()`` with no + argument resolves the provider to the DoFn instance, so the provider methods + read the transform-level spec (poll function, coders, termination) off + ``self``. Provider methods run on a separately deserialized copy and before + ``setup()``, so the spec is immutable state set in ``__init__``. + """ + def __init__( + self, + poll_fn: Callable[[Any], PollResult], + termination: TerminationCondition, + poll_interval: Duration, + output_coder: Coder, + now_fn: Optional[Callable[[], float]] = None): + self._poll_fn = poll_fn + self._termination = termination + self._poll_interval = poll_interval + self._output_coder = output_coder + self._key_coder = output_coder + self._now = now_fn or time.time + self._restriction_coder = _GrowthStateCoder(output_coder, termination) + + def initial_restriction(self, element) -> _PollingGrowthState: + now = Timestamp.of(self._now()) + return _PollingGrowthState( + collections.OrderedDict(), + None, + self._termination.for_new_input(now, element)) + + def create_tracker(self, restriction) -> _GrowthRestrictionTracker: + return _GrowthRestrictionTracker( + restriction, + self._poll_fn, + self._key_coder, + self._termination, + self._now) + + def split(self, element, restriction): + # Watch fans out by input element, so each restriction stays whole. + yield restriction + + def restriction_coder(self) -> Coder: + return self._restriction_coder + + def restriction_size(self, element, restriction) -> int: + return 1 + + def truncate(self, element, restriction): + # On drain, replay a pending NonPolling state and stop further polling. + if isinstance(restriction, _NonPollingGrowthState): + return restriction + return None + + @core.DoFn.unbounded_per_element() + def process( + self, + element, + timestamp=core.DoFn.TimestampParam, + tracker=core.DoFn.RestrictionParam(), + watermark_estimator=core.DoFn.WatermarkEstimatorParam( + ManualWatermarkEstimator.default_provider())): + assert isinstance(tracker, sdf_utils.RestrictionTrackerView) + holder = [element, None] + if not tracker.try_claim(holder): + # A checkpoint already stopped this invocation; emit nothing. + return + # Seed the watermark hold from the input event time after the claim. + _set_watermark_if_greater(watermark_estimator, timestamp) + work = holder[1] + if work[0] == 'replay': + for output in work[1].outputs: + yield TimestampedValue((element, output.value), output.timestamp) + return + new_outputs, watermark, stop = work[1], work[2], work[3] + for output in new_outputs: + yield TimestampedValue((element, output.value), output.timestamp) + if stop: + # The input is finished, so release the watermark hold to MAX. + _set_watermark_if_greater(watermark_estimator, MAX_TIMESTAMP) + return + if watermark is not None: + _set_watermark_if_greater(watermark_estimator, watermark) + tracker.defer_remainder(self._poll_interval) + + +def _set_watermark_if_greater(watermark_estimator, new_watermark) -> None: + # set_watermark raises on regression, so only ever advance the watermark. + current = watermark_estimator.current_watermark() + if current is None or new_watermark > current: + watermark_estimator.set_watermark(new_watermark) + + +# ------------------------------------------------------------------------------ +# Public PTransform. +# ------------------------------------------------------------------------------ + + +class Watch(PTransform): + """Watches a growing set of outputs per input via a periodic poll function. + + Build with :meth:`growth_of` and the ``with_*`` methods. The output is an + unbounded ``PCollection`` of ``(input, output)`` pairs. + """ + def __init__( + self, + poll_fn: Callable[[Any], PollResult], + termination: Optional[TerminationCondition] = None, + poll_interval: Optional[Duration] = None, + output_coder: Optional[Coder] = None, + now_fn: Optional[Callable[[], float]] = None): + super().__init__() + self._poll_fn = poll_fn + self._termination = termination or never() + self._poll_interval = poll_interval + self._output_coder = output_coder + self._now = now_fn + + @classmethod + def growth_of(cls, poll_fn: Callable[[Any], PollResult]) -> 'Watch': + return cls(poll_fn) + + def _replace(self, **changes) -> 'Watch': + spec = dict( + poll_fn=self._poll_fn, + termination=self._termination, + poll_interval=self._poll_interval, + output_coder=self._output_coder, + now_fn=self._now) + spec.update(changes) + return Watch(**spec) + + def with_poll_interval(self, poll_interval) -> 'Watch': + return self._replace(poll_interval=_as_duration(poll_interval)) + + def with_termination_per_input( + self, termination: TerminationCondition) -> 'Watch': + return self._replace(termination=termination) + + def with_output_coder(self, output_coder: Coder) -> 'Watch': + return self._replace(output_coder=output_coder) + + def expand(self, pcoll): + if self._poll_interval is None: + raise ValueError('Watch requires with_poll_interval(...)') + output_coder = self._output_coder + if output_coder is None: + hint = self._poll_fn.default_output_coder() if isinstance( + self._poll_fn, PollFn) else None + output_coder = hint or coders.PickleCoder() + if not output_coder.is_deterministic(): + _LOGGER.warning( + 'Watch dedup uses a non-deterministic output coder (%s); equal ' + 'outputs may be emitted more than once. Pass a deterministic coder ' + 'via with_output_coder() for reliable dedup.', + type(output_coder).__name__) + return pcoll | core.ParDo( + _WatchGrowthDoFn( + self._poll_fn, + self._termination, + self._poll_interval, + output_coder, + self._now)) + + +def _as_duration(value) -> Duration: + return value if isinstance(value, Duration) else Duration(value) diff --git a/sdks/python/apache_beam/io/watch_test.py b/sdks/python/apache_beam/io/watch_test.py new file mode 100644 index 000000000000..8b47d31cf56a --- /dev/null +++ b/sdks/python/apache_beam/io/watch_test.py @@ -0,0 +1,280 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Tests for the Watch transform.""" + +import collections +import unittest + +import apache_beam as beam +from apache_beam.coders.coders import StrUtf8Coder +from apache_beam.io.watch import PollResult +from apache_beam.io.watch import Watch +from apache_beam.io.watch import _GrowthRestrictionTracker +from apache_beam.io.watch import _GrowthStateCoder +from apache_beam.io.watch import _NonPollingGrowthState +from apache_beam.io.watch import _PollingGrowthState +from apache_beam.io.watch import after_total_of +from apache_beam.io.watch import never +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.runners.sdf_utils import RestrictionTrackerView +from apache_beam.runners.sdf_utils import ThreadsafeRestrictionTracker +from apache_beam.testing.test_pipeline import TestPipeline +from apache_beam.testing.util import assert_that +from apache_beam.testing.util import equal_to +from apache_beam.testing.util import TestWindowedValue +from apache_beam.transforms.window import FixedWindows +from apache_beam.transforms.window import GlobalWindow +from apache_beam.transforms.window import TimestampedValue +from apache_beam.utils.timestamp import MAX_TIMESTAMP +from apache_beam.utils.timestamp import Duration +from apache_beam.utils.timestamp import Timestamp + + +def _ts(value, timestamp): + return TimestampedValue(value, Timestamp(timestamp)) + + +def _new_tracker(restriction, poll_fn, now=0.0): + return _GrowthRestrictionTracker( + restriction, poll_fn, StrUtf8Coder(), never(), lambda: now) + + +def _initial_polling(termination=None, now=Timestamp(0)): + termination = termination or never() + return _PollingGrowthState( + collections.OrderedDict(), None, termination.for_new_input(now, 'input')) + + +class GrowthStateCoderTest(unittest.TestCase): + def test_polling_round_trip_preserves_resume_state(self): + termination = after_total_of(Duration(30)) + coder = _GrowthStateCoder(StrUtf8Coder(), termination) + completed = collections.OrderedDict([ + (b'a' * 16, Timestamp(1)), + (b'b' * 16, Timestamp(2)), + (b'c' * 16, Timestamp(3)), + ]) + termination_state = termination.for_new_input(Timestamp(7), 'input') + state = _PollingGrowthState(completed, Timestamp(5), termination_state) + decoded = coder.decode(coder.encode(state)) + self.assertEqual(list(completed.items()), list(decoded.completed.items())) + self.assertEqual(Timestamp(5), decoded.poll_watermark) + self.assertEqual(termination_state, decoded.termination_state) + + def test_non_polling_round_trip_preserves_pending_outputs(self): + coder = _GrowthStateCoder(StrUtf8Coder(), never()) + pending = PollResult((_ts('a', 1), _ts('b', 2)), MAX_TIMESTAMP) + state = _NonPollingGrowthState(pending) + decoded = coder.decode(coder.encode(state)) + self.assertEqual(MAX_TIMESTAMP, decoded.pending.watermark) + self.assertEqual( + [('a', Timestamp(1)), ('b', Timestamp(2))], + [(o.value, o.timestamp) for o in decoded.pending.outputs]) + + +class GrowthTrackerTest(unittest.TestCase): + def test_poll_claims_dedups_and_checkpoints(self): + def poll(unused_element): + return PollResult.incomplete([_ts('a', 1), _ts('a', 1), _ts('b', 2)]) + + tracker = _new_tracker(_initial_polling(), poll) + holder = ['input', None] + self.assertTrue(tracker.try_claim(holder)) + work = holder[1] + kind, outputs, watermark, stop = work[0], work[1], work[2], work[3] + self.assertEqual('poll', kind) + self.assertEqual(['a', 'b'], [o.value for o in outputs]) + self.assertFalse(stop) + self.assertEqual(Timestamp(1), watermark) + self.assertFalse(tracker.is_bounded()) + + primary, residual = tracker.try_split(0) + self.assertIsInstance(primary, _NonPollingGrowthState) + self.assertIsInstance(residual, _PollingGrowthState) + self.assertEqual(2, len(residual.completed)) + self.assertTrue(tracker.check_done()) + + def explicit_watermark_poll(unused_element): + return PollResult.incomplete([_ts('c', 3)]).with_watermark(5) + + explicit_tracker = _new_tracker(_initial_polling(), explicit_watermark_poll) + holder = ['input', None] + self.assertTrue(explicit_tracker.try_claim(holder)) + self.assertEqual(Timestamp(5), holder[1][2]) + _, residual = explicit_tracker.try_split(0) + self.assertEqual(Timestamp(5), residual.poll_watermark) + + def test_second_round_repolls_and_dedups_against_completed(self): + polls = [] + + def poll(unused_element): + polls.append(len(polls)) + if len(polls) == 1: + return PollResult.incomplete([_ts('a', 1), _ts('b', 2)]) + return PollResult.incomplete([_ts('a', 1), _ts('c', 3)]) + + tracker = _new_tracker(_initial_polling(), poll) + holder = ['input', None] + tracker.try_claim(holder) + _, residual = tracker.try_split(0) + + resumed = _new_tracker(residual, poll) + holder = ['input', None] + self.assertTrue(resumed.try_claim(holder)) + outputs = holder[1][1] + self.assertEqual(2, len(polls)) + self.assertEqual(['c'], [o.value for o in outputs]) + + def test_termination_condition_sets_stop(self): + def poll(unused_element): + return PollResult.incomplete([_ts('a', 1)]) + + termination = after_total_of(10) + for now, expected_stop in [(10.0, False), (11.0, True)]: + with self.subTest(now=now): + tracker = _GrowthRestrictionTracker( + _initial_polling(termination, Timestamp(0)), + poll, + StrUtf8Coder(), + termination, lambda now=now: now) + holder = ['input', None] + self.assertTrue(tracker.try_claim(holder)) + self.assertEqual(expected_stop, holder[1][3]) + + def test_non_polling_replays(self): + pending = PollResult((_ts('a', 1), _ts('b', 2)), MAX_TIMESTAMP) + tracker = _new_tracker(_NonPollingGrowthState(pending), lambda e: None) + holder = ['input', None] + self.assertTrue(tracker.try_claim(holder)) + work = holder[1] + kind, replayed = work[0], work[1] + self.assertEqual('replay', kind) + self.assertEqual(['a', 'b'], [o.value for o in replayed.outputs]) + self.assertTrue(tracker.check_done()) + + def test_terminal_split_residual_is_empty_for_all_stop_causes(self): + termination = after_total_of(Duration(10)) + cases = [ + ('reached_max', never(), _initial_polling(), + lambda element: PollResult((TimestampedValue('a', MAX_TIMESTAMP), ), + None), 0.0), + ('complete', never(), _initial_polling(), + lambda element: PollResult.complete([_ts('a', 1)]), 0.0), + ('after_total_of', termination, + _initial_polling(termination, Timestamp(0)), + lambda element: PollResult.incomplete([_ts('a', 1)]), 100.0), + ] + for name, condition, restriction, poll_fn, now in cases: + with self.subTest(name=name): + tracker = _GrowthRestrictionTracker( + restriction, poll_fn, StrUtf8Coder(), condition, + lambda now=now: now) + holder = ['input', None] + self.assertTrue(tracker.try_claim(holder)) + self.assertTrue(holder[1][3]) + _, residual = tracker.try_split(0) + self.assertIsInstance(residual, _NonPollingGrowthState) + self.assertEqual((), residual.pending.outputs) + + def test_wrapper_chain_defers_merged_residual(self): + def poll(unused_element): + return PollResult.incomplete([_ts('a', 1), _ts('b', 2)]) + + threadsafe = ThreadsafeRestrictionTracker( + _new_tracker(_initial_polling(), poll)) + view = RestrictionTrackerView(threadsafe) + holder = ['input', None] + self.assertTrue(view.try_claim(holder)) + self.assertEqual('poll', holder[1][0]) + view.defer_remainder(Duration(5)) + residual, _ = threadsafe.deferred_status() + self.assertIsInstance(residual, _PollingGrowthState) + self.assertEqual(2, len(residual.completed)) + + +# Module-level so the poll function pickles by reference; the call counter is +# shared within the single in-memory DirectRunner process. +_POLL_CALLS = collections.defaultdict(int) + + +def _growing_poll(prefix): + _POLL_CALLS[prefix] += 1 + count = _POLL_CALLS[prefix] + outputs = [_ts('%s%d' % (prefix, i), i + 1) for i in range(count)] + if count >= 3: + return PollResult.complete(outputs) + return PollResult.incomplete(outputs) + + +def _complete_poll(prefix): + return PollResult.complete([_ts(prefix + 'a', 1), _ts(prefix + 'b', 2)]) + + +def _windowed_group(kv, window=beam.DoFn.WindowParam): + return ((window.start, window.end), sorted(kv[1])) + + +class WatchEndToEndTest(unittest.TestCase): + def _in_memory_pipeline(self): + return TestPipeline( + options=PipelineOptions(direct_running_mode='in_memory')) + + def test_complete_outputs_values_and_timestamps(self): + with self._in_memory_pipeline() as p: + output = ( + p | beam.Create(['k:']) + | Watch.growth_of(_complete_poll).with_poll_interval(Duration(1))) + assert_that( + output, + equal_to([ + TestWindowedValue(('k:', 'k:a'), Timestamp(1), [GlobalWindow()]), + TestWindowedValue(('k:', 'k:b'), Timestamp(2), [GlobalWindow()]), + ]), + reify_windows=True) + + def test_complete_advances_watermark_for_windowed_pipeline(self): + with self._in_memory_pipeline() as p: + output = ( + p | beam.Create(['k:']) + | Watch.growth_of(_complete_poll).with_poll_interval(Duration(1))) + grouped = ( + output + | beam.WindowInto(FixedWindows(10)) + | beam.Map(lambda kv: ('all', kv[1])) + | beam.GroupByKey() + | beam.Map(_windowed_group)) + assert_that(grouped, equal_to([ + ((Timestamp(0), Timestamp(10)), ['k:a', 'k:b']), + ])) + + def test_multi_round_dedups_stops_and_is_per_input(self): + _POLL_CALLS.clear() + with self._in_memory_pipeline() as p: + output = ( + p | beam.Create(['x:', 'y:']) + | Watch.growth_of(_growing_poll).with_poll_interval(Duration(0.05))) + assert_that( + output, + equal_to([('x:', 'x:0'), ('x:', 'x:1'), ('x:', 'x:2'), + ('y:', 'y:0'), ('y:', 'y:1'), ('y:', 'y:2')])) + self.assertEqual(3, _POLL_CALLS['x:']) + self.assertEqual(3, _POLL_CALLS['y:']) + + +if __name__ == '__main__': + unittest.main()