diff --git a/iceaxe/__tests__/schemas/test_db_memory_serializer.py b/iceaxe/__tests__/schemas/test_db_memory_serializer.py index 6c4bb80..5094a37 100644 --- a/iceaxe/__tests__/schemas/test_db_memory_serializer.py +++ b/iceaxe/__tests__/schemas/test_db_memory_serializer.py @@ -736,6 +736,74 @@ class ExampleModel2(TableBase): ] +def test_simple_uuid_subclass_column_assignment(clear_all_database_objects): + class CustomUUID(UUID): + pass + + class ExampleModel(TableBase): + id: CustomUUID = Field(primary_key=True) + values: list[CustomUUID] + + migrator = DatabaseMemorySerializer() + db_objects = list(migrator.delegate([ExampleModel])) + assert db_objects == [ + ( + DBTable(table_name="examplemodel"), + [], + ), + ( + DBColumn( + table_name="examplemodel", + column_name="id", + column_type=ColumnType.UUID, + column_is_list=False, + nullable=False, + ), + [ + DBTable(table_name="examplemodel"), + ], + ), + ( + DBColumn( + table_name="examplemodel", + column_name="values", + column_type=ColumnType.UUID, + column_is_list=True, + nullable=False, + ), + [ + DBTable(table_name="examplemodel"), + ], + ), + ( + DBConstraint( + table_name="examplemodel", + constraint_name="examplemodel_pkey", + columns=frozenset({"id"}), + constraint_type=ConstraintType.PRIMARY_KEY, + foreign_key_constraint=None, + ), + [ + DBTable(table_name="examplemodel"), + DBColumn( + table_name="examplemodel", + column_name="id", + column_type=ColumnType.UUID, + column_is_list=False, + nullable=False, + ), + DBColumn( + table_name="examplemodel", + column_name="values", + column_type=ColumnType.UUID, + column_is_list=True, + nullable=False, + ), + ], + ), + ] + + @pytest.mark.asyncio @pytest.mark.parametrize( "field_name, annotation, field_info, expected_db_objects", diff --git a/iceaxe/__tests__/test_base.py b/iceaxe/__tests__/test_base.py index 72bc5ec..663dc9d 100644 --- a/iceaxe/__tests__/test_base.py +++ b/iceaxe/__tests__/test_base.py @@ -1,4 +1,5 @@ from typing import Annotated, Any, Generic, TypeVar, cast +from uuid import UUID from iceaxe.base import ( DBModelMetaclass, @@ -92,3 +93,28 @@ class Event(TableBase, autodetect=False): assert field.annotation == dict[str, Any] | None assert field.default is None assert field.is_json is True + + +def test_model_fields_with_simple_uuid_subclass(): + class CustomUUID(UUID): + pass + + class Event(TableBase, autodetect=False): + id: CustomUUID + maybe_id: CustomUUID | None = None + ids: list[CustomUUID] + + raw_uuid = UUID("12345678-1234-5678-1234-567812345678") + event = cast( + Any, + Event, + )( + id=raw_uuid, + maybe_id=str(raw_uuid), + ids=[raw_uuid, str(raw_uuid)], + ) + + assert event.model_fields["id"].annotation == CustomUUID + assert isinstance(event.id, CustomUUID) + assert isinstance(event.maybe_id, CustomUUID) + assert all(isinstance(value, CustomUUID) for value in event.ids) diff --git a/iceaxe/__tests__/test_session.py b/iceaxe/__tests__/test_session.py index 6bda421..aff1b02 100644 --- a/iceaxe/__tests__/test_session.py +++ b/iceaxe/__tests__/test_session.py @@ -3,6 +3,7 @@ from json import dumps as json_dumps, loads as json_loads from typing import Any, Type from unittest.mock import AsyncMock, patch +from uuid import UUID import asyncpg import pytest @@ -66,6 +67,44 @@ async def test_db_connection_update(db_connection: DBConnection): assert user.get_modified_attributes() == {} +@pytest.mark.asyncio +async def test_db_connection_uuid_subclass_round_trip( + db_connection: DBConnection, + clear_all_database_objects, +): + class CustomUUID(UUID): + pass + + class UUIDSubclassDemo(TableBase): + id: CustomUUID = Field(primary_key=True) + + await db_connection.conn.execute("DROP TABLE IF EXISTS uuidsubclassdemo") + await create_all(db_connection, [UUIDSubclassDemo]) + + row_id = CustomUUID("12345678-1234-5678-1234-567812345678") + demo = UUIDSubclassDemo(id=row_id) + await db_connection.insert([demo]) + + raw_row = await db_connection.conn.fetchrow( + "SELECT id FROM uuidsubclassdemo WHERE id = $1", + UUID(str(row_id)), + ) + assert raw_row is not None + assert isinstance(raw_row["id"], UUID) + + result = await db_connection.exec( + QueryBuilder().select(UUIDSubclassDemo).where(UUIDSubclassDemo.id == row_id) + ) + assert len(result) == 1 + assert isinstance(result[0].id, CustomUUID) + + selected_ids = await db_connection.exec( + QueryBuilder().select(UUIDSubclassDemo.id).where(UUIDSubclassDemo.id == row_id) + ) + assert selected_ids == [row_id] + assert isinstance(selected_ids[0], CustomUUID) + + @pytest.mark.asyncio async def test_db_obj_mixin_track_modifications(): user = UserDemo(name="John Doe", email="john@example.com") diff --git a/iceaxe/base.py b/iceaxe/base.py index b6c088c..db3743c 100644 --- a/iceaxe/base.py +++ b/iceaxe/base.py @@ -12,7 +12,9 @@ from pydantic.main import _model_construction from pydantic_core import PydanticUndefined +from iceaxe.custom_typehints import wrap_simple_subclass_annotation from iceaxe.field import DBFieldClassDefinition, DBFieldInfo, Field +from iceaxe.typing import transform_typehint @dataclass_transform(kw_only_default=True, field_specifiers=(PydanticField,)) @@ -51,6 +53,20 @@ def __new__(mcs, name, bases, namespace, **kwargs): Create a new database model class with proper field tracking. Handles registration of the model and processes any table-specific arguments. """ + # Pydantic must see these normalized annotations up front; otherwise simple + # subclasses like `CustomUUID(UUID)` are treated as unknown types. + namespace = dict(namespace) + raw_annotations = namespace.get("__annotations__", {}) + if raw_annotations: + namespace["__annotations__"] = { + key: ( + transform_typehint(annotation, wrap_simple_subclass_annotation) + if not isinstance(annotation, str) + else annotation + ) + for key, annotation in raw_annotations.items() + } + raw_kwargs = {**kwargs} mcs.is_constructing = True diff --git a/iceaxe/custom_typehints.py b/iceaxe/custom_typehints.py new file mode 100644 index 0000000..2e9aadc --- /dev/null +++ b/iceaxe/custom_typehints.py @@ -0,0 +1,250 @@ +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from enum import Enum +from inspect import isclass +from typing import Annotated, Any, Literal, assert_never, get_args, get_origin +from uuid import UUID + +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema + +from iceaxe.typing import resolve_typehint, unwrap_annotated + +# This literal definition seems overly verbose for what we're doing (ie. defining the types +# of simple subclasses for which we support type coercion). But it's required so we can properly +# throw a type error during static analysis if we don't properly support a handler for one +# of these supported types. +SimpleSubclassKind = Literal[ + "datetime", + "date", + "time", + "timedelta", + "uuid", + "bytes", + "str", + "int", + "float", + "bool", +] + +SIMPLE_SUBCLASS_BASE_TYPES_BY_KIND: dict[SimpleSubclassKind, type[Any]] = { + "datetime": datetime, + "date": date, + "time": time, + "timedelta": timedelta, + "uuid": UUID, + "bytes": bytes, + "str": str, + "int": int, + "float": float, + "bool": bool, +} +SIMPLE_SUBCLASS_BASE_TYPES = tuple(SIMPLE_SUBCLASS_BASE_TYPES_BY_KIND.values()) + + +class SimpleSubclassAnnotation: + """ + Pydantic metadata wrapper for "validate as base type, return subclass". + + Simple subclasses such as `CustomUUID(UUID)` are structurally compatible + with their parent type for database storage, but Pydantic does not know how + to build a schema for those subclasses by default. If Iceaxe passes the raw + subclass annotation through unchanged, model construction fails because + Pydantic treats it as an unknown arbitrary type. + + This metadata object is attached via `Annotated[...]` during + `transform_typehint(..., wrap_simple_subclass_annotation)`. When Pydantic + sees that annotation, + it calls `__get_pydantic_core_schema__`, which lets us: + - reuse the existing schema for the storage/base type (`UUID`, `date`, etc.) + - keep all of Pydantic's normal parsing behavior for that base type + - run one final post-validation cast that reconstructs the requested subclass + + The important constraint is that this only works for "simple" subclasses + whose runtime value can be losslessly rebuilt from the validated base value. + We are not defining a brand new schema here; we are explicitly piggybacking + on the parent type's schema and restoring the subclass identity afterward. + + """ + + def __init__(self, subtype: type[Any], base_type: type[Any]): + self.subtype = subtype + self.base_type = base_type + + def __get_pydantic_core_schema__( + self, + source_type: Any, + handler: GetCoreSchemaHandler, + ) -> core_schema.CoreSchema: + schema = handler.generate_schema(self.base_type) + return core_schema.no_info_after_validator_function( + self._cast_value, + schema, + ) + + def _cast_value(self, value: Any): + return coerce_single_subclass_value(value, self.subtype) + + +def wrap_simple_subclass_annotation(annotation: Any) -> Any: + """ + Wrap one annotation node with `SimpleSubclassAnnotation` when it represents + a supported simple subclass. + + This function is intentionally small in scope: it does not walk nested type + structures itself. Instead it is designed to be passed into + `transform_typehint`, which recursively traverses unions, `Annotated`, and + other generic wrappers and invokes this function on each node. At each node + we decide whether the current type needs Pydantic metadata so it can be + validated as its base storage type and then reconstructed as the subclass. + + If the node is already wrapped, or if it is not one of the supported + subclass shapes, the annotation is returned unchanged. + + """ + if get_origin(annotation) is Annotated: + inner, *metadata = get_args(annotation) + if any(isinstance(item, SimpleSubclassAnnotation) for item in metadata): + return annotation + + base_type = get_simple_subclass_base_type(inner) + if base_type is None: + return annotation + + return Annotated[ + inner, + *metadata, + SimpleSubclassAnnotation(inner, base_type), + ] + + base_type = get_simple_subclass_base_type(annotation) + if base_type is None: + return annotation + + return Annotated[annotation, SimpleSubclassAnnotation(annotation, base_type)] + + +def get_simple_subclass_base_type(annotation: Any) -> type[Any] | None: + """ + Resolve the storage/base type for a supported simple subclass annotation. + + For a value type like `CustomUUID(UUID)`, the runtime annotation is the + subclass but the storage behavior should follow `UUID`. This helper maps the + subclass back to that base type so callers can reason about database storage + and coercion without losing track of the original runtime type. + + The return value is intentionally `None` for non-subclasses and for base + types themselves. That lets callers distinguish "this annotation should be + treated specially" from "this is already a normal built-in/base type". + + """ + annotation = unwrap_annotated(annotation) + kind = get_simple_subclass_kind(annotation) + if kind is None: + return None + + base_type = SIMPLE_SUBCLASS_BASE_TYPES_BY_KIND[kind] + if annotation is base_type: + return None + + return base_type + + +def get_simple_subclass_kind(annotation: Any) -> SimpleSubclassKind | None: + """ + Classify an annotation into one of the supported simple-subclass families. + + The subclass feature only supports a bounded set of base runtime/storage + types, captured in `SIMPLE_SUBCLASS_BASE_TYPES_BY_KIND`. Rather than + repeatedly branching on `issubclass(..., UUID)` / `issubclass(..., date)` in + multiple places, we first collapse a candidate type into one stable literal + kind. Downstream code can then switch on that kind and get both clearer + control flow and exhaustiveness checking from static analysis. + + The function returns `None` for values that are not classes, enums, or do + not inherit from one of the supported base types. + + """ + annotation = unwrap_annotated(annotation) + if not isclass(annotation): + return None + if issubclass(annotation, Enum): + return None + + mro = annotation.mro() + matches: list[tuple[int, SimpleSubclassKind]] = [ + (mro.index(base_type), kind) + for kind, base_type in SIMPLE_SUBCLASS_BASE_TYPES_BY_KIND.items() + if base_type in mro + ] + if not matches: + return None + + return min(matches, key=lambda match: match[0])[1] + + +def convert_simple_subclass_value(value: Any, annotation: Any, *, to_db: bool) -> Any: + if value is None: + return None + + resolved = resolve_typehint(annotation) + storage_type = get_simple_subclass_base_type(resolved.runtime_type) + if storage_type is None: + return value + + target_type = storage_type if to_db else resolved.runtime_type + if resolved.is_list: + return [coerce_single_subclass_value(item, target_type) for item in value] + + return coerce_single_subclass_value(value, target_type) + + +def coerce_single_subclass_value(value: Any, target_type: type[Any]) -> Any: + if type(value) is target_type: + return value + + kind = get_simple_subclass_kind(target_type) + if kind is None: + raise TypeError(f"Unsupported simple subclass target type: {target_type}") + + match kind: + case "uuid": + return target_type(str(value)) + case "datetime": + return target_type( + value.year, + value.month, + value.day, + value.hour, + value.minute, + value.second, + value.microsecond, + tzinfo=value.tzinfo, + fold=value.fold, + ) + case "date": + return target_type( + value.year, + value.month, + value.day, + ) + case "time": + return target_type( + value.hour, + value.minute, + value.second, + value.microsecond, + tzinfo=value.tzinfo, + fold=value.fold, + ) + case "timedelta": + return target_type( + days=value.days, + seconds=value.seconds, + microseconds=value.microseconds, + ) + case "bytes" | "str" | "int" | "float" | "bool": + return target_type(value) + case unexpected: + assert_never(unexpected) diff --git a/iceaxe/field.py b/iceaxe/field.py index 31e90bb..74269db 100644 --- a/iceaxe/field.py +++ b/iceaxe/field.py @@ -17,6 +17,7 @@ from pydantic_core import PydanticUndefined from iceaxe.comparison import ComparisonBase +from iceaxe.custom_typehints import convert_simple_subclass_value from iceaxe.postgres import PostgresFieldBase from iceaxe.queries_str import QueryIdentifier, QueryLiteral from iceaxe.sql_types import ColumnType @@ -171,10 +172,16 @@ def extend_field( def to_db_value(self, value: Any): if self.is_json: return json_dumps(value) + if self.annotation is not None: + return convert_simple_subclass_value(value, self.annotation, to_db=True) return value def from_db_value(self, value: Any): - if not self.is_json or value is None: + if not self.is_json: + if self.annotation is None: + return value + return convert_simple_subclass_value(value, self.annotation, to_db=False) + if value is None: return value parsed_value = json_loads(value) if isinstance(value, str) else value diff --git a/iceaxe/schemas/db_memory_serializer.py b/iceaxe/schemas/db_memory_serializer.py index fbc9dec..8802a5a 100644 --- a/iceaxe/schemas/db_memory_serializer.py +++ b/iceaxe/schemas/db_memory_serializer.py @@ -14,6 +14,7 @@ TableBase, UniqueConstraint, ) +from iceaxe.custom_typehints import get_simple_subclass_base_type from iceaxe.generics import ( get_typevar_mapping, has_null_type, @@ -51,6 +52,7 @@ DATE_TYPES, JSON_WRAPPER_FALLBACK, PRIMITIVE_WRAPPER_TYPES, + resolve_typehint, ) NodeYieldType = Union[DBObject, DBObjectPointer, "NodeDefinition"] @@ -423,39 +425,51 @@ def handle_column_type(self, key: str, info: DBFieldInfo, table: Type[TableBase] ) annotation = remove_null_type(info.annotation) + resolved_annotation = resolve_typehint(annotation) + storage_annotation = ( + get_simple_subclass_base_type(resolved_annotation.runtime_type) + or resolved_annotation.runtime_type + ) + is_list = resolved_annotation.is_list # Resolve the type of the column, if generic - if isinstance(annotation, TypeVar): + if isinstance(storage_annotation, TypeVar): typevar_map = get_typevar_mapping(table) - annotation = typevar_map[annotation] + storage_annotation = typevar_map[storage_annotation] + resolved_annotation = resolve_typehint(storage_annotation) + storage_annotation = ( + get_simple_subclass_base_type(resolved_annotation.runtime_type) + or resolved_annotation.runtime_type + ) + is_list = resolved_annotation.is_list # Should be prioritized in terms of MRO; StrEnums should be processed # before the str types - if is_type_compatible(annotation, ALL_ENUM_TYPES): + if is_type_compatible(storage_annotation, ALL_ENUM_TYPES): # We only support string values for enums because postgres enums are defined # as name-based types - for value in annotation: # type: ignore + for value in storage_annotation: # type: ignore if not isinstance(value.value, str): raise ValueError( - f"Only string values are supported for enums, received: {value.value} (enum: {annotation})" + f"Only string values are supported for enums, received: {value.value} (enum: {storage_annotation})" ) return TypeDeclarationResponse( custom_type=DBType( - name=enum_to_name(annotation), # type: ignore - values=frozenset([value.value for value in annotation]), # type: ignore + name=enum_to_name(storage_annotation), # type: ignore + values=frozenset([value.value for value in storage_annotation]), # type: ignore reference_columns=frozenset({(table.get_table_name(), key)}), ), ) - elif is_type_compatible(annotation, PRIMITIVE_WRAPPER_TYPES): + elif is_type_compatible(storage_annotation, PRIMITIVE_WRAPPER_TYPES): for primitive, json_type in self.python_to_sql.items(): - if annotation == primitive or annotation == list[primitive]: # type: ignore + if storage_annotation == primitive: # type: ignore return TypeDeclarationResponse( primitive_type=json_type, - is_list=(annotation == list[primitive]), # type: ignore + is_list=is_list, ) - elif is_type_compatible(annotation, DATE_TYPES): - if is_type_compatible(annotation, datetime): # type: ignore + elif is_type_compatible(storage_annotation, DATE_TYPES): + if is_type_compatible(storage_annotation, datetime): # type: ignore if isinstance(info.postgres_config, PostgresDateTime): return TypeDeclarationResponse( primitive_type=( @@ -468,11 +482,11 @@ def handle_column_type(self, key: str, info: DBFieldInfo, table: Type[TableBase] return TypeDeclarationResponse( primitive_type=ColumnType.TIMESTAMP_WITHOUT_TIME_ZONE, ) - elif is_type_compatible(annotation, date): # type: ignore + elif is_type_compatible(storage_annotation, date): # type: ignore return TypeDeclarationResponse( primitive_type=ColumnType.DATE, ) - elif is_type_compatible(annotation, time): # type: ignore + elif is_type_compatible(storage_annotation, time): # type: ignore if isinstance(info.postgres_config, PostgresTime): return TypeDeclarationResponse( primitive_type=( @@ -484,34 +498,34 @@ def handle_column_type(self, key: str, info: DBFieldInfo, table: Type[TableBase] return TypeDeclarationResponse( primitive_type=ColumnType.TIME_WITHOUT_TIME_ZONE, ) - elif is_type_compatible(annotation, timedelta): # type: ignore + elif is_type_compatible(storage_annotation, timedelta): # type: ignore return TypeDeclarationResponse( primitive_type=ColumnType.INTERVAL, ) else: - raise ValueError(f"Unsupported date type: {annotation}") - elif is_type_compatible(annotation, BaseModel): + raise ValueError(f"Unsupported date type: {storage_annotation}") + elif is_type_compatible(storage_annotation, BaseModel): if info.is_json: return TypeDeclarationResponse( primitive_type=ColumnType.JSON, ) else: raise ValueError( - f"Pydantic model fields must have Field(is_json=True) specified: {annotation}\n" + f"Pydantic model fields must have Field(is_json=True) specified: {storage_annotation}\n" f"Column: {table.__name__}.{key}" ) - elif is_type_compatible(annotation, JSON_WRAPPER_FALLBACK): + elif is_type_compatible(storage_annotation, JSON_WRAPPER_FALLBACK): if info.is_json: return TypeDeclarationResponse( primitive_type=ColumnType.JSON, ) else: raise ValueError( - f"JSON fields must have Field(is_json=True) specified: {annotation}\n" + f"JSON fields must have Field(is_json=True) specified: {storage_annotation}\n" f"Column: {table.__name__}.{key}" ) - raise ValueError(f"Unsupported column type: {annotation}") + raise ValueError(f"Unsupported column type: {storage_annotation}") def handle_single_constraints( self, key: str, info: DBFieldInfo, table: Type[TableBase] diff --git a/iceaxe/session.py b/iceaxe/session.py index b3f5d6f..25544a1 100644 --- a/iceaxe/session.py +++ b/iceaxe/session.py @@ -219,6 +219,50 @@ def get_dsn(self) -> str: return "".join(dsn_parts) + def _cast_column_select_results( + self, + results: list[Any], + select_raw: Sequence[Any], + ) -> list[Any]: + """ + Apply field-level `from_db_value` casting to direct column selections. + + The optimized select path already constructs full table objects with each + field deserialized through the model metadata. Direct column selections are + intentionally left as raw driver values for speed, which means subclass-backed + fields like `CustomUUID(UUID)` come back as the base database type instead of + the annotated runtime type. This helper restores that field-level coercion for + non-JSON column selects while leaving table, alias, and function selections on + the existing fast path. + + """ + column_cast_indices: list[int] = [] + for index, raw_value in enumerate(select_raw): + if is_column(raw_value) and not raw_value.field_definition.is_json: + column_cast_indices.append(index) + + if not column_cast_indices: + return results + + for result_index, row in enumerate(results): + if len(select_raw) == 1: + raw_column = cast(DBFieldClassDefinition[Any], select_raw[0]) + results[result_index] = raw_column.field_definition.from_db_value(row) + continue + + row_values = list(cast(tuple[Any, ...], row)) + for column_index in column_cast_indices: + raw_column = cast( + DBFieldClassDefinition[Any], + select_raw[column_index], + ) + row_values[column_index] = raw_column.field_definition.from_db_value( + row_values[column_index] + ) + results[result_index] = tuple(row_values) + + return results + @asynccontextmanager async def transaction(self, *, ensure: bool = False): """ @@ -332,6 +376,10 @@ async def exec( ] result_all = optimize_exec_casting(values, query._select_raw, select_types) + result_all = self._cast_column_select_results( + result_all, + query._select_raw, + ) # Only loop through results if we have verbosity enabled, since this logic otherwise # is wasted if no content will eventually be logged diff --git a/iceaxe/typing.py b/iceaxe/typing.py index f4b94e5..422cc71 100644 --- a/iceaxe/typing.py +++ b/iceaxe/typing.py @@ -1,14 +1,21 @@ from __future__ import annotations +import types +from dataclasses import dataclass from datetime import date, datetime, time, timedelta from enum import Enum, IntEnum, StrEnum from inspect import isclass from typing import ( TYPE_CHECKING, + Annotated, Any, + Callable, Type, TypeGuard, TypeVar, + Union, + get_args, + get_origin, ) from uuid import UUID @@ -31,6 +38,157 @@ T = TypeVar("T") +# +# Simple type utility function +# + + +def is_union_type(annotation: Any) -> bool: + origin = get_origin(annotation) + return origin is Union or isinstance(annotation, types.UnionType) + + +def rebuild_typehint(annotation: Any, args: tuple[Any, ...]): + if is_union_type(annotation): + return Union[args] # type: ignore + + origin = get_origin(annotation) + if origin is None: + return annotation + + item = args[0] if len(args) == 1 else args + return origin[item] + + +def unwrap_annotated(annotation: Any) -> Any: + while get_origin(annotation) is Annotated: + annotation = get_args(annotation)[0] + return annotation + + +def get_optional_inner(annotation: Any) -> Any | None: + if not is_union_type(annotation): + return None + + non_null_args = tuple( + arg for arg in get_args(annotation) if unwrap_annotated(arg) is not type(None) + ) + if len(non_null_args) != 1: + return None + + return non_null_args[0] + + +# +# Type introspection +# + + +@dataclass(frozen=True) +class ResolvedTypehint: + runtime_type: Any + is_list: bool + + +def resolve_typehint(annotation: Any) -> ResolvedTypehint: + """ + Normalize a field annotation into the subset of typing structure Iceaxe + needs for runtime coercion and schema inference. + + Python annotations can describe the same logical field in several wrapped + forms: + - `Annotated[T, ...]` carries metadata but doesn't change the core type. + - `T | None` / `Optional[T]` is represented as a union and needs to be + unwrapped to reach the concrete value type. + - `list[T]` means the ORM should treat the column as an array while still + reasoning about the element type `T`. + + Callers that need to infer database/storage behavior should not each have to + reimplement `get_origin()` / `get_args()` handling or care about the exact + runtime shape Python uses for unions and annotated metadata. This helper + resolves those wrappers into a canonical form: + - `runtime_type`: the innermost non-`Annotated`, non-nullable element type + - `is_list`: whether the annotation represents a top-level `list[...]` + + The resolver is intentionally narrow. It understands the container/wrapper + shapes Iceaxe needs structurally, but it does not try to semantically reduce + arbitrary generic types. For example, nested generics are preserved inside + `runtime_type` once the top-level list/optional wrappers have been handled. + + """ + current = annotation + is_list = False + + while True: + current = unwrap_annotated(current) + + optional_inner = get_optional_inner(current) + if optional_inner is not None: + current = optional_inner + continue + + if not is_list and get_origin(current) is list: + (current,) = get_args(current) + is_list = True + continue + + break + + return ResolvedTypehint( + runtime_type=unwrap_annotated(current), + is_list=is_list, + ) + + +def transform_typehint( + annotation: Any, + transform: Callable[[Any], Any], +) -> Any: + """ + Recursively rebuild an annotation tree while applying a callback to each node. + + Python type hints are often nested combinations of wrappers such as + `Annotated[...]`, unions, and container generics. Callers sometimes need to + inject or rewrite metadata inside that structure without losing the overall + typing shape. This helper performs that traversal once and hands each rebuilt + node to `transform`, allowing feature-specific code to focus on "what should + this node become?" rather than repeatedly reimplementing `get_origin()` / + `get_args()` recursion. + + Some examples of the supported traversal behavior: + - `CustomUUID | None` visits `CustomUUID`, applies the transform there, and + then rebuilds the nullable union around the transformed result. + - `list[CustomUUID]` visits `CustomUUID`, applies the transform there, and + then rebuilds the outer list as `list[]`. + - `dict[str, CustomUUID]` preserves the `dict[str, ...]` shape while still + transforming the nested value type. + - `Annotated[list[CustomUUID], Meta()]` first transforms the inner + `list[CustomUUID]`, then rebuilds the `Annotated[...]` wrapper with the + original metadata still attached. + + In all of these cases, child nodes are transformed before their parent + wrapper is rebuilt. That lets callers inspect the already-normalized inner + annotation when they receive a parent node such as `Annotated[...]`. + + """ + origin = get_origin(annotation) + + if origin is Annotated: + inner, *metadata = get_args(annotation) + return transform(Annotated[transform_typehint(inner, transform), *metadata]) + + if origin is not None: + args = tuple(transform_typehint(arg, transform) for arg in get_args(annotation)) + return transform(rebuild_typehint(annotation, args)) + + return transform(annotation) + + +# +# Typeguards +# + + def is_base_table(obj: Any) -> TypeGuard[type[TableBase]]: from iceaxe.base import TableBase