From 7050bbb87940fd68506ae996275a88c8bbc02f18 Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Mon, 6 Apr 2026 11:01:03 -0700 Subject: [PATCH 1/6] Allow type class subclasses --- .../schemas/test_db_memory_serializer.py | 68 +++++ iceaxe/__tests__/test_base.py | 26 ++ iceaxe/__tests__/test_session.py | 39 +++ iceaxe/base.py | 13 + iceaxe/field.py | 9 +- iceaxe/schemas/db_memory_serializer.py | 45 ++-- iceaxe/session.py | 28 ++ iceaxe/typing.py | 245 ++++++++++++++++++ 8 files changed, 451 insertions(+), 22 deletions(-) diff --git a/iceaxe/__tests__/schemas/test_db_memory_serializer.py b/iceaxe/__tests__/schemas/test_db_memory_serializer.py index 6c4bb80..5094a37 100644 --- a/iceaxe/__tests__/schemas/test_db_memory_serializer.py +++ b/iceaxe/__tests__/schemas/test_db_memory_serializer.py @@ -736,6 +736,74 @@ class ExampleModel2(TableBase): ] +def test_simple_uuid_subclass_column_assignment(clear_all_database_objects): + class CustomUUID(UUID): + pass + + class ExampleModel(TableBase): + id: CustomUUID = Field(primary_key=True) + values: list[CustomUUID] + + migrator = DatabaseMemorySerializer() + db_objects = list(migrator.delegate([ExampleModel])) + assert db_objects == [ + ( + DBTable(table_name="examplemodel"), + [], + ), + ( + DBColumn( + table_name="examplemodel", + column_name="id", + column_type=ColumnType.UUID, + column_is_list=False, + nullable=False, + ), + [ + DBTable(table_name="examplemodel"), + ], + ), + ( + DBColumn( + table_name="examplemodel", + column_name="values", + column_type=ColumnType.UUID, + column_is_list=True, + nullable=False, + ), + [ + DBTable(table_name="examplemodel"), + ], + ), + ( + DBConstraint( + table_name="examplemodel", + constraint_name="examplemodel_pkey", + columns=frozenset({"id"}), + constraint_type=ConstraintType.PRIMARY_KEY, + foreign_key_constraint=None, + ), + [ + DBTable(table_name="examplemodel"), + DBColumn( + table_name="examplemodel", + column_name="id", + column_type=ColumnType.UUID, + column_is_list=False, + nullable=False, + ), + DBColumn( + table_name="examplemodel", + column_name="values", + column_type=ColumnType.UUID, + column_is_list=True, + nullable=False, + ), + ], + ), + ] + + @pytest.mark.asyncio @pytest.mark.parametrize( "field_name, annotation, field_info, expected_db_objects", diff --git a/iceaxe/__tests__/test_base.py b/iceaxe/__tests__/test_base.py index 72bc5ec..663dc9d 100644 --- a/iceaxe/__tests__/test_base.py +++ b/iceaxe/__tests__/test_base.py @@ -1,4 +1,5 @@ from typing import Annotated, Any, Generic, TypeVar, cast +from uuid import UUID from iceaxe.base import ( DBModelMetaclass, @@ -92,3 +93,28 @@ class Event(TableBase, autodetect=False): assert field.annotation == dict[str, Any] | None assert field.default is None assert field.is_json is True + + +def test_model_fields_with_simple_uuid_subclass(): + class CustomUUID(UUID): + pass + + class Event(TableBase, autodetect=False): + id: CustomUUID + maybe_id: CustomUUID | None = None + ids: list[CustomUUID] + + raw_uuid = UUID("12345678-1234-5678-1234-567812345678") + event = cast( + Any, + Event, + )( + id=raw_uuid, + maybe_id=str(raw_uuid), + ids=[raw_uuid, str(raw_uuid)], + ) + + assert event.model_fields["id"].annotation == CustomUUID + assert isinstance(event.id, CustomUUID) + assert isinstance(event.maybe_id, CustomUUID) + assert all(isinstance(value, CustomUUID) for value in event.ids) diff --git a/iceaxe/__tests__/test_session.py b/iceaxe/__tests__/test_session.py index 6bda421..aff1b02 100644 --- a/iceaxe/__tests__/test_session.py +++ b/iceaxe/__tests__/test_session.py @@ -3,6 +3,7 @@ from json import dumps as json_dumps, loads as json_loads from typing import Any, Type from unittest.mock import AsyncMock, patch +from uuid import UUID import asyncpg import pytest @@ -66,6 +67,44 @@ async def test_db_connection_update(db_connection: DBConnection): assert user.get_modified_attributes() == {} +@pytest.mark.asyncio +async def test_db_connection_uuid_subclass_round_trip( + db_connection: DBConnection, + clear_all_database_objects, +): + class CustomUUID(UUID): + pass + + class UUIDSubclassDemo(TableBase): + id: CustomUUID = Field(primary_key=True) + + await db_connection.conn.execute("DROP TABLE IF EXISTS uuidsubclassdemo") + await create_all(db_connection, [UUIDSubclassDemo]) + + row_id = CustomUUID("12345678-1234-5678-1234-567812345678") + demo = UUIDSubclassDemo(id=row_id) + await db_connection.insert([demo]) + + raw_row = await db_connection.conn.fetchrow( + "SELECT id FROM uuidsubclassdemo WHERE id = $1", + UUID(str(row_id)), + ) + assert raw_row is not None + assert isinstance(raw_row["id"], UUID) + + result = await db_connection.exec( + QueryBuilder().select(UUIDSubclassDemo).where(UUIDSubclassDemo.id == row_id) + ) + assert len(result) == 1 + assert isinstance(result[0].id, CustomUUID) + + selected_ids = await db_connection.exec( + QueryBuilder().select(UUIDSubclassDemo.id).where(UUIDSubclassDemo.id == row_id) + ) + assert selected_ids == [row_id] + assert isinstance(selected_ids[0], CustomUUID) + + @pytest.mark.asyncio async def test_db_obj_mixin_track_modifications(): user = UserDemo(name="John Doe", email="john@example.com") diff --git a/iceaxe/base.py b/iceaxe/base.py index b6c088c..70c4b58 100644 --- a/iceaxe/base.py +++ b/iceaxe/base.py @@ -13,6 +13,7 @@ from pydantic_core import PydanticUndefined from iceaxe.field import DBFieldClassDefinition, DBFieldInfo, Field +from iceaxe.typing import normalize_simple_subclass_annotation @dataclass_transform(kw_only_default=True, field_specifiers=(PydanticField,)) @@ -51,6 +52,18 @@ def __new__(mcs, name, bases, namespace, **kwargs): Create a new database model class with proper field tracking. Handles registration of the model and processes any table-specific arguments. """ + namespace = dict(namespace) + raw_annotations = namespace.get("__annotations__", {}) + if raw_annotations: + namespace["__annotations__"] = { + key: ( + normalize_simple_subclass_annotation(annotation) + if not isinstance(annotation, str) + else annotation + ) + for key, annotation in raw_annotations.items() + } + raw_kwargs = {**kwargs} mcs.is_constructing = True diff --git a/iceaxe/field.py b/iceaxe/field.py index 31e90bb..2f5b943 100644 --- a/iceaxe/field.py +++ b/iceaxe/field.py @@ -20,6 +20,7 @@ from iceaxe.postgres import PostgresFieldBase from iceaxe.queries_str import QueryIdentifier, QueryLiteral from iceaxe.sql_types import ColumnType +from iceaxe.typing import convert_value_from_db_storage, convert_value_to_db_storage if TYPE_CHECKING: from iceaxe.base import TableBase @@ -171,10 +172,16 @@ def extend_field( def to_db_value(self, value: Any): if self.is_json: return json_dumps(value) + if self.annotation is not None: + return convert_value_to_db_storage(value, self.annotation) return value def from_db_value(self, value: Any): - if not self.is_json or value is None: + if not self.is_json: + if self.annotation is None: + return value + return convert_value_from_db_storage(value, self.annotation) + if value is None: return value parsed_value = json_loads(value) if isinstance(value, str) else value diff --git a/iceaxe/schemas/db_memory_serializer.py b/iceaxe/schemas/db_memory_serializer.py index fbc9dec..54efe3e 100644 --- a/iceaxe/schemas/db_memory_serializer.py +++ b/iceaxe/schemas/db_memory_serializer.py @@ -51,6 +51,7 @@ DATE_TYPES, JSON_WRAPPER_FALLBACK, PRIMITIVE_WRAPPER_TYPES, + get_db_storage_annotation, ) NodeYieldType = Union[DBObject, DBObjectPointer, "NodeDefinition"] @@ -423,39 +424,41 @@ def handle_column_type(self, key: str, info: DBFieldInfo, table: Type[TableBase] ) annotation = remove_null_type(info.annotation) + storage_annotation, is_list = get_db_storage_annotation(annotation) # Resolve the type of the column, if generic - if isinstance(annotation, TypeVar): + if isinstance(storage_annotation, TypeVar): typevar_map = get_typevar_mapping(table) - annotation = typevar_map[annotation] + storage_annotation = typevar_map[storage_annotation] + storage_annotation, is_list = get_db_storage_annotation(storage_annotation) # Should be prioritized in terms of MRO; StrEnums should be processed # before the str types - if is_type_compatible(annotation, ALL_ENUM_TYPES): + if is_type_compatible(storage_annotation, ALL_ENUM_TYPES): # We only support string values for enums because postgres enums are defined # as name-based types - for value in annotation: # type: ignore + for value in storage_annotation: # type: ignore if not isinstance(value.value, str): raise ValueError( - f"Only string values are supported for enums, received: {value.value} (enum: {annotation})" + f"Only string values are supported for enums, received: {value.value} (enum: {storage_annotation})" ) return TypeDeclarationResponse( custom_type=DBType( - name=enum_to_name(annotation), # type: ignore - values=frozenset([value.value for value in annotation]), # type: ignore + name=enum_to_name(storage_annotation), # type: ignore + values=frozenset([value.value for value in storage_annotation]), # type: ignore reference_columns=frozenset({(table.get_table_name(), key)}), ), ) - elif is_type_compatible(annotation, PRIMITIVE_WRAPPER_TYPES): + elif is_type_compatible(storage_annotation, PRIMITIVE_WRAPPER_TYPES): for primitive, json_type in self.python_to_sql.items(): - if annotation == primitive or annotation == list[primitive]: # type: ignore + if storage_annotation == primitive: # type: ignore return TypeDeclarationResponse( primitive_type=json_type, - is_list=(annotation == list[primitive]), # type: ignore + is_list=is_list, ) - elif is_type_compatible(annotation, DATE_TYPES): - if is_type_compatible(annotation, datetime): # type: ignore + elif is_type_compatible(storage_annotation, DATE_TYPES): + if is_type_compatible(storage_annotation, datetime): # type: ignore if isinstance(info.postgres_config, PostgresDateTime): return TypeDeclarationResponse( primitive_type=( @@ -468,11 +471,11 @@ def handle_column_type(self, key: str, info: DBFieldInfo, table: Type[TableBase] return TypeDeclarationResponse( primitive_type=ColumnType.TIMESTAMP_WITHOUT_TIME_ZONE, ) - elif is_type_compatible(annotation, date): # type: ignore + elif is_type_compatible(storage_annotation, date): # type: ignore return TypeDeclarationResponse( primitive_type=ColumnType.DATE, ) - elif is_type_compatible(annotation, time): # type: ignore + elif is_type_compatible(storage_annotation, time): # type: ignore if isinstance(info.postgres_config, PostgresTime): return TypeDeclarationResponse( primitive_type=( @@ -484,34 +487,34 @@ def handle_column_type(self, key: str, info: DBFieldInfo, table: Type[TableBase] return TypeDeclarationResponse( primitive_type=ColumnType.TIME_WITHOUT_TIME_ZONE, ) - elif is_type_compatible(annotation, timedelta): # type: ignore + elif is_type_compatible(storage_annotation, timedelta): # type: ignore return TypeDeclarationResponse( primitive_type=ColumnType.INTERVAL, ) else: - raise ValueError(f"Unsupported date type: {annotation}") - elif is_type_compatible(annotation, BaseModel): + raise ValueError(f"Unsupported date type: {storage_annotation}") + elif is_type_compatible(storage_annotation, BaseModel): if info.is_json: return TypeDeclarationResponse( primitive_type=ColumnType.JSON, ) else: raise ValueError( - f"Pydantic model fields must have Field(is_json=True) specified: {annotation}\n" + f"Pydantic model fields must have Field(is_json=True) specified: {storage_annotation}\n" f"Column: {table.__name__}.{key}" ) - elif is_type_compatible(annotation, JSON_WRAPPER_FALLBACK): + elif is_type_compatible(storage_annotation, JSON_WRAPPER_FALLBACK): if info.is_json: return TypeDeclarationResponse( primitive_type=ColumnType.JSON, ) else: raise ValueError( - f"JSON fields must have Field(is_json=True) specified: {annotation}\n" + f"JSON fields must have Field(is_json=True) specified: {storage_annotation}\n" f"Column: {table.__name__}.{key}" ) - raise ValueError(f"Unsupported column type: {annotation}") + raise ValueError(f"Unsupported column type: {storage_annotation}") def handle_single_constraints( self, key: str, info: DBFieldInfo, table: Type[TableBase] diff --git a/iceaxe/session.py b/iceaxe/session.py index b3f5d6f..bbca08d 100644 --- a/iceaxe/session.py +++ b/iceaxe/session.py @@ -332,6 +332,34 @@ async def exec( ] result_all = optimize_exec_casting(values, query._select_raw, select_types) + column_cast_indices: list[int] = [] + for index, select_raw in enumerate(query._select_raw): + if is_column(select_raw) and not select_raw.field_definition.is_json: + column_cast_indices.append(index) + if column_cast_indices: + for result_index, row in enumerate(result_all): + if len(query._select_raw) == 1: + select_raw = cast( + DBFieldClassDefinition[Any], + query._select_raw[0], + ) + result_all[result_index] = ( + select_raw.field_definition.from_db_value(row) + ) + continue + + row_values = list(cast(tuple[Any, ...], row)) + for column_index in column_cast_indices: + select_raw = cast( + DBFieldClassDefinition[Any], + query._select_raw[column_index], + ) + row_values[column_index] = ( + select_raw.field_definition.from_db_value( + row_values[column_index] + ) + ) + result_all[result_index] = tuple(row_values) # Only loop through results if we have verbosity enabled, since this logic otherwise # is wasted if no content will eventually be logged diff --git a/iceaxe/typing.py b/iceaxe/typing.py index f4b94e5..9b71613 100644 --- a/iceaxe/typing.py +++ b/iceaxe/typing.py @@ -1,17 +1,25 @@ from __future__ import annotations +import types from datetime import date, datetime, time, timedelta from enum import Enum, IntEnum, StrEnum from inspect import isclass from typing import ( TYPE_CHECKING, + Annotated, Any, Type, TypeGuard, TypeVar, + Union, + get_args, + get_origin, ) from uuid import UUID +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + if TYPE_CHECKING: from iceaxe.alias_values import Alias from iceaxe.base import ( @@ -27,10 +35,247 @@ PRIMITIVE_WRAPPER_TYPES = list[PRIMITIVE_TYPES] | PRIMITIVE_TYPES DATE_TYPES = datetime | date | time | timedelta JSON_WRAPPER_FALLBACK = list[Any] | dict[Any, Any] +SIMPLE_SUBCLASS_BASE_TYPES = ( + datetime, + date, + time, + timedelta, + UUID, + bytes, + str, + int, + float, + bool, +) T = TypeVar("T") +class SimpleSubclassAnnotation: + def __init__(self, subtype: type[Any], base_type: type[Any]): + self.subtype = subtype + self.base_type = base_type + + def __get_pydantic_core_schema__( + self, + source_type: Any, + handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + schema = handler.generate_schema(self.base_type) + return core_schema.no_info_after_validator_function( + self._cast_value, + schema, + ) + + def _cast_value(self, value: Any): + return _coerce_simple_subclass_value(value, self.subtype) + + +def _is_union_type(annotation: Any) -> bool: + origin = get_origin(annotation) + return origin is Union or isinstance(annotation, types.UnionType) + + +def _rebuild_annotation(origin: Any, args: tuple[Any, ...]): + if not args: + return origin + if len(args) == 1: + item = args[0] + if hasattr(origin, "__class_getitem__"): + return origin.__class_getitem__(item) + return origin[item] + if hasattr(origin, "__class_getitem__"): + return origin.__class_getitem__(args) + return origin[args] + + +def unwrap_annotated(annotation: Any) -> Any: + while get_origin(annotation) is Annotated: + annotation = get_args(annotation)[0] + return annotation + + +def get_simple_subclass_base_type(annotation: Any) -> type[Any] | None: + annotation = unwrap_annotated(annotation) + if not isclass(annotation) or annotation in SIMPLE_SUBCLASS_BASE_TYPES: + return None + if issubclass(annotation, Enum): + return None + + mro = annotation.mro() + matches = [ + (mro.index(base_type), base_type) + for base_type in SIMPLE_SUBCLASS_BASE_TYPES + if base_type in mro + ] + if not matches: + return None + + return min(matches, key=lambda match: match[0])[1] + + +def normalize_simple_subclass_annotation(annotation: Any) -> Any: + origin = get_origin(annotation) + + if origin is Annotated: + inner, *metadata = get_args(annotation) + if any(isinstance(item, SimpleSubclassAnnotation) for item in metadata): + normalized_inner = normalize_simple_subclass_annotation(inner) + if normalized_inner == inner: + return annotation + return Annotated[normalized_inner, *metadata] + + base_type = get_simple_subclass_base_type(inner) + if base_type is not None: + return Annotated[ + inner, + *metadata, + SimpleSubclassAnnotation(inner, base_type), + ] + + normalized_inner = normalize_simple_subclass_annotation(inner) + if normalized_inner == inner: + return annotation + return Annotated[normalized_inner, *metadata] + + if _is_union_type(annotation): + normalized_args = tuple( + normalize_simple_subclass_annotation(arg) for arg in get_args(annotation) + ) + if normalized_args == get_args(annotation): + return annotation + return Union[normalized_args] # type: ignore + + if origin is not None: + normalized_args = tuple( + normalize_simple_subclass_annotation(arg) for arg in get_args(annotation) + ) + if normalized_args == get_args(annotation): + return annotation + return _rebuild_annotation(origin, normalized_args) + + base_type = get_simple_subclass_base_type(annotation) + if base_type is None: + return annotation + + return Annotated[annotation, SimpleSubclassAnnotation(annotation, base_type)] + + +def get_db_storage_annotation(annotation: Any) -> tuple[Any, bool]: + annotation = unwrap_annotated(annotation) + + if _is_union_type(annotation): + non_null_args = tuple( + arg + for arg in get_args(annotation) + if unwrap_annotated(arg) is not type(None) + ) + if len(non_null_args) == 1: + return get_db_storage_annotation(non_null_args[0]) + return annotation, False + + origin = get_origin(annotation) + if origin is list: + (value_type,) = get_args(annotation) + resolved_type, _ = get_db_storage_annotation(value_type) + return resolved_type, True + + base_type = get_simple_subclass_base_type(annotation) + if base_type is not None: + return base_type, False + return annotation, False + + +def convert_value_to_db_storage(value: Any, annotation: Any) -> Any: + return _convert_simple_subclass_value(value, annotation, to_db=True) + + +def convert_value_from_db_storage(value: Any, annotation: Any) -> Any: + return _convert_simple_subclass_value(value, annotation, to_db=False) + + +def _convert_simple_subclass_value(value: Any, annotation: Any, *, to_db: bool) -> Any: + if value is None: + return None + + annotation = unwrap_annotated(annotation) + if _is_union_type(annotation): + non_null_args = tuple( + arg + for arg in get_args(annotation) + if unwrap_annotated(arg) is not type(None) + ) + if len(non_null_args) == 1: + return _convert_simple_subclass_value( + value, + non_null_args[0], + to_db=to_db, + ) + return value + + origin = get_origin(annotation) + if origin is list: + (value_type,) = get_args(annotation) + return [ + _convert_simple_subclass_value(item, value_type, to_db=to_db) + for item in value + ] + + base_type = get_simple_subclass_base_type(annotation) + if base_type is None: + return value + + target_type = base_type if to_db else annotation + return _coerce_simple_subclass_value(value, target_type) + + +def _coerce_simple_subclass_value(value: Any, target_type: type[Any]) -> Any: + if type(value) is target_type: + return value + + if issubclass(target_type, UUID): + return target_type(str(value)) + + if issubclass(target_type, datetime): + return target_type( + value.year, + value.month, + value.day, + value.hour, + value.minute, + value.second, + value.microsecond, + tzinfo=value.tzinfo, + fold=value.fold, + ) + + if issubclass(target_type, date): + return target_type( + value.year, + value.month, + value.day, + ) + + if issubclass(target_type, time): + return target_type( + value.hour, + value.minute, + value.second, + value.microsecond, + tzinfo=value.tzinfo, + fold=value.fold, + ) + + if issubclass(target_type, timedelta): + return target_type( + days=value.days, + seconds=value.seconds, + microseconds=value.microseconds, + ) + + return target_type(value) + + def is_base_table(obj: Any) -> TypeGuard[type[TableBase]]: from iceaxe.base import TableBase From d979eeac20d5481c91d968eb5591eae0281ab51b Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Mon, 6 Apr 2026 11:22:51 -0700 Subject: [PATCH 2/6] Add docstrings --- iceaxe/base.py | 2 ++ iceaxe/session.py | 76 ++++++++++++++++++++++++++++++----------------- 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/iceaxe/base.py b/iceaxe/base.py index 70c4b58..b088fe6 100644 --- a/iceaxe/base.py +++ b/iceaxe/base.py @@ -52,6 +52,8 @@ def __new__(mcs, name, bases, namespace, **kwargs): Create a new database model class with proper field tracking. Handles registration of the model and processes any table-specific arguments. """ + # Pydantic must see these normalized annotations up front; otherwise simple + # subclasses like `CustomUUID(UUID)` are treated as unknown types. namespace = dict(namespace) raw_annotations = namespace.get("__annotations__", {}) if raw_annotations: diff --git a/iceaxe/session.py b/iceaxe/session.py index bbca08d..25544a1 100644 --- a/iceaxe/session.py +++ b/iceaxe/session.py @@ -219,6 +219,50 @@ def get_dsn(self) -> str: return "".join(dsn_parts) + def _cast_column_select_results( + self, + results: list[Any], + select_raw: Sequence[Any], + ) -> list[Any]: + """ + Apply field-level `from_db_value` casting to direct column selections. + + The optimized select path already constructs full table objects with each + field deserialized through the model metadata. Direct column selections are + intentionally left as raw driver values for speed, which means subclass-backed + fields like `CustomUUID(UUID)` come back as the base database type instead of + the annotated runtime type. This helper restores that field-level coercion for + non-JSON column selects while leaving table, alias, and function selections on + the existing fast path. + + """ + column_cast_indices: list[int] = [] + for index, raw_value in enumerate(select_raw): + if is_column(raw_value) and not raw_value.field_definition.is_json: + column_cast_indices.append(index) + + if not column_cast_indices: + return results + + for result_index, row in enumerate(results): + if len(select_raw) == 1: + raw_column = cast(DBFieldClassDefinition[Any], select_raw[0]) + results[result_index] = raw_column.field_definition.from_db_value(row) + continue + + row_values = list(cast(tuple[Any, ...], row)) + for column_index in column_cast_indices: + raw_column = cast( + DBFieldClassDefinition[Any], + select_raw[column_index], + ) + row_values[column_index] = raw_column.field_definition.from_db_value( + row_values[column_index] + ) + results[result_index] = tuple(row_values) + + return results + @asynccontextmanager async def transaction(self, *, ensure: bool = False): """ @@ -332,34 +376,10 @@ async def exec( ] result_all = optimize_exec_casting(values, query._select_raw, select_types) - column_cast_indices: list[int] = [] - for index, select_raw in enumerate(query._select_raw): - if is_column(select_raw) and not select_raw.field_definition.is_json: - column_cast_indices.append(index) - if column_cast_indices: - for result_index, row in enumerate(result_all): - if len(query._select_raw) == 1: - select_raw = cast( - DBFieldClassDefinition[Any], - query._select_raw[0], - ) - result_all[result_index] = ( - select_raw.field_definition.from_db_value(row) - ) - continue - - row_values = list(cast(tuple[Any, ...], row)) - for column_index in column_cast_indices: - select_raw = cast( - DBFieldClassDefinition[Any], - query._select_raw[column_index], - ) - row_values[column_index] = ( - select_raw.field_definition.from_db_value( - row_values[column_index] - ) - ) - result_all[result_index] = tuple(row_values) + result_all = self._cast_column_select_results( + result_all, + query._select_raw, + ) # Only loop through results if we have verbosity enabled, since this logic otherwise # is wasted if no content will eventually be logged From 52b2b65cd30c7a1c2b3689fa03ebc74fdb8d217f Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Mon, 6 Apr 2026 11:28:44 -0700 Subject: [PATCH 3/6] Refactor type resolver --- iceaxe/typing.py | 153 +++++++++++++++++++++++++---------------------- 1 file changed, 83 insertions(+), 70 deletions(-) diff --git a/iceaxe/typing.py b/iceaxe/typing.py index 9b71613..51d3a6a 100644 --- a/iceaxe/typing.py +++ b/iceaxe/typing.py @@ -1,6 +1,7 @@ from __future__ import annotations import types +from dataclasses import dataclass from datetime import date, datetime, time, timedelta from enum import Enum, IntEnum, StrEnum from inspect import isclass @@ -51,6 +52,15 @@ T = TypeVar("T") +@dataclass(frozen=True) +class ResolvedFieldAnnotation: + runtime_annotation: Any + storage_annotation: Any + is_list: bool + is_nullable: bool + is_simple_subclass: bool + + class SimpleSubclassAnnotation: def __init__(self, subtype: type[Any], base_type: type[Any]): self.subtype = subtype @@ -76,17 +86,16 @@ def _is_union_type(annotation: Any) -> bool: return origin is Union or isinstance(annotation, types.UnionType) -def _rebuild_annotation(origin: Any, args: tuple[Any, ...]): - if not args: - return origin - if len(args) == 1: - item = args[0] - if hasattr(origin, "__class_getitem__"): - return origin.__class_getitem__(item) - return origin[item] - if hasattr(origin, "__class_getitem__"): - return origin.__class_getitem__(args) - return origin[args] +def _rebuild_annotation(annotation: Any, args: tuple[Any, ...]): + if _is_union_type(annotation): + return Union[args] # type: ignore + + origin = get_origin(annotation) + if origin is None: + return annotation + + item = args[0] if len(args) == 1 else args + return origin[item] def unwrap_annotated(annotation: Any) -> Any: @@ -95,6 +104,19 @@ def unwrap_annotated(annotation: Any) -> Any: return annotation +def _get_optional_annotation_inner(annotation: Any) -> Any | None: + if not _is_union_type(annotation): + return None + + non_null_args = tuple( + arg for arg in get_args(annotation) if unwrap_annotated(arg) is not type(None) + ) + if len(non_null_args) != 1: + return None + + return non_null_args[0] + + def get_simple_subclass_base_type(annotation: Any) -> type[Any] | None: annotation = unwrap_annotated(annotation) if not isclass(annotation) or annotation in SIMPLE_SUBCLASS_BASE_TYPES: @@ -114,6 +136,39 @@ def get_simple_subclass_base_type(annotation: Any) -> type[Any] | None: return min(matches, key=lambda match: match[0])[1] +def resolve_field_annotation(annotation: Any) -> ResolvedFieldAnnotation: + is_list = False + is_nullable = False + current = annotation + + while True: + current = unwrap_annotated(current) + + optional_inner = _get_optional_annotation_inner(current) + if optional_inner is not None: + is_nullable = True + current = optional_inner + continue + + if not is_list and get_origin(current) is list: + (current,) = get_args(current) + is_list = True + continue + + break + + current = unwrap_annotated(current) + base_type = get_simple_subclass_base_type(current) + + return ResolvedFieldAnnotation( + runtime_annotation=current, + storage_annotation=base_type or current, + is_list=is_list, + is_nullable=is_nullable, + is_simple_subclass=base_type is not None, + ) + + def normalize_simple_subclass_annotation(annotation: Any) -> Any: origin = get_origin(annotation) @@ -138,52 +193,30 @@ def normalize_simple_subclass_annotation(annotation: Any) -> Any: return annotation return Annotated[normalized_inner, *metadata] - if _is_union_type(annotation): - normalized_args = tuple( - normalize_simple_subclass_annotation(arg) for arg in get_args(annotation) - ) - if normalized_args == get_args(annotation): - return annotation - return Union[normalized_args] # type: ignore - if origin is not None: normalized_args = tuple( normalize_simple_subclass_annotation(arg) for arg in get_args(annotation) ) if normalized_args == get_args(annotation): return annotation - return _rebuild_annotation(origin, normalized_args) + return _rebuild_annotation(annotation, normalized_args) - base_type = get_simple_subclass_base_type(annotation) - if base_type is None: + resolved = resolve_field_annotation(annotation) + if not resolved.is_simple_subclass: return annotation - return Annotated[annotation, SimpleSubclassAnnotation(annotation, base_type)] + return Annotated[ + annotation, + SimpleSubclassAnnotation( + resolved.runtime_annotation, + resolved.storage_annotation, + ), + ] def get_db_storage_annotation(annotation: Any) -> tuple[Any, bool]: - annotation = unwrap_annotated(annotation) - - if _is_union_type(annotation): - non_null_args = tuple( - arg - for arg in get_args(annotation) - if unwrap_annotated(arg) is not type(None) - ) - if len(non_null_args) == 1: - return get_db_storage_annotation(non_null_args[0]) - return annotation, False - - origin = get_origin(annotation) - if origin is list: - (value_type,) = get_args(annotation) - resolved_type, _ = get_db_storage_annotation(value_type) - return resolved_type, True - - base_type = get_simple_subclass_base_type(annotation) - if base_type is not None: - return base_type, False - return annotation, False + resolved = resolve_field_annotation(annotation) + return resolved.storage_annotation, resolved.is_list def convert_value_to_db_storage(value: Any, annotation: Any) -> Any: @@ -198,34 +231,14 @@ def _convert_simple_subclass_value(value: Any, annotation: Any, *, to_db: bool) if value is None: return None - annotation = unwrap_annotated(annotation) - if _is_union_type(annotation): - non_null_args = tuple( - arg - for arg in get_args(annotation) - if unwrap_annotated(arg) is not type(None) - ) - if len(non_null_args) == 1: - return _convert_simple_subclass_value( - value, - non_null_args[0], - to_db=to_db, - ) + resolved = resolve_field_annotation(annotation) + if not resolved.is_simple_subclass: return value - origin = get_origin(annotation) - if origin is list: - (value_type,) = get_args(annotation) - return [ - _convert_simple_subclass_value(item, value_type, to_db=to_db) - for item in value - ] - - base_type = get_simple_subclass_base_type(annotation) - if base_type is None: - return value + target_type = resolved.storage_annotation if to_db else resolved.runtime_annotation + if resolved.is_list: + return [_coerce_simple_subclass_value(item, target_type) for item in value] - target_type = base_type if to_db else annotation return _coerce_simple_subclass_value(value, target_type) From 85494403495a792bf74e989d5ab5ff4fff46ac74 Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Mon, 6 Apr 2026 11:57:41 -0700 Subject: [PATCH 4/6] Refactor custom typehints --- iceaxe/base.py | 5 +- iceaxe/custom_typehints.py | 208 ++++++++++++++++++++++ iceaxe/field.py | 6 +- iceaxe/schemas/db_memory_serializer.py | 17 +- iceaxe/typing.py | 234 +++++-------------------- 5 files changed, 274 insertions(+), 196 deletions(-) create mode 100644 iceaxe/custom_typehints.py diff --git a/iceaxe/base.py b/iceaxe/base.py index b088fe6..db3743c 100644 --- a/iceaxe/base.py +++ b/iceaxe/base.py @@ -12,8 +12,9 @@ from pydantic.main import _model_construction from pydantic_core import PydanticUndefined +from iceaxe.custom_typehints import wrap_simple_subclass_annotation from iceaxe.field import DBFieldClassDefinition, DBFieldInfo, Field -from iceaxe.typing import normalize_simple_subclass_annotation +from iceaxe.typing import transform_typehint @dataclass_transform(kw_only_default=True, field_specifiers=(PydanticField,)) @@ -59,7 +60,7 @@ def __new__(mcs, name, bases, namespace, **kwargs): if raw_annotations: namespace["__annotations__"] = { key: ( - normalize_simple_subclass_annotation(annotation) + transform_typehint(annotation, wrap_simple_subclass_annotation) if not isinstance(annotation, str) else annotation ) diff --git a/iceaxe/custom_typehints.py b/iceaxe/custom_typehints.py new file mode 100644 index 0000000..c7c7937 --- /dev/null +++ b/iceaxe/custom_typehints.py @@ -0,0 +1,208 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from enum import Enum +from inspect import isclass +from typing import Annotated, Any, Literal, assert_never, get_args, get_origin +from uuid import UUID + +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + +from iceaxe.typing import resolve_typehint, unwrap_annotated + +# This literal definition seems overly verbose for what we're doing (ie. defining the types +# of simple subclasses for which we support type coercion). But it's required so we can properly +# throw a type error during static analysis if we don't properly support a handler for one +# of these supported types. +SimpleSubclassKind = Literal[ + "datetime", + "date", + "time", + "timedelta", + "uuid", + "bytes", + "str", + "int", + "float", + "bool", +] + +SIMPLE_SUBCLASS_BASE_TYPES_BY_KIND: dict[SimpleSubclassKind, type[Any]] = { + "datetime": datetime, + "date": date, + "time": time, + "timedelta": timedelta, + "uuid": UUID, + "bytes": bytes, + "str": str, + "int": int, + "float": float, + "bool": bool, +} +SIMPLE_SUBCLASS_BASE_TYPES = tuple(SIMPLE_SUBCLASS_BASE_TYPES_BY_KIND.values()) + + +class SimpleSubclassAnnotation: + """ + Pydantic metadata wrapper for "validate as base type, return subclass". + + Simple subclasses such as `CustomUUID(UUID)` are structurally compatible + with their parent type for database storage, but Pydantic does not know how + to build a schema for those subclasses by default. If Iceaxe passes the raw + subclass annotation through unchanged, model construction fails because + Pydantic treats it as an unknown arbitrary type. + + This metadata object is attached via `Annotated[...]` during + `transform_typehint(..., wrap_simple_subclass_annotation)`. When Pydantic + sees that annotation, + it calls `__get_pydantic_core_schema__`, which lets us: + - reuse the existing schema for the storage/base type (`UUID`, `date`, etc.) + - keep all of Pydantic's normal parsing behavior for that base type + - run one final post-validation cast that reconstructs the requested subclass + + The important constraint is that this only works for "simple" subclasses + whose runtime value can be losslessly rebuilt from the validated base value. + We are not defining a brand new schema here; we are explicitly piggybacking + on the parent type's schema and restoring the subclass identity afterward. + + """ + + def __init__(self, subtype: type[Any], base_type: type[Any]): + self.subtype = subtype + self.base_type = base_type + + def __get_pydantic_core_schema__( + self, + source_type: Any, + handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + schema = handler.generate_schema(self.base_type) + return core_schema.no_info_after_validator_function( + self._cast_value, + schema, + ) + + def _cast_value(self, value: Any): + return coerce_simple_subclass_value(value, self.subtype) + + +def wrap_simple_subclass_annotation(annotation: Any) -> Any: + if get_origin(annotation) is Annotated: + inner, *metadata = get_args(annotation) + if any(isinstance(item, SimpleSubclassAnnotation) for item in metadata): + return annotation + + base_type = get_simple_subclass_base_type(inner) + if base_type is None: + return annotation + + return Annotated[ + inner, + *metadata, + SimpleSubclassAnnotation(inner, base_type), + ] + + base_type = get_simple_subclass_base_type(annotation) + if base_type is None: + return annotation + + return Annotated[annotation, SimpleSubclassAnnotation(annotation, base_type)] + + +def get_simple_subclass_base_type(annotation: Any) -> type[Any] | None: + annotation = unwrap_annotated(annotation) + kind = get_simple_subclass_kind(annotation) + if kind is None: + return None + + base_type = SIMPLE_SUBCLASS_BASE_TYPES_BY_KIND[kind] + if annotation is base_type: + return None + + return base_type + + +def get_simple_subclass_kind(annotation: Any) -> SimpleSubclassKind | None: + annotation = unwrap_annotated(annotation) + if not isclass(annotation): + return None + if issubclass(annotation, Enum): + return None + + mro = annotation.mro() + matches: list[tuple[int, SimpleSubclassKind]] = [ + (mro.index(base_type), kind) + for kind, base_type in SIMPLE_SUBCLASS_BASE_TYPES_BY_KIND.items() + if base_type in mro + ] + if not matches: + return None + + return min(matches, key=lambda match: match[0])[1] + + +def convert_simple_subclass_value(value: Any, annotation: Any, *, to_db: bool) -> Any: + if value is None: + return None + + resolved = resolve_typehint(annotation) + storage_type = get_simple_subclass_base_type(resolved.runtime_type) + if storage_type is None: + return value + + target_type = storage_type if to_db else resolved.runtime_type + if resolved.is_list: + return [coerce_simple_subclass_value(item, target_type) for item in value] + + return coerce_simple_subclass_value(value, target_type) + + +def coerce_simple_subclass_value(value: Any, target_type: type[Any]) -> Any: + if type(value) is target_type: + return value + + kind = get_simple_subclass_kind(target_type) + if kind is None: + raise TypeError(f"Unsupported simple subclass target type: {target_type}") + + match kind: + case "uuid": + return target_type(str(value)) + case "datetime": + return target_type( + value.year, + value.month, + value.day, + value.hour, + value.minute, + value.second, + value.microsecond, + tzinfo=value.tzinfo, + fold=value.fold, + ) + case "date": + return target_type( + value.year, + value.month, + value.day, + ) + case "time": + return target_type( + value.hour, + value.minute, + value.second, + value.microsecond, + tzinfo=value.tzinfo, + fold=value.fold, + ) + case "timedelta": + return target_type( + days=value.days, + seconds=value.seconds, + microseconds=value.microseconds, + ) + case "bytes" | "str" | "int" | "float" | "bool": + return target_type(value) + case unexpected: + assert_never(unexpected) diff --git a/iceaxe/field.py b/iceaxe/field.py index 2f5b943..74269db 100644 --- a/iceaxe/field.py +++ b/iceaxe/field.py @@ -17,10 +17,10 @@ from pydantic_core import PydanticUndefined from iceaxe.comparison import ComparisonBase +from iceaxe.custom_typehints import convert_simple_subclass_value from iceaxe.postgres import PostgresFieldBase from iceaxe.queries_str import QueryIdentifier, QueryLiteral from iceaxe.sql_types import ColumnType -from iceaxe.typing import convert_value_from_db_storage, convert_value_to_db_storage if TYPE_CHECKING: from iceaxe.base import TableBase @@ -173,14 +173,14 @@ def to_db_value(self, value: Any): if self.is_json: return json_dumps(value) if self.annotation is not None: - return convert_value_to_db_storage(value, self.annotation) + return convert_simple_subclass_value(value, self.annotation, to_db=True) return value def from_db_value(self, value: Any): if not self.is_json: if self.annotation is None: return value - return convert_value_from_db_storage(value, self.annotation) + return convert_simple_subclass_value(value, self.annotation, to_db=False) if value is None: return value diff --git a/iceaxe/schemas/db_memory_serializer.py b/iceaxe/schemas/db_memory_serializer.py index 54efe3e..8802a5a 100644 --- a/iceaxe/schemas/db_memory_serializer.py +++ b/iceaxe/schemas/db_memory_serializer.py @@ -14,6 +14,7 @@ TableBase, UniqueConstraint, ) +from iceaxe.custom_typehints import get_simple_subclass_base_type from iceaxe.generics import ( get_typevar_mapping, has_null_type, @@ -51,7 +52,7 @@ DATE_TYPES, JSON_WRAPPER_FALLBACK, PRIMITIVE_WRAPPER_TYPES, - get_db_storage_annotation, + resolve_typehint, ) NodeYieldType = Union[DBObject, DBObjectPointer, "NodeDefinition"] @@ -424,13 +425,23 @@ def handle_column_type(self, key: str, info: DBFieldInfo, table: Type[TableBase] ) annotation = remove_null_type(info.annotation) - storage_annotation, is_list = get_db_storage_annotation(annotation) + resolved_annotation = resolve_typehint(annotation) + storage_annotation = ( + get_simple_subclass_base_type(resolved_annotation.runtime_type) + or resolved_annotation.runtime_type + ) + is_list = resolved_annotation.is_list # Resolve the type of the column, if generic if isinstance(storage_annotation, TypeVar): typevar_map = get_typevar_mapping(table) storage_annotation = typevar_map[storage_annotation] - storage_annotation, is_list = get_db_storage_annotation(storage_annotation) + resolved_annotation = resolve_typehint(storage_annotation) + storage_annotation = ( + get_simple_subclass_base_type(resolved_annotation.runtime_type) + or resolved_annotation.runtime_type + ) + is_list = resolved_annotation.is_list # Should be prioritized in terms of MRO; StrEnums should be processed # before the str types diff --git a/iceaxe/typing.py b/iceaxe/typing.py index 51d3a6a..ed4471e 100644 --- a/iceaxe/typing.py +++ b/iceaxe/typing.py @@ -9,6 +9,7 @@ TYPE_CHECKING, Annotated, Any, + Callable, Type, TypeGuard, TypeVar, @@ -18,9 +19,6 @@ ) from uuid import UUID -from pydantic import GetCoreSchemaHandler -from pydantic_core import core_schema - if TYPE_CHECKING: from iceaxe.alias_values import Alias from iceaxe.base import ( @@ -36,58 +34,23 @@ PRIMITIVE_WRAPPER_TYPES = list[PRIMITIVE_TYPES] | PRIMITIVE_TYPES DATE_TYPES = datetime | date | time | timedelta JSON_WRAPPER_FALLBACK = list[Any] | dict[Any, Any] -SIMPLE_SUBCLASS_BASE_TYPES = ( - datetime, - date, - time, - timedelta, - UUID, - bytes, - str, - int, - float, - bool, -) T = TypeVar("T") @dataclass(frozen=True) -class ResolvedFieldAnnotation: - runtime_annotation: Any - storage_annotation: Any +class ResolvedTypehint: + runtime_type: Any is_list: bool - is_nullable: bool - is_simple_subclass: bool - - -class SimpleSubclassAnnotation: - def __init__(self, subtype: type[Any], base_type: type[Any]): - self.subtype = subtype - self.base_type = base_type - def __get_pydantic_core_schema__( - self, - source_type: Any, - handler: GetCoreSchemaHandler, - ) -> core_schema.CoreSchema: - schema = handler.generate_schema(self.base_type) - return core_schema.no_info_after_validator_function( - self._cast_value, - schema, - ) - def _cast_value(self, value: Any): - return _coerce_simple_subclass_value(value, self.subtype) - - -def _is_union_type(annotation: Any) -> bool: +def is_union_type(annotation: Any) -> bool: origin = get_origin(annotation) return origin is Union or isinstance(annotation, types.UnionType) -def _rebuild_annotation(annotation: Any, args: tuple[Any, ...]): - if _is_union_type(annotation): +def rebuild_typehint(annotation: Any, args: tuple[Any, ...]): + if is_union_type(annotation): return Union[args] # type: ignore origin = get_origin(annotation) @@ -104,8 +67,8 @@ def unwrap_annotated(annotation: Any) -> Any: return annotation -def _get_optional_annotation_inner(annotation: Any) -> Any | None: - if not _is_union_type(annotation): +def get_optional_inner(annotation: Any) -> Any | None: + if not is_union_type(annotation): return None non_null_args = tuple( @@ -117,36 +80,40 @@ def _get_optional_annotation_inner(annotation: Any) -> Any | None: return non_null_args[0] -def get_simple_subclass_base_type(annotation: Any) -> type[Any] | None: - annotation = unwrap_annotated(annotation) - if not isclass(annotation) or annotation in SIMPLE_SUBCLASS_BASE_TYPES: - return None - if issubclass(annotation, Enum): - return None - - mro = annotation.mro() - matches = [ - (mro.index(base_type), base_type) - for base_type in SIMPLE_SUBCLASS_BASE_TYPES - if base_type in mro - ] - if not matches: - return None - - return min(matches, key=lambda match: match[0])[1] - - -def resolve_field_annotation(annotation: Any) -> ResolvedFieldAnnotation: - is_list = False - is_nullable = False +def resolve_typehint(annotation: Any) -> ResolvedTypehint: + """ + Normalize a field annotation into the subset of typing structure Iceaxe + needs for runtime coercion and schema inference. + + Python annotations can describe the same logical field in several wrapped + forms: + - `Annotated[T, ...]` carries metadata but doesn't change the core type. + - `T | None` / `Optional[T]` is represented as a union and needs to be + unwrapped to reach the concrete value type. + - `list[T]` means the ORM should treat the column as an array while still + reasoning about the element type `T`. + + Callers that need to infer database/storage behavior should not each have to + reimplement `get_origin()` / `get_args()` handling or care about the exact + runtime shape Python uses for unions and annotated metadata. This helper + resolves those wrappers into a canonical form: + - `runtime_type`: the innermost non-`Annotated`, non-nullable element type + - `is_list`: whether the annotation represents a top-level `list[...]` + + The resolver is intentionally narrow. It understands the container/wrapper + shapes Iceaxe needs structurally, but it does not try to semantically reduce + arbitrary generic types. For example, nested generics are preserved inside + `runtime_type` once the top-level list/optional wrappers have been handled. + + """ current = annotation + is_list = False while True: current = unwrap_annotated(current) - optional_inner = _get_optional_annotation_inner(current) + optional_inner = get_optional_inner(current) if optional_inner is not None: - is_nullable = True current = optional_inner continue @@ -157,136 +124,27 @@ def resolve_field_annotation(annotation: Any) -> ResolvedFieldAnnotation: break - current = unwrap_annotated(current) - base_type = get_simple_subclass_base_type(current) - - return ResolvedFieldAnnotation( - runtime_annotation=current, - storage_annotation=base_type or current, + return ResolvedTypehint( + runtime_type=unwrap_annotated(current), is_list=is_list, - is_nullable=is_nullable, - is_simple_subclass=base_type is not None, ) -def normalize_simple_subclass_annotation(annotation: Any) -> Any: +def transform_typehint( + annotation: Any, + transform: Callable[[Any], Any], +) -> Any: origin = get_origin(annotation) if origin is Annotated: inner, *metadata = get_args(annotation) - if any(isinstance(item, SimpleSubclassAnnotation) for item in metadata): - normalized_inner = normalize_simple_subclass_annotation(inner) - if normalized_inner == inner: - return annotation - return Annotated[normalized_inner, *metadata] - - base_type = get_simple_subclass_base_type(inner) - if base_type is not None: - return Annotated[ - inner, - *metadata, - SimpleSubclassAnnotation(inner, base_type), - ] - - normalized_inner = normalize_simple_subclass_annotation(inner) - if normalized_inner == inner: - return annotation - return Annotated[normalized_inner, *metadata] + return transform(Annotated[transform_typehint(inner, transform), *metadata]) if origin is not None: - normalized_args = tuple( - normalize_simple_subclass_annotation(arg) for arg in get_args(annotation) - ) - if normalized_args == get_args(annotation): - return annotation - return _rebuild_annotation(annotation, normalized_args) - - resolved = resolve_field_annotation(annotation) - if not resolved.is_simple_subclass: - return annotation - - return Annotated[ - annotation, - SimpleSubclassAnnotation( - resolved.runtime_annotation, - resolved.storage_annotation, - ), - ] - - -def get_db_storage_annotation(annotation: Any) -> tuple[Any, bool]: - resolved = resolve_field_annotation(annotation) - return resolved.storage_annotation, resolved.is_list - - -def convert_value_to_db_storage(value: Any, annotation: Any) -> Any: - return _convert_simple_subclass_value(value, annotation, to_db=True) - - -def convert_value_from_db_storage(value: Any, annotation: Any) -> Any: - return _convert_simple_subclass_value(value, annotation, to_db=False) - - -def _convert_simple_subclass_value(value: Any, annotation: Any, *, to_db: bool) -> Any: - if value is None: - return None + args = tuple(transform_typehint(arg, transform) for arg in get_args(annotation)) + return transform(rebuild_typehint(annotation, args)) - resolved = resolve_field_annotation(annotation) - if not resolved.is_simple_subclass: - return value - - target_type = resolved.storage_annotation if to_db else resolved.runtime_annotation - if resolved.is_list: - return [_coerce_simple_subclass_value(item, target_type) for item in value] - - return _coerce_simple_subclass_value(value, target_type) - - -def _coerce_simple_subclass_value(value: Any, target_type: type[Any]) -> Any: - if type(value) is target_type: - return value - - if issubclass(target_type, UUID): - return target_type(str(value)) - - if issubclass(target_type, datetime): - return target_type( - value.year, - value.month, - value.day, - value.hour, - value.minute, - value.second, - value.microsecond, - tzinfo=value.tzinfo, - fold=value.fold, - ) - - if issubclass(target_type, date): - return target_type( - value.year, - value.month, - value.day, - ) - - if issubclass(target_type, time): - return target_type( - value.hour, - value.minute, - value.second, - value.microsecond, - tzinfo=value.tzinfo, - fold=value.fold, - ) - - if issubclass(target_type, timedelta): - return target_type( - days=value.days, - seconds=value.seconds, - microseconds=value.microseconds, - ) - - return target_type(value) + return transform(annotation) def is_base_table(obj: Any) -> TypeGuard[type[TableBase]]: From 9e18057f97c06949cd38f530c61641086ecde7f6 Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Mon, 6 Apr 2026 11:58:48 -0700 Subject: [PATCH 5/6] Add comments --- iceaxe/typing.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/iceaxe/typing.py b/iceaxe/typing.py index ed4471e..f3b037d 100644 --- a/iceaxe/typing.py +++ b/iceaxe/typing.py @@ -38,10 +38,9 @@ T = TypeVar("T") -@dataclass(frozen=True) -class ResolvedTypehint: - runtime_type: Any - is_list: bool +# +# Simple type utility function +# def is_union_type(annotation: Any) -> bool: @@ -80,6 +79,17 @@ def get_optional_inner(annotation: Any) -> Any | None: return non_null_args[0] +# +# Type introspection +# + + +@dataclass(frozen=True) +class ResolvedTypehint: + runtime_type: Any + is_list: bool + + def resolve_typehint(annotation: Any) -> ResolvedTypehint: """ Normalize a field annotation into the subset of typing structure Iceaxe @@ -147,6 +157,11 @@ def transform_typehint( return transform(annotation) +# +# Typeguards +# + + def is_base_table(obj: Any) -> TypeGuard[type[TableBase]]: from iceaxe.base import TableBase From ad192fd97005edd52479d13132468bad6ce6629f Mon Sep 17 00:00:00 2001 From: Pierce Freeman Date: Mon, 6 Apr 2026 12:06:02 -0700 Subject: [PATCH 6/6] Add comments --- iceaxe/custom_typehints.py | 50 +++++++++++++++++++++++++++++++++++--- iceaxe/typing.py | 27 ++++++++++++++++++++ 2 files changed, 73 insertions(+), 4 deletions(-) diff --git a/iceaxe/custom_typehints.py b/iceaxe/custom_typehints.py index c7c7937..2e9aadc 100644 --- a/iceaxe/custom_typehints.py +++ b/iceaxe/custom_typehints.py @@ -84,10 +84,25 @@ def __get_pydantic_core_schema__( ) def _cast_value(self, value: Any): - return coerce_simple_subclass_value(value, self.subtype) + return coerce_single_subclass_value(value, self.subtype) def wrap_simple_subclass_annotation(annotation: Any) -> Any: + """ + Wrap one annotation node with `SimpleSubclassAnnotation` when it represents + a supported simple subclass. + + This function is intentionally small in scope: it does not walk nested type + structures itself. Instead it is designed to be passed into + `transform_typehint`, which recursively traverses unions, `Annotated`, and + other generic wrappers and invokes this function on each node. At each node + we decide whether the current type needs Pydantic metadata so it can be + validated as its base storage type and then reconstructed as the subclass. + + If the node is already wrapped, or if it is not one of the supported + subclass shapes, the annotation is returned unchanged. + + """ if get_origin(annotation) is Annotated: inner, *metadata = get_args(annotation) if any(isinstance(item, SimpleSubclassAnnotation) for item in metadata): @@ -111,6 +126,19 @@ def wrap_simple_subclass_annotation(annotation: Any) -> Any: def get_simple_subclass_base_type(annotation: Any) -> type[Any] | None: + """ + Resolve the storage/base type for a supported simple subclass annotation. + + For a value type like `CustomUUID(UUID)`, the runtime annotation is the + subclass but the storage behavior should follow `UUID`. This helper maps the + subclass back to that base type so callers can reason about database storage + and coercion without losing track of the original runtime type. + + The return value is intentionally `None` for non-subclasses and for base + types themselves. That lets callers distinguish "this annotation should be + treated specially" from "this is already a normal built-in/base type". + + """ annotation = unwrap_annotated(annotation) kind = get_simple_subclass_kind(annotation) if kind is None: @@ -124,6 +152,20 @@ def get_simple_subclass_base_type(annotation: Any) -> type[Any] | None: def get_simple_subclass_kind(annotation: Any) -> SimpleSubclassKind | None: + """ + Classify an annotation into one of the supported simple-subclass families. + + The subclass feature only supports a bounded set of base runtime/storage + types, captured in `SIMPLE_SUBCLASS_BASE_TYPES_BY_KIND`. Rather than + repeatedly branching on `issubclass(..., UUID)` / `issubclass(..., date)` in + multiple places, we first collapse a candidate type into one stable literal + kind. Downstream code can then switch on that kind and get both clearer + control flow and exhaustiveness checking from static analysis. + + The function returns `None` for values that are not classes, enums, or do + not inherit from one of the supported base types. + + """ annotation = unwrap_annotated(annotation) if not isclass(annotation): return None @@ -153,12 +195,12 @@ def convert_simple_subclass_value(value: Any, annotation: Any, *, to_db: bool) - target_type = storage_type if to_db else resolved.runtime_type if resolved.is_list: - return [coerce_simple_subclass_value(item, target_type) for item in value] + return [coerce_single_subclass_value(item, target_type) for item in value] - return coerce_simple_subclass_value(value, target_type) + return coerce_single_subclass_value(value, target_type) -def coerce_simple_subclass_value(value: Any, target_type: type[Any]) -> Any: +def coerce_single_subclass_value(value: Any, target_type: type[Any]) -> Any: if type(value) is target_type: return value diff --git a/iceaxe/typing.py b/iceaxe/typing.py index f3b037d..422cc71 100644 --- a/iceaxe/typing.py +++ b/iceaxe/typing.py @@ -144,6 +144,33 @@ def transform_typehint( annotation: Any, transform: Callable[[Any], Any], ) -> Any: + """ + Recursively rebuild an annotation tree while applying a callback to each node. + + Python type hints are often nested combinations of wrappers such as + `Annotated[...]`, unions, and container generics. Callers sometimes need to + inject or rewrite metadata inside that structure without losing the overall + typing shape. This helper performs that traversal once and hands each rebuilt + node to `transform`, allowing feature-specific code to focus on "what should + this node become?" rather than repeatedly reimplementing `get_origin()` / + `get_args()` recursion. + + Some examples of the supported traversal behavior: + - `CustomUUID | None` visits `CustomUUID`, applies the transform there, and + then rebuilds the nullable union around the transformed result. + - `list[CustomUUID]` visits `CustomUUID`, applies the transform there, and + then rebuilds the outer list as `list[]`. + - `dict[str, CustomUUID]` preserves the `dict[str, ...]` shape while still + transforming the nested value type. + - `Annotated[list[CustomUUID], Meta()]` first transforms the inner + `list[CustomUUID]`, then rebuilds the `Annotated[...]` wrapper with the + original metadata still attached. + + In all of these cases, child nodes are transformed before their parent + wrapper is rebuilt. That lets callers inspect the already-normalized inner + annotation when they receive a parent node such as `Annotated[...]`. + + """ origin = get_origin(annotation) if origin is Annotated: