diff --git a/python/src/waymark/serialization.py b/python/src/waymark/serialization.py index 64665055..a5cab83d 100644 --- a/python/src/waymark/serialization.py +++ b/python/src/waymark/serialization.py @@ -13,6 +13,7 @@ from pydantic import BaseModel from waymark.proto import messages_pb2 as pb2 +from waymark.type_coercion import instantiate_typed_model NULL_VALUE = struct_pb2.NULL_VALUE # type: ignore[attr-defined] @@ -255,9 +256,7 @@ def _primitive_to_python(primitive: pb2.PrimitiveWorkflowArgument) -> Any: def _instantiate_serialized_model(module: str, name: str, model_data: dict[str, Any]) -> Any: cls = _import_symbol(module, name) - if hasattr(cls, "model_validate"): - return cls.model_validate(model_data) # type: ignore[attr-defined] - return cls(**model_data) + return instantiate_typed_model(cls, model_data) def _is_base_model(value: Any) -> bool: diff --git a/python/src/waymark/type_coercion.py b/python/src/waymark/type_coercion.py new file mode 100644 index 00000000..2318bc04 --- /dev/null +++ b/python/src/waymark/type_coercion.py @@ -0,0 +1,211 @@ +import dataclasses +from base64 import b64decode +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from pathlib import PurePath +from types import UnionType +from typing import Any, Union, cast, get_args, get_origin, get_type_hints +from uuid import UUID + +from pydantic import BaseModel + +COERCIBLE_TYPES = (UUID, datetime, date, time, timedelta, Decimal, bytes, PurePath) + + +def instantiate_typed_model(target_type: type, value: dict[str, Any]) -> Any: + """Instantiate a structured model type from a plain mapping payload. + + Supported target types: + - Pydantic ``BaseModel`` subclasses, validated with ``model_validate``. + - Dataclass types, coerced via ``_coerce_dict_to_dataclass`` so nested + field types are honored, omitted fields can use dataclass defaults, and + unexpected keys are rejected. + - Plain Python classes that accept keyword arguments matching the payload + keys, instantiated directly as ``target_type(**value)``. + + Primitive values and container coercion are handled by ``coerce_value``. + This helper is specifically for dict-like payloads that should become a + structured object instance. + """ + if is_pydantic_model_type(target_type): + model_type = cast(type[BaseModel], target_type) + return model_type.model_validate(value) + if is_dataclass_type(target_type): + return _coerce_dict_to_dataclass(value, target_type) + return target_type(**value) + + +def is_pydantic_model_type(target_type: Any) -> bool: + try: + return isinstance(target_type, type) and issubclass(target_type, BaseModel) + except TypeError: + return False + + +def is_dataclass_type(target_type: Any) -> bool: + return isinstance(target_type, type) and dataclasses.is_dataclass(target_type) + + +def coerce_value(value: Any, target_type: type) -> Any: + if value is None or target_type is Any: + return value + + origin = get_origin(target_type) + if origin is UnionType or origin is Union: + return _coerce_union_value(value, target_type) + + if isinstance(target_type, type) and issubclass(target_type, COERCIBLE_TYPES): + return _coerce_primitive(value, target_type) + + if isinstance(value, dict) and ( + is_pydantic_model_type(target_type) or is_dataclass_type(target_type) + ): + return instantiate_typed_model(target_type, value) + + if origin is None: + return value + + args = get_args(target_type) + + if origin is list and isinstance(value, list) and args: + item_type = args[0] + return [coerce_value(item, item_type) for item in value] + + if origin is set and isinstance(value, list) and args: + item_type = args[0] + return {coerce_value(item, item_type) for item in value} + + if origin is frozenset and isinstance(value, list) and args: + item_type = args[0] + return frozenset(coerce_value(item, item_type) for item in value) + + if origin is tuple and isinstance(value, (list, tuple)) and args: + if len(args) == 2 and args[1] is ...: + item_type = args[0] + return tuple(coerce_value(item, item_type) for item in value) + return tuple( + coerce_value(item, item_type) for item, item_type in zip(value, args, strict=False) + ) + + if origin is dict and isinstance(value, dict) and len(args) == 2: + key_type, value_type = args + return { + coerce_value(key, key_type): coerce_value(item, value_type) + for key, item in value.items() + } + + return value + + +def _coerce_union_value(value: Any, target_type: type) -> Any: + for union_type in get_args(target_type): + if union_type is type(None): + if value is None: + return None + continue + try: + coerced = coerce_value(value, union_type) + except Exception: + continue + if coerced is not value: + return coerced + if isinstance(union_type, type) and isinstance(value, union_type): + return value + return value + + +def _coerce_primitive(value: Any, target_type: type) -> Any: + if target_type is UUID: + if isinstance(value, UUID): + return value + if isinstance(value, str): + return UUID(value) + return value + + if target_type is datetime: + if isinstance(value, datetime): + return value + if isinstance(value, str): + return datetime.fromisoformat(value) + return value + + if target_type is date: + if isinstance(value, date): + return value + if isinstance(value, str): + return date.fromisoformat(value) + return value + + if target_type is time: + if isinstance(value, time): + return value + if isinstance(value, str): + return time.fromisoformat(value) + return value + + if target_type is timedelta: + if isinstance(value, timedelta): + return value + if isinstance(value, (int, float)): + return timedelta(seconds=value) + return value + + if target_type is Decimal: + if isinstance(value, Decimal): + return value + if isinstance(value, (str, int, float)): + return Decimal(str(value)) + return value + + if target_type is bytes: + if isinstance(value, bytes): + return value + if isinstance(value, str): + return b64decode(value) + return value + + if issubclass(target_type, PurePath): + if isinstance(value, PurePath): + if isinstance(value, target_type): + return value + return target_type(str(value)) + if isinstance(value, str): + return target_type(value) + return value + + return value + + +def _coerce_dict_to_dataclass(value: dict[str, Any], target_type: type) -> Any: + try: + field_types = get_type_hints(target_type) + except Exception: + field_types = {} + + init_values: dict[str, Any] = {} + deferred_values: dict[str, Any] = {} + field_names: set[str] = set() + + for field in dataclasses.fields(target_type): + field_names.add(field.name) + if field.name not in value: + continue + + field_value = value[field.name] + if field.name in field_types: + field_value = coerce_value(field_value, field_types[field.name]) + + if field.init: + init_values[field.name] = field_value + else: + deferred_values[field.name] = field_value + + extra_fields = set(value) - field_names + if extra_fields: + extras = ", ".join(sorted(extra_fields)) + raise TypeError(f"{target_type.__qualname__} got unexpected field(s): {extras}") + + instance = target_type(**init_values) + for field_name, field_value in deferred_values.items(): + object.__setattr__(instance, field_name, field_value) + return instance diff --git a/python/src/waymark/workflow_runtime.py b/python/src/waymark/workflow_runtime.py index 92741196..777d0086 100644 --- a/python/src/waymark/workflow_runtime.py +++ b/python/src/waymark/workflow_runtime.py @@ -5,14 +5,8 @@ """ import asyncio -import dataclasses -from base64 import b64decode from dataclasses import dataclass -from datetime import date, datetime, time, timedelta -from decimal import Decimal -from pathlib import Path, PurePath -from typing import Any, Dict, get_args, get_origin, get_type_hints -from uuid import UUID +from typing import Any, Dict, get_type_hints from pydantic import BaseModel @@ -21,6 +15,7 @@ from .dependencies import provide_dependencies from .registry import registry from .serialization import arguments_to_kwargs +from .type_coercion import coerce_value as _coerce_value class WorkflowNodeResult(BaseModel): @@ -37,185 +32,6 @@ class ActionExecutionResult: exception: BaseException | None = None -def _is_pydantic_model(cls: type) -> bool: - """Check if a class is a Pydantic BaseModel subclass.""" - try: - return isinstance(cls, type) and issubclass(cls, BaseModel) - except TypeError: - return False - - -def _is_dataclass_type(cls: type) -> bool: - """Check if a class is a dataclass.""" - return dataclasses.is_dataclass(cls) and isinstance(cls, type) - - -def _coerce_primitive(value: Any, target_type: type) -> Any: - """Coerce a value to a primitive type based on target_type. - - Handles conversion of serialized values (strings, floats) back to their - native Python types (UUID, datetime, etc.). - """ - # Handle None - if value is None: - return None - - # UUID from string - if target_type is UUID: - if isinstance(value, UUID): - return value - if isinstance(value, str): - return UUID(value) - return value - - # datetime from ISO string - if target_type is datetime: - if isinstance(value, datetime): - return value - if isinstance(value, str): - return datetime.fromisoformat(value) - return value - - # date from ISO string - if target_type is date: - if isinstance(value, date): - return value - if isinstance(value, str): - return date.fromisoformat(value) - return value - - # time from ISO string - if target_type is time: - if isinstance(value, time): - return value - if isinstance(value, str): - return time.fromisoformat(value) - return value - - # timedelta from total seconds - if target_type is timedelta: - if isinstance(value, timedelta): - return value - if isinstance(value, (int, float)): - return timedelta(seconds=value) - return value - - # Decimal from string - if target_type is Decimal: - if isinstance(value, Decimal): - return value - if isinstance(value, (str, int, float)): - return Decimal(str(value)) - return value - - # bytes from base64 string - if target_type is bytes: - if isinstance(value, bytes): - return value - if isinstance(value, str): - return b64decode(value) - return value - - # Path from string - if target_type is Path or target_type is PurePath: - if isinstance(value, PurePath): - return value - if isinstance(value, str): - return Path(value) - return value - - return value - - -# Types that can be coerced from serialized form -COERCIBLE_TYPES = (UUID, datetime, date, time, timedelta, Decimal, bytes, Path, PurePath) - - -def _coerce_dict_to_model(value: Any, target_type: type) -> Any: - """Convert a dict to a Pydantic model or dataclass if needed. - - If value is a dict and target_type is a Pydantic model or dataclass, - instantiate the model with the dict values. Otherwise, return value unchanged. - """ - if not isinstance(value, dict): - return value - - if _is_pydantic_model(target_type): - # Use model_validate for Pydantic v2, fall back to direct instantiation - model_validate = getattr(target_type, "model_validate", None) - if model_validate is not None: - return model_validate(value) - return target_type(**value) - - if _is_dataclass_type(target_type): - return target_type(**value) - - return value - - -def _coerce_value(value: Any, target_type: type) -> Any: - """Coerce a value to the target type. - - Handles: - - Primitive types (UUID, datetime, etc.) - - Pydantic models and dataclasses (from dicts) - - Generic collections like list[UUID], set[datetime] - """ - # Handle None - if value is None: - return None - - # Check for coercible primitive types - if isinstance(target_type, type) and issubclass(target_type, COERCIBLE_TYPES): - return _coerce_primitive(value, target_type) - - # Check for Pydantic models or dataclasses - if isinstance(value, dict): - coerced = _coerce_dict_to_model(value, target_type) - if coerced is not value: - return coerced - - # Handle generic types like list[UUID], set[datetime] - origin = get_origin(target_type) - if origin is not None: - args = get_args(target_type) - - # Handle list[T] - if origin is list and isinstance(value, list) and args: - item_type = args[0] - return [_coerce_value(item, item_type) for item in value] - - # Handle set[T] (serialized as list) - if origin is set and isinstance(value, list) and args: - item_type = args[0] - return {_coerce_value(item, item_type) for item in value} - - # Handle frozenset[T] (serialized as list) - if origin is frozenset and isinstance(value, list) and args: - item_type = args[0] - return frozenset(_coerce_value(item, item_type) for item in value) - - # Handle tuple[T, ...] (serialized as list) - if origin is tuple and isinstance(value, (list, tuple)) and args: - # Variable length tuple like tuple[int, ...] - if len(args) == 2 and args[1] is ...: - item_type = args[0] - return tuple(_coerce_value(item, item_type) for item in value) - # Fixed length tuple like tuple[int, str, UUID] - return tuple( - _coerce_value(item, item_type) for item, item_type in zip(value, args, strict=False) - ) - - # Handle dict[K, V] - if origin is dict and isinstance(value, dict) and len(args) == 2: - key_type, val_type = args - return { - _coerce_value(k, key_type): _coerce_value(v, val_type) for k, v in value.items() - } - - return value - - def _coerce_kwargs_to_type_hints(handler: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]: """Coerce kwargs to expected types based on handler's type hints. diff --git a/python/tests/test_serialization.py b/python/tests/test_serialization.py index 9f166de4..9b0b7750 100644 --- a/python/tests/test_serialization.py +++ b/python/tests/test_serialization.py @@ -24,6 +24,18 @@ class SampleDataclass: count: int +@dataclass +class NestedSampleDataclass: + identifier: UUID + created_at: datetime + + +@dataclass +class SampleDataclassEnvelope: + item: NestedSampleDataclass + related_ids: list[UUID] + + def test_result_round_trip_with_basemodel() -> None: payload = serialize_result_payload(SampleModel(payload="hello")) decoded = deserialize_result_payload(payload) @@ -94,6 +106,27 @@ def test_result_round_trip_with_dataclass() -> None: assert decoded.result.count == 42 +def test_result_round_trip_with_nested_typed_dataclass() -> None: + created_at = datetime(2024, 1, 15, 10, 30, 45, tzinfo=timezone.utc) + identifier = UUID("12345678-1234-5678-1234-567812345678") + related_id = UUID("87654321-4321-8765-4321-876543218765") + + payload = serialize_result_payload( + SampleDataclassEnvelope( + item=NestedSampleDataclass(identifier=identifier, created_at=created_at), + related_ids=[identifier, related_id], + ) + ) + decoded = deserialize_result_payload(payload) + + assert decoded.error is None + assert isinstance(decoded.result, SampleDataclassEnvelope) + assert isinstance(decoded.result.item, NestedSampleDataclass) + assert decoded.result.item.identifier == identifier + assert decoded.result.item.created_at == created_at + assert decoded.result.related_ids == [identifier, related_id] + + class ModelWithUUID(BaseModel): id: UUID name: str diff --git a/python/tests/test_type_coercion.py b/python/tests/test_type_coercion.py new file mode 100644 index 00000000..d28b092d --- /dev/null +++ b/python/tests/test_type_coercion.py @@ -0,0 +1,324 @@ +from base64 import b64encode +from dataclasses import dataclass, field +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal +from pathlib import Path +from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple, Union +from uuid import UUID + +import pytest +from pydantic import BaseModel + +from waymark.type_coercion import _coerce_dict_to_dataclass, coerce_value + +UUID_STR = "12345678-1234-5678-1234-567812345678" +UUID_OBJ = UUID(UUID_STR) +SECOND_UUID_STR = "87654321-4321-8765-4321-876543218765" +SECOND_UUID_OBJ = UUID(SECOND_UUID_STR) +RECORDED_AT = datetime(2024, 1, 15, 10, 30, 45, tzinfo=timezone.utc) +NEXT_RECORDED_AT = datetime(2024, 1, 16, 11, 15, 0, tzinfo=timezone.utc) +RECORDED_DATE = date(2024, 1, 15) +RECORDED_TIME = time(10, 30, 45) +DURATION = timedelta(hours=2, minutes=30) +DECIMAL_VALUE = Decimal("123.456789012345678901234567890") +BINARY_PAYLOAD = b"hello world" +BINARY_PAYLOAD_B64 = b64encode(BINARY_PAYLOAD).decode("ascii") +BIN_PATH = Path("/usr/local/bin") + + +@dataclass +class DataclassWithDefaults: + name: str + enabled: bool = True + tags: list[str] = field(default_factory=list) + + +@dataclass +class StrictDataclass: + name: str + retries: int + active: bool + + +@dataclass +class TypedDataclass: + reading_id: UUID + + +class TypedModel(BaseModel): + name: str + + +def _assert_coerced_value(annotation: Any, payload: Any, expected: Any) -> None: + result = coerce_value(payload, annotation) + + if expected is None: + assert result is None + return + + assert result == expected + assert isinstance(result, type(expected)) + + +def _pep604_optional(inner_type: Any) -> Any: + return inner_type | None + + +def _typing_optional(inner_type: Any) -> Any: + return Optional[inner_type] + + +def _typing_union_optional(inner_type: Any) -> Any: + return Union[inner_type, None] + + +def _pep604_union(primary_type: Any, secondary_type: Any) -> Any: + return primary_type | secondary_type + + +def _typing_union(primary_type: Any, secondary_type: Any) -> Any: + return Union[primary_type, secondary_type] + + +@pytest.mark.parametrize( + ("payload", "expected"), + [ + ({"name": "alpha"}, DataclassWithDefaults(name="alpha")), + ( + {"name": "beta", "enabled": False}, + DataclassWithDefaults(name="beta", enabled=False), + ), + ( + {"name": "gamma", "tags": ["ops"]}, + DataclassWithDefaults(name="gamma", tags=["ops"]), + ), + ], +) +def test_coerce_dict_to_dataclass_uses_defaults_for_missing_fields( + payload: dict[str, object], + expected: DataclassWithDefaults, +) -> None: + result = _coerce_dict_to_dataclass(payload, DataclassWithDefaults) + + assert result == expected + + +@pytest.mark.parametrize( + ("payload", "expected"), + [ + ( + {"name": "alpha", "retries": 1, "active": True}, + StrictDataclass(name="alpha", retries=1, active=True), + ), + ( + {"name": "beta", "retries": 3, "active": False}, + StrictDataclass(name="beta", retries=3, active=False), + ), + ], +) +def test_coerce_dict_to_dataclass_accepts_exact_payload( + payload: dict[str, object], + expected: StrictDataclass, +) -> None: + result = _coerce_dict_to_dataclass(payload, StrictDataclass) + + assert result == expected + + +@pytest.mark.parametrize( + ("payload", "message"), + [ + ( + {"name": "alpha", "retries": 1, "active": True, "extra": "value"}, + "StrictDataclass got unexpected field(s): extra", + ), + ( + { + "name": "beta", + "retries": 2, + "active": False, + "extra_one": "value", + "extra_two": "value", + }, + "StrictDataclass got unexpected field(s): extra_one, extra_two", + ), + ], +) +def test_coerce_dict_to_dataclass_rejects_extra_fields( + payload: dict[str, object], + message: str, +) -> None: + with pytest.raises(TypeError) as exc_info: + _coerce_dict_to_dataclass(payload, StrictDataclass) + + assert str(exc_info.value) == message + + +@pytest.mark.parametrize( + ("annotation", "payload", "expected"), + [ + pytest.param(UUID, UUID_STR, UUID_OBJ, id="uuid"), + pytest.param(datetime, RECORDED_AT.isoformat(), RECORDED_AT, id="datetime"), + pytest.param(date, RECORDED_DATE.isoformat(), RECORDED_DATE, id="date"), + pytest.param(time, RECORDED_TIME.isoformat(), RECORDED_TIME, id="time"), + pytest.param(timedelta, DURATION.total_seconds(), DURATION, id="timedelta"), + pytest.param(Decimal, str(DECIMAL_VALUE), DECIMAL_VALUE, id="decimal"), + pytest.param(bytes, BINARY_PAYLOAD_B64, BINARY_PAYLOAD, id="bytes"), + pytest.param(Path, str(BIN_PATH), BIN_PATH, id="path"), + ], +) +def test_coerce_value_primitives(annotation: Any, payload: Any, expected: Any) -> None: + _assert_coerced_value(annotation, payload, expected) + + +@pytest.mark.parametrize( + ("annotation", "payload", "expected"), + [ + pytest.param( + list[UUID], + [UUID_STR, SECOND_UUID_STR], + [UUID_OBJ, SECOND_UUID_OBJ], + id="builtins-list", + ), + pytest.param( + List[UUID], + [UUID_STR, SECOND_UUID_STR], + [UUID_OBJ, SECOND_UUID_OBJ], + id="typing-list", + ), + pytest.param( + set[datetime], + [RECORDED_AT.isoformat(), NEXT_RECORDED_AT.isoformat()], + {RECORDED_AT, NEXT_RECORDED_AT}, + id="builtins-set", + ), + pytest.param( + Set[datetime], + [RECORDED_AT.isoformat(), NEXT_RECORDED_AT.isoformat()], + {RECORDED_AT, NEXT_RECORDED_AT}, + id="typing-set", + ), + pytest.param( + frozenset[UUID], + [UUID_STR, SECOND_UUID_STR], + frozenset({UUID_OBJ, SECOND_UUID_OBJ}), + id="builtins-frozenset", + ), + pytest.param( + FrozenSet[UUID], + [UUID_STR, SECOND_UUID_STR], + frozenset({UUID_OBJ, SECOND_UUID_OBJ}), + id="typing-frozenset", + ), + pytest.param( + tuple[UUID, datetime], + [UUID_STR, RECORDED_AT.isoformat()], + (UUID_OBJ, RECORDED_AT), + id="builtins-tuple-fixed", + ), + pytest.param( + Tuple[UUID, datetime], + [UUID_STR, RECORDED_AT.isoformat()], + (UUID_OBJ, RECORDED_AT), + id="typing-tuple-fixed", + ), + pytest.param( + tuple[UUID, ...], + [UUID_STR, SECOND_UUID_STR], + (UUID_OBJ, SECOND_UUID_OBJ), + id="builtins-tuple-variadic", + ), + pytest.param( + Tuple[UUID, ...], + [UUID_STR, SECOND_UUID_STR], + (UUID_OBJ, SECOND_UUID_OBJ), + id="typing-tuple-variadic", + ), + pytest.param( + dict[str, UUID], + {"user_id": UUID_STR}, + {"user_id": UUID_OBJ}, + id="builtins-dict", + ), + pytest.param( + Dict[str, UUID], + {"user_id": UUID_STR}, + {"user_id": UUID_OBJ}, + id="typing-dict", + ), + ], +) +def test_coerce_value_container_annotations(annotation: Any, payload: Any, expected: Any) -> None: + _assert_coerced_value(annotation, payload, expected) + + +@pytest.mark.parametrize( + "annotation_factory", + [ + pytest.param(_pep604_optional, id="pep604-optional"), + pytest.param(_typing_optional, id="typing-optional"), + pytest.param(_typing_union_optional, id="typing-union-optional"), + ], +) +@pytest.mark.parametrize( + ("inner_type", "payload", "expected"), + [ + pytest.param(UUID, UUID_STR, UUID_OBJ, id="uuid"), + pytest.param( + TypedDataclass, + {"reading_id": UUID_STR}, + TypedDataclass(reading_id=UUID_OBJ), + id="dataclass", + ), + pytest.param(TypedModel, {"name": "alpha"}, TypedModel(name="alpha"), id="pydantic"), + pytest.param(UUID, None, None, id="none"), + ], +) +def test_coerce_value_optional_annotation_variants( + annotation_factory: Any, + inner_type: Any, + payload: Any, + expected: Any, +) -> None: + annotation = annotation_factory(inner_type) + + _assert_coerced_value(annotation, payload, expected) + + +@pytest.mark.parametrize( + "annotation_factory", + [ + pytest.param(_pep604_union, id="pep604-union"), + pytest.param(_typing_union, id="typing-union"), + ], +) +@pytest.mark.parametrize( + ("primary_type", "secondary_type", "payload", "expected"), + [ + pytest.param(UUID, int, UUID_STR, UUID_OBJ, id="uuid-int"), + pytest.param( + TypedDataclass, + str, + {"reading_id": UUID_STR}, + TypedDataclass(reading_id=UUID_OBJ), + id="dataclass-str", + ), + pytest.param( + TypedModel, + str, + {"name": "beta"}, + TypedModel(name="beta"), + id="pydantic-str", + ), + ], +) +def test_coerce_value_union_annotation_variants( + annotation_factory: Any, + primary_type: Any, + secondary_type: Any, + payload: Any, + expected: Any, +) -> None: + annotation = annotation_factory(primary_type, secondary_type) + + _assert_coerced_value(annotation, payload, expected) diff --git a/python/tests/test_workflow.py b/python/tests/test_workflow.py index cd6a3613..c9c99288 100644 --- a/python/tests/test_workflow.py +++ b/python/tests/test_workflow.py @@ -313,6 +313,53 @@ async def fake_execute_workflow(_payload: bytes) -> bytes: assert result.count == 3 +def test_workflow_result_coerces_nested_typed_dataclass( + monkeypatch: pytest.MonkeyPatch, +) -> None: + @dataclass + class ResultMetadata: + created_at: datetime + sample_ids: list[UUID] + + @dataclass + class ResultData: + result_id: UUID + metadata: ResultMetadata + + @action + async def build_result() -> ResultData: + raise NotImplementedError + + @workflow_decorator + class DataWorkflow(Workflow): + async def run(self) -> ResultData: + return await build_result() + + result_id = uuid4() + created_at = datetime(2024, 1, 2, 3, 4, 5) + sample_ids = [uuid4(), uuid4()] + + async def fake_execute_workflow(_payload: bytes) -> bytes: + response = { + "result_id": str(result_id), + "metadata": { + "created_at": created_at.isoformat(), + "sample_ids": [str(sample_id) for sample_id in sample_ids], + }, + } + payload = serialize_result_payload(response) + return payload.SerializeToString() + + monkeypatch.setattr(bridge, "execute_workflow", fake_execute_workflow) + + result = asyncio.run(DataWorkflow().run()) + assert isinstance(result, ResultData) + assert result.result_id == result_id + assert isinstance(result.metadata, ResultMetadata) + assert result.metadata.created_at == created_at + assert result.metadata.sample_ids == sample_ids + + def test_workflow_result_optional_returns_none(monkeypatch: pytest.MonkeyPatch) -> None: @dataclass class OptionalData: diff --git a/python/tests/test_workflow_runtime.py b/python/tests/test_workflow_runtime.py index 792e7930..941d8037 100644 --- a/python/tests/test_workflow_runtime.py +++ b/python/tests/test_workflow_runtime.py @@ -2,9 +2,7 @@ import asyncio from dataclasses import dataclass as python_dataclass -from datetime import date, datetime, time, timedelta, timezone -from decimal import Decimal -from pathlib import Path +from datetime import datetime, timezone from typing import Annotated from uuid import UUID @@ -14,7 +12,7 @@ from waymark.actions import action from waymark.dependencies import Depend from waymark.proto import messages_pb2 as pb2 -from waymark.workflow_runtime import ActionExecutionResult, _coerce_value, execute_action +from waymark.workflow_runtime import ActionExecutionResult, execute_action @action @@ -151,6 +149,18 @@ class PointData: y: int +@python_dataclass +class ReadingMetadata: + recorded_at: datetime + sample_ids: list[UUID] + + +@python_dataclass +class ReadingRequest: + reading_id: UUID + metadata: ReadingMetadata + + @action async def greet_person(person: PersonModel) -> str: """Action that expects a Pydantic model argument.""" @@ -163,6 +173,20 @@ async def compute_distance(point: PointData) -> int: return point.x + point.y +@action +async def summarize_reading(reading: ReadingRequest) -> str: + """Action that validates nested dataclass coercion.""" + if not isinstance(reading.reading_id, UUID): + raise TypeError("reading_id was not coerced to UUID") + if not isinstance(reading.metadata, ReadingMetadata): + raise TypeError("metadata was not coerced to ReadingMetadata") + if not isinstance(reading.metadata.recorded_at, datetime): + raise TypeError("recorded_at was not coerced to datetime") + if not all(isinstance(sample_id, UUID) for sample_id in reading.metadata.sample_ids): + raise TypeError("sample_ids were not coerced to UUID") + return f"{reading.metadata.recorded_at.year}:{len(reading.metadata.sample_ids)}" + + def _build_action_dispatch_with_dict( action_name: str, module_name: str, @@ -190,6 +214,11 @@ def add_value_to_proto(proto_value: pb2.WorkflowArgumentValue, value: object) -> proto_value.primitive.double_value = value elif isinstance(value, bool): proto_value.primitive.bool_value = value + elif isinstance(value, list): + proto_value.list_value.SetInParent() + for item in value: + item_value = proto_value.list_value.items.add() + add_value_to_proto(item_value, item) elif isinstance(value, dict): proto_value.dict_value.SetInParent() for k, v in value.items(): @@ -244,114 +273,34 @@ def test_execute_action_coerces_dict_to_dataclass() -> None: assert result.result == 7 # 3 + 4 -# ---- Tests for primitive type coercion ---- - - -def test_coerce_uuid_from_string() -> None: - """Test that UUID strings are coerced to UUID objects.""" - uuid_str = "12345678-1234-5678-1234-567812345678" - result = _coerce_value(uuid_str, UUID) - assert isinstance(result, UUID) - assert str(result) == uuid_str - - -def test_coerce_datetime_from_string() -> None: - """Test that ISO datetime strings are coerced to datetime objects.""" - dt = datetime(2024, 1, 15, 10, 30, 45, tzinfo=timezone.utc) - result = _coerce_value(dt.isoformat(), datetime) - assert isinstance(result, datetime) - assert result == dt - - -def test_coerce_date_from_string() -> None: - """Test that ISO date strings are coerced to date objects.""" - d = date(2024, 1, 15) - result = _coerce_value(d.isoformat(), date) - assert isinstance(result, date) - assert result == d - - -def test_coerce_time_from_string() -> None: - """Test that ISO time strings are coerced to time objects.""" - t = time(10, 30, 45) - result = _coerce_value(t.isoformat(), time) - assert isinstance(result, time) - assert result == t - - -def test_coerce_timedelta_from_seconds() -> None: - """Test that numeric values are coerced to timedelta objects.""" - td = timedelta(hours=2, minutes=30) - result = _coerce_value(td.total_seconds(), timedelta) - assert isinstance(result, timedelta) - assert result == td - - -def test_coerce_decimal_from_string() -> None: - """Test that string values are coerced to Decimal objects.""" - d = Decimal("123.456789012345678901234567890") - result = _coerce_value(str(d), Decimal) - assert isinstance(result, Decimal) - assert result == d - - -def test_coerce_bytes_from_base64() -> None: - """Test that base64 strings are coerced to bytes.""" - from base64 import b64encode - - data = b"hello world" - result = _coerce_value(b64encode(data).decode("ascii"), bytes) - assert isinstance(result, bytes) - assert result == data +def test_execute_action_coerces_nested_typed_dataclass() -> None: + """Test that nested dataclass fields use their type hints during coercion.""" + if action_registry.get(__name__, "summarize_reading") is None: + action_registry.register(__name__, "summarize_reading", summarize_reading) + reading_id = UUID("12345678-1234-5678-1234-567812345678") + recorded_at = datetime(2024, 1, 15, 10, 30, 45, tzinfo=timezone.utc) + sample_ids = [ + UUID("87654321-4321-8765-4321-876543218765"), + UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"), + ] -def test_coerce_path_from_string() -> None: - """Test that strings are coerced to Path objects.""" - p = Path("/usr/local/bin") - result = _coerce_value(str(p), Path) - assert isinstance(result, Path) - assert result == p + dispatch = _build_action_dispatch_with_dict( + action_name="summarize_reading", + module_name=__name__, + kwargs={ + "reading": { + "reading_id": str(reading_id), + "metadata": { + "recorded_at": recorded_at.isoformat(), + "sample_ids": [str(sample_id) for sample_id in sample_ids], + }, + } + }, + ) + result = asyncio.run(execute_action(dispatch)) -def test_coerce_list_of_uuids() -> None: - """Test that list[UUID] coerces string items to UUIDs.""" - uuid_strs = [ - "12345678-1234-5678-1234-567812345678", - "87654321-4321-8765-4321-876543218765", - ] - result = _coerce_value(uuid_strs, list[UUID]) - assert isinstance(result, list) - assert all(isinstance(u, UUID) for u in result) - assert [str(u) for u in result] == uuid_strs - - -def test_coerce_set_of_datetimes() -> None: - """Test that set[datetime] coerces list of ISO strings to set of datetimes.""" - dt1 = datetime(2024, 1, 15, 10, 0, 0, tzinfo=timezone.utc) - dt2 = datetime(2024, 1, 16, 11, 0, 0, tzinfo=timezone.utc) - result = _coerce_value([dt1.isoformat(), dt2.isoformat()], set[datetime]) - assert isinstance(result, set) - assert all(isinstance(d, datetime) for d in result) - assert result == {dt1, dt2} - - -def test_coerce_dict_with_uuid_values() -> None: - """Test that dict[str, UUID] coerces string values to UUIDs.""" - uuid_str = "12345678-1234-5678-1234-567812345678" - result = _coerce_value({"user_id": uuid_str}, dict[str, UUID]) - assert isinstance(result, dict) - assert isinstance(result["user_id"], UUID) - assert str(result["user_id"]) == uuid_str - - -def test_coerce_preserves_already_correct_type() -> None: - """Test that values already of the correct type are preserved.""" - uuid_obj = UUID("12345678-1234-5678-1234-567812345678") - result = _coerce_value(uuid_obj, UUID) - assert result is uuid_obj - - -def test_coerce_none_returns_none() -> None: - """Test that None values are preserved.""" - result = _coerce_value(None, UUID) - assert result is None + assert isinstance(result, ActionExecutionResult) + assert result.exception is None, f"Unexpected exception: {result.exception}" + assert result.result == "2024:2"