diff --git a/plugboard-schemas/plugboard_schemas/_validation.py b/plugboard-schemas/plugboard_schemas/_validation.py index 98eb886c..ee71c175 100644 --- a/plugboard-schemas/plugboard_schemas/_validation.py +++ b/plugboard-schemas/plugboard_schemas/_validation.py @@ -18,6 +18,9 @@ from ._validator_registry import validator +_SYSTEM_STOP_EVENT = "system_stop" + + def _build_component_graph( connectors: dict[str, dict[str, _t.Any]], ) -> dict[str, set[str]]: @@ -98,9 +101,11 @@ def validate_all_inputs_connected( for comp_name, comp_data in components.items(): io = comp_data.get("io", {}) all_inputs = set(io.get("inputs", [])) + input_events = set(io.get("input_events", [])) + has_non_system_input_events = bool(input_events - {_SYSTEM_STOP_EVENT}) connected = connected_inputs.get(comp_name, set()) unconnected = all_inputs - connected - if unconnected: + if unconnected and not has_non_system_input_events: errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}") return errors diff --git a/plugboard/component/component.py b/plugboard/component/component.py index 6fe0ad20..2db9e92b 100644 --- a/plugboard/component/component.py +++ b/plugboard/component/component.py @@ -356,7 +356,7 @@ async def _wrapper() -> None: raise e self._bind_outputs() await self.io.write() - self._field_inputs_ready = False + self._reset_input_trackers() await self._set_status(Status.WAITING, publish=not self._is_running) return _wrapper @@ -365,6 +365,11 @@ async def _wrapper() -> None: def _has_field_inputs(self) -> bool: return len(self.io.inputs) > 0 + @property + def _has_connected_field_inputs(self) -> bool: + """Whether any declared field inputs are connected via input channels.""" + return self.io.has_connected_field_inputs + @cached_property def _has_event_inputs(self) -> bool: input_events = set([evt.safe_type() for evt in self.io.input_events]) @@ -409,7 +414,7 @@ async def _io_read_with_status_check(self) -> None: task.cancel() for task in done: exc = task.exception() - if isinstance(exc, EventStreamClosedError) and len(self.io.inputs) == 0: + if isinstance(exc, EventStreamClosedError) and not self._has_connected_field_inputs: await self.io.close() # Call close for final wait and flush event buffer elif exc is not None: raise exc @@ -422,7 +427,7 @@ async def _periodic_status_check(self) -> None: # TODO : Eventually producer graph update will be event driven. For now, # : the update is performed periodically, so it's called here along # : with the status check. - if len(self.io.inputs) == 0: + if not self._has_connected_field_inputs: await self._update_producer_graph() async def _status_check(self) -> None: @@ -455,8 +460,11 @@ def _bind_inputs(self) -> None: for field in self.io.inputs: field_default = getattr(self, field, None) value = self._field_inputs.get(field, field_default) - setattr(self, field, value) + super().__setattr__(field, value) + + def _reset_input_trackers(self) -> None: self._field_inputs = {} + self._field_inputs_ready = False def _bind_outputs(self) -> None: """Binds component fields to output fields.""" diff --git a/plugboard/component/io_controller.py b/plugboard/component/io_controller.py index 7500aee2..5ac67f7c 100644 --- a/plugboard/component/io_controller.py +++ b/plugboard/component/io_controller.py @@ -86,8 +86,9 @@ def is_closed(self) -> bool: """Returns `True` if the `IOController` is closed, `False` otherwise.""" return self._is_closed - @cached_property - def _has_field_inputs(self) -> bool: + @property + def has_connected_field_inputs(self) -> bool: + """Returns whether any field inputs are connected via channels.""" return len(self._input_channels) > 0 @cached_property @@ -96,7 +97,7 @@ def _has_event_inputs(self) -> bool: @cached_property def _has_inputs(self) -> bool: - return self._has_field_inputs or self._has_event_inputs + return self.has_connected_field_inputs or self._has_event_inputs async def read(self, timeout: float | None = None) -> None: """Reads data and/or events from input channels. @@ -139,7 +140,7 @@ async def read(self, timeout: float | None = None) -> None: def _set_read_tasks(self) -> list[asyncio.Task]: read_tasks: list[asyncio.Task] = [] - if self._has_field_inputs: + if self.has_connected_field_inputs: if _fields_read_task not in self._read_tasks: read_fields_task = asyncio.create_task(self._read_fields(), name=_fields_read_task) self._read_tasks[_fields_read_task] = read_fields_task @@ -374,7 +375,7 @@ def _add_channel_for_event( def _create_input_field_group_tasks(self) -> None: """Groups input field channels by field name and launches read tasks for group inputs.""" - if not self._has_field_inputs: + if not self.has_connected_field_inputs: return field_channels: dict[str, list[tuple[_t_field_key, Channel]]] = defaultdict(list) for key, chan in self._input_channels.items(): diff --git a/plugboard/library/data_writer.py b/plugboard/library/data_writer.py index 96d14538..1c0dbf72 100644 --- a/plugboard/library/data_writer.py +++ b/plugboard/library/data_writer.py @@ -43,6 +43,7 @@ def __init__( **kwargs: Additional keyword arguments for [`Component`][plugboard.component.Component]. """ super().__init__(**kwargs) + # Use a single buffer to track everything self._buffer: dict[str, deque] = defaultdict(deque) self._chunk_size = chunk_size self.io = IOController( @@ -76,22 +77,43 @@ async def _convert(self, data: dict[str, deque]) -> _t.Any: def _bind_inputs(self) -> None: """Binds input fields to component fields and append to internal buffer.""" super()._bind_inputs() - for field in self.io.inputs: + for field in self._field_inputs: value = getattr(self, field, None) self._buffer[field].append(value) + @property + def _completed_rows(self) -> int: + """Calculates how many fully formed rows exist in the buffer.""" + if not self.io.inputs: + return 0 + return min((len(self._buffer[f]) for f in self.io.inputs), default=0) + + @property + def _can_step(self) -> bool: + """We can step if we have at least one fully formed row.""" + return self._completed_rows > 0 + async def _save_chunk(self) -> None: - """Write data from the buffer.""" + """Write completed data rows from the buffer.""" + completed_rows = self._completed_rows + if completed_rows == 0: + return + if self._task is not None: await self._task - # Create task to save next chunk of data - chunk = await self._convert(self._buffer) + + # Extract only the completed rows into a new chunk + chunk_data = { + field: deque([self._buffer[field].popleft() for _ in range(completed_rows)]) + for field in self.io.inputs + } + + chunk = await self._convert(chunk_data) self._task = asyncio.create_task(self._save(chunk)) - self._buffer = defaultdict(deque) async def step(self) -> None: """Trigger save when buffer is at target size.""" - if self._chunk_size and len(self._buffer[self.io.inputs[0]]) >= self._chunk_size: + if self._chunk_size and len(self._write_buffer[self.io.inputs[0]]) >= self._chunk_size: await self._save_chunk() async def run(self) -> None: diff --git a/tests/integration/test_process_with_components_run.py b/tests/integration/test_process_with_components_run.py index fe047ae8..8f48a2dc 100644 --- a/tests/integration/test_process_with_components_run.py +++ b/tests/integration/test_process_with_components_run.py @@ -23,6 +23,7 @@ ) from plugboard.events import Event from plugboard.exceptions import ConstraintError, NotInitialisedError, ProcessStatusError +from plugboard.library import FileWriter from plugboard.process import LocalProcess, Process, RayProcess from plugboard.schemas import ConnectorSpec, Status from tests.conftest import ComponentTestHelper, zmq_connector_cls @@ -459,6 +460,85 @@ async def test_event_driven_process_shutdown( await process.destroy() +class MessageEventData(BaseModel): + """Data for a message event.""" + + message: str + + +class MessageEvent(Event): + """Event carrying a file-writer message.""" + + type: _t.ClassVar[str] = "message_event" + data: MessageEventData + + +class MessageEventGenerator(ComponentTestHelper): + """Produces a fixed number of message events.""" + + io = IO(output_events=[MessageEvent]) + + def __init__(self, iters: int, *args: _t.Any, **kwargs: _t.Any) -> None: + super().__init__(*args, **kwargs) + self._iters = iters + + async def init(self) -> None: + await super().init() + self._seq = iter(range(self._iters)) + + async def step(self) -> None: + try: + idx = next(self._seq) + except StopIteration: + await self.io.close() + else: + evt = MessageEvent( + source=self.name, + data=MessageEventData(message=f"Message {idx}"), + ) + self.io.queue_event(evt) + await super().step() + + +class EventReaderFileWriter(FileWriter): + """`FileWriter` variant that adds event handling instead of a connector for `message`.""" + + io = IO(input_events=[MessageEvent]) + + @MessageEvent.handler + async def handle_message(self, event: MessageEvent) -> None: + self.message = event.data.message + + +@pytest.mark.asyncio +async def test_event_driven_file_writer_reuse(tmp_path: Path) -> None: + """Test that field-input components can be reused in event-driven processes.""" + output_path = tmp_path / "output_messages.csv" + components = [ + MessageEventGenerator(iters=3, name="message_event_generator"), + EventReaderFileWriter( + path=output_path, + name="event_reader_file_writer", + field_names=["message"], + ), + ] + event_connectors = AsyncioConnector.builder().build_event_connectors(components) + process = LocalProcess(components=components, connectors=event_connectors) + + await process.init() + await process.run() + + assert process.status == Status.COMPLETED + assert output_path.read_text().splitlines() == [ + "message", + "Message 0", + "Message 1", + "Message 2", + ] + + await process.destroy() + + _SHORT_TIMEOUT = 0.1 diff --git a/tests/unit/test_process_validation.py b/tests/unit/test_process_validation.py index 02e0a4d2..b0ec7482 100644 --- a/tests/unit/test_process_validation.py +++ b/tests/unit/test_process_validation.py @@ -303,6 +303,21 @@ def test_no_inputs_no_errors(self) -> None: errors = validate_all_inputs_connected(pd) assert errors == [] + def test_missing_inputs_allowed_for_event_driven_component_reuse(self) -> None: + """Unconnected inputs are allowed when non-system input events can populate them.""" + pd = _make_process_dict( + components={ + "producer": _make_component("producer", output_events=["message_event"]), + "writer": _make_component( + "writer", + inputs=["message"], + input_events=["system_stop", "message_event"], + ), + }, + ) + errors = validate_all_inputs_connected(pd) + assert errors == [] + # --------------------------------------------------------------------------- # Tests for validate_input_events