Skip to content

Commit b30bc1e

Browse files
feat(internal/types): support eagerly validating pydantic iterators
1 parent e3f0b8d commit b30bc1e

2 files changed

Lines changed: 137 additions & 3 deletions

File tree

src/kernel/_models.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
ClassVar,
2626
Protocol,
2727
Required,
28+
Annotated,
2829
ParamSpec,
30+
TypeAlias,
2931
TypedDict,
3032
TypeGuard,
3133
final,
@@ -79,7 +81,15 @@
7981
from ._constants import RAW_RESPONSE_HEADER
8082

8183
if TYPE_CHECKING:
84+
from pydantic import GetCoreSchemaHandler, ValidatorFunctionWrapHandler
85+
from pydantic_core import CoreSchema, core_schema
8286
from pydantic_core.core_schema import ModelField, ModelSchema, LiteralSchema, ModelFieldsSchema
87+
else:
88+
try:
89+
from pydantic_core import CoreSchema, core_schema
90+
except ImportError:
91+
CoreSchema = None
92+
core_schema = None
8393

8494
__all__ = ["BaseModel", "GenericModel"]
8595

@@ -396,6 +406,76 @@ def model_dump_json(
396406
)
397407

398408

409+
class _EagerIterable(list[_T], Generic[_T]):
410+
"""
411+
Accepts any Iterable[T] input (including generators), consumes it
412+
eagerly, and validates all items upfront.
413+
414+
Validation preserves the original container type where possible
415+
(e.g. a set[T] stays a set[T]). Serialization (model_dump / JSON)
416+
always emits a list — round-tripping through model_dump() will not
417+
restore the original container type.
418+
"""
419+
420+
@classmethod
421+
def __get_pydantic_core_schema__(
422+
cls,
423+
source_type: Any,
424+
handler: GetCoreSchemaHandler,
425+
) -> CoreSchema:
426+
(item_type,) = get_args(source_type) or (Any,)
427+
item_schema: CoreSchema = handler.generate_schema(item_type)
428+
list_of_items_schema: CoreSchema = core_schema.list_schema(item_schema)
429+
430+
return core_schema.no_info_wrap_validator_function(
431+
cls._validate,
432+
list_of_items_schema,
433+
serialization=core_schema.plain_serializer_function_ser_schema(
434+
cls._serialize,
435+
info_arg=False,
436+
),
437+
)
438+
439+
@staticmethod
440+
def _validate(v: Iterable[_T], handler: "ValidatorFunctionWrapHandler") -> Any:
441+
original_type: type[Any] = type(v)
442+
443+
# Normalize to list so list_schema can validate each item
444+
if isinstance(v, list):
445+
items: list[_T] = v
446+
else:
447+
try:
448+
items = list(v)
449+
except TypeError as e:
450+
raise TypeError("Value is not iterable") from e
451+
452+
# Validate items against the inner schema
453+
validated: list[_T] = handler(items)
454+
455+
# Reconstruct original container type
456+
if original_type is list:
457+
return validated
458+
# str(list) produces the list's repr, not a string built from items,
459+
# so skip reconstruction for str and its subclasses.
460+
if issubclass(original_type, str):
461+
return validated
462+
try:
463+
return original_type(validated)
464+
except (TypeError, ValueError):
465+
# If the type cannot be reconstructed, just return the validated list
466+
return validated
467+
468+
@staticmethod
469+
def _serialize(v: Iterable[_T]) -> list[_T]:
470+
"""Always serialize as a list so Pydantic's JSON encoder is happy."""
471+
if isinstance(v, list):
472+
return v
473+
return list(v)
474+
475+
476+
EagerIterable: TypeAlias = Annotated[Iterable[_T], _EagerIterable]
477+
478+
399479
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
400480
if value is None:
401481
return field_get_default(field)

tests/test_models.py

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import json
2-
from typing import TYPE_CHECKING, Any, Dict, List, Union, Optional, cast
2+
from typing import TYPE_CHECKING, Any, Dict, List, Union, Iterable, Optional, cast
33
from datetime import datetime, timezone
4-
from typing_extensions import Literal, Annotated, TypeAliasType
4+
from collections import deque
5+
from typing_extensions import Literal, Annotated, TypedDict, TypeAliasType
56

67
import pytest
78
import pydantic
89
from pydantic import Field
910

1011
from kernel._utils import PropertyInfo
1112
from kernel._compat import PYDANTIC_V1, parse_obj, model_dump, model_json
12-
from kernel._models import DISCRIMINATOR_CACHE, BaseModel, construct_type
13+
from kernel._models import DISCRIMINATOR_CACHE, BaseModel, EagerIterable, construct_type
1314

1415

1516
class BasicModel(BaseModel):
@@ -961,3 +962,56 @@ def __getattr__(self, attr: str) -> Item: ...
961962
assert model.a.prop == 1
962963
assert isinstance(model.a, Item)
963964
assert model.other == "foo"
965+
966+
967+
# NOTE: Workaround for Pydantic Iterable behavior.
968+
# Iterable fields are replaced with a ValidatorIterator and may be consumed
969+
# during serialization, which can cause subsequent dumps to return empty data.
970+
# See: https://github.com/pydantic/pydantic/issues/9541
971+
@pytest.mark.parametrize(
972+
"data, expected_validated",
973+
[
974+
([1, 2, 3], [1, 2, 3]),
975+
((1, 2, 3), (1, 2, 3)),
976+
(set([1, 2, 3]), set([1, 2, 3])),
977+
(iter([1, 2, 3]), [1, 2, 3]),
978+
([], []),
979+
((x for x in [1, 2, 3]), [1, 2, 3]),
980+
(map(lambda x: x, [1, 2, 3]), [1, 2, 3]),
981+
(frozenset([1, 2, 3]), frozenset([1, 2, 3])),
982+
(deque([1, 2, 3]), deque([1, 2, 3])),
983+
],
984+
ids=["list", "tuple", "set", "iterator", "empty", "generator", "map", "frozenset", "deque"],
985+
)
986+
@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2")
987+
def test_iterable_construction(data: Iterable[int], expected_validated: Iterable[int]) -> None:
988+
class TypeWithIterable(TypedDict):
989+
items: EagerIterable[int]
990+
991+
class Model(BaseModel):
992+
data: TypeWithIterable
993+
994+
m = Model.model_validate({"data": {"items": data}})
995+
assert m.data["items"] == expected_validated
996+
997+
# Verify repeated dumps don't lose data (the original bug)
998+
assert m.model_dump()["data"]["items"] == list(expected_validated)
999+
assert m.model_dump()["data"]["items"] == list(expected_validated)
1000+
1001+
1002+
@pytest.mark.skipif(PYDANTIC_V1, reason="this is only supported in pydantic v2")
1003+
def test_iterable_construction_str_falls_back_to_list() -> None:
1004+
# str is iterable (over chars), but str(list_of_chars) produces the list's repr
1005+
# rather than reconstructing a string from items. We special-case str to fall
1006+
# back to list instead of attempting reconstruction.
1007+
class TypeWithIterable(TypedDict):
1008+
items: EagerIterable[str]
1009+
1010+
class Model(BaseModel):
1011+
data: TypeWithIterable
1012+
1013+
m = Model.model_validate({"data": {"items": "hello"}})
1014+
1015+
# falls back to list of chars rather than calling str(["h", "e", "l", "l", "o"])
1016+
assert m.data["items"] == ["h", "e", "l", "l", "o"]
1017+
assert m.model_dump()["data"]["items"] == ["h", "e", "l", "l", "o"]

0 commit comments

Comments
 (0)