diff --git a/iceaxe/__tests__/conftest.py b/iceaxe/__tests__/conftest.py index 3869d84..0dd0076 100644 --- a/iceaxe/__tests__/conftest.py +++ b/iceaxe/__tests__/conftest.py @@ -59,6 +59,7 @@ async def db_connection(docker_postgres): "demomodela", "demomodelb", "jsondemo", + "pydanticjsondemo", "complextypedemo", ] known_types = ["statusenum", "employeestatus"] diff --git a/iceaxe/__tests__/schemas/test_db_memory_serializer.py b/iceaxe/__tests__/schemas/test_db_memory_serializer.py index 853451c..6c4bb80 100644 --- a/iceaxe/__tests__/schemas/test_db_memory_serializer.py +++ b/iceaxe/__tests__/schemas/test_db_memory_serializer.py @@ -6,7 +6,7 @@ from uuid import UUID import pytest -from pydantic import create_model +from pydantic import BaseModel, create_model from pydantic.fields import FieldInfo from iceaxe import Field, TableBase @@ -1569,3 +1569,22 @@ class TestModel(TableBase): id_column = next(c for c in columns if c.column_name == "id") assert id_column.column_type == ColumnType.INTEGER assert not id_column.nullable + + +def test_pydantic_model_json_field(clear_all_database_objects): + class SettingsModel(BaseModel): + theme: str + notifications: bool + + class TestModel(TableBase): + id: int = Field(primary_key=True) + settings: SettingsModel = Field(is_json=True) + + migrator = DatabaseMemorySerializer() + db_objects = list(migrator.delegate([TestModel])) + + columns = [obj for obj, _ in db_objects if isinstance(obj, DBColumn)] + settings_column = next(c for c in columns if c.column_name == "settings") + + assert settings_column.column_type == ColumnType.JSON + assert not settings_column.nullable diff --git a/iceaxe/__tests__/test_session.py b/iceaxe/__tests__/test_session.py index f38889b..9e639df 100644 --- a/iceaxe/__tests__/test_session.py +++ b/iceaxe/__tests__/test_session.py @@ -1,11 +1,13 @@ from contextlib import asynccontextmanager from enum import StrEnum +from json import dumps as json_dumps, loads as json_loads from typing import Any, Type from unittest.mock import AsyncMock, patch import asyncpg import pytest from asyncpg.connection import Connection +from pydantic import BaseModel from iceaxe.__tests__.conf_models import ( ArtifactDemo, @@ -544,6 +546,122 @@ async def test_refresh(db_connection: DBConnection): assert user.name == "Jane Doe" +@pytest.mark.asyncio +async def test_pydantic_json_round_trip(db_connection: DBConnection): + class Preferences(BaseModel): + theme: str + notifications: bool + + class PydanticJsonDemo(TableBase): + id: int | None = Field(primary_key=True, default=None) + payload: Preferences = Field(is_json=True) + + await db_connection.conn.execute("DROP TABLE IF EXISTS pydanticjsondemo") + await create_all(db_connection, [PydanticJsonDemo]) + + demo = PydanticJsonDemo( + payload=Preferences(theme="dark", notifications=True), + ) + await db_connection.insert([demo]) + + full_result = await db_connection.exec( + QueryBuilder().select(PydanticJsonDemo).where(PydanticJsonDemo.id == demo.id) + ) + assert len(full_result) == 1 + assert full_result[0].payload == Preferences(theme="dark", notifications=True) + + column_result = await db_connection.exec( + QueryBuilder() + .select(PydanticJsonDemo.payload) + .where(PydanticJsonDemo.id == demo.id) + ) + assert column_result == [Preferences(theme="dark", notifications=True)] + + demo.payload = Preferences(theme="light", notifications=False) + await db_connection.update([demo]) + + updated_result = await db_connection.exec( + QueryBuilder().select(PydanticJsonDemo).where(PydanticJsonDemo.id == demo.id) + ) + assert updated_result == [ + PydanticJsonDemo( + id=demo.id, + payload=Preferences(theme="light", notifications=False), + ) + ] + + await db_connection.conn.execute( + "UPDATE pydanticjsondemo SET payload = $1::json WHERE id = $2", + json_dumps({"theme": "system", "notifications": True}), + demo.id, + ) + await db_connection.refresh([demo]) + assert demo.payload == Preferences(theme="system", notifications=True) + + +@pytest.mark.asyncio +async def test_pydantic_json_serialization_to_database(db_connection: DBConnection): + class Preferences(BaseModel): + theme: str + notifications: bool + + class PydanticJsonDemo(TableBase): + id: int | None = Field(primary_key=True, default=None) + payload: Preferences = Field(is_json=True) + + await db_connection.conn.execute("DROP TABLE IF EXISTS pydanticjsondemo") + await create_all(db_connection, [PydanticJsonDemo]) + + demo = PydanticJsonDemo( + payload=Preferences(theme="dark", notifications=True), + ) + await db_connection.insert([demo]) + + raw_row = await db_connection.conn.fetchrow( + "SELECT payload::text AS payload FROM pydanticjsondemo WHERE id = $1", + demo.id, + ) + assert raw_row is not None + assert json_loads(raw_row["payload"]) == { + "theme": "dark", + "notifications": True, + } + + +@pytest.mark.asyncio +async def test_pydantic_json_deserialization_from_database( + db_connection: DBConnection, +): + class Preferences(BaseModel): + theme: str + notifications: bool + + class PydanticJsonDemo(TableBase): + id: int | None = Field(primary_key=True, default=None) + payload: Preferences = Field(is_json=True) + + await db_connection.conn.execute("DROP TABLE IF EXISTS pydanticjsondemo") + await create_all(db_connection, [PydanticJsonDemo]) + + inserted_row = await db_connection.conn.fetchrow( + "INSERT INTO pydanticjsondemo (payload) VALUES ($1::json) RETURNING id", + json_dumps({"theme": "system", "notifications": False}), + ) + assert inserted_row is not None + + result = await db_connection.exec( + QueryBuilder() + .select(PydanticJsonDemo) + .where(PydanticJsonDemo.id == inserted_row["id"]) + ) + assert result == [ + PydanticJsonDemo( + id=inserted_row["id"], + payload=Preferences(theme="system", notifications=False), + ) + ] + + @pytest.mark.asyncio async def test_get(db_connection: DBConnection): """ diff --git a/iceaxe/field.py b/iceaxe/field.py index 0b315b5..363a508 100644 --- a/iceaxe/field.py +++ b/iceaxe/field.py @@ -1,4 +1,4 @@ -from json import dumps as json_dumps +from json import dumps as json_dumps, loads as json_loads from typing import ( TYPE_CHECKING, Any, @@ -12,7 +12,7 @@ cast, ) -from pydantic import Field as PydanticField +from pydantic import Field as PydanticField, TypeAdapter from pydantic.fields import FieldInfo, _FieldInfoInputs from pydantic_core import PydanticUndefined @@ -105,6 +105,8 @@ class DBFieldInfo(FieldInfo): When set, this type takes precedence over automatic type inference. """ + _json_type_adapter: TypeAdapter[Any] | None = None + def __init__(self, **kwargs: Unpack[DBFieldInputs]): """ Initialize a new DBFieldInfo instance with the given field configuration. @@ -163,6 +165,19 @@ def to_db_value(self, value: Any): return json_dumps(value) return value + def from_db_value(self, value: Any): + if not self.is_json or value is None: + return value + + parsed_value = json_loads(value) if isinstance(value, str) else value + if self.annotation is None: + return parsed_value + + if self._json_type_adapter is None: + self._json_type_adapter = TypeAdapter(self.annotation) + + return self._json_type_adapter.validate_python(parsed_value) + def __get_db_field(_: Callable[Concatenate[Any, P], Any] = PydanticField): # type: ignore """ diff --git a/iceaxe/queries.py b/iceaxe/queries.py index ea719de..021aff2 100644 --- a/iceaxe/queries.py +++ b/iceaxe/queries.py @@ -5,6 +5,8 @@ from functools import wraps from typing import Any, Generic, Literal, Type, TypeVar, TypeVarTuple, cast, overload +from pydantic import BaseModel + from iceaxe.alias_values import Alias from iceaxe.base import ( DBFieldClassDefinition, @@ -46,6 +48,7 @@ | PRIMITIVE_WRAPPER_TYPES | DATE_TYPES | JSON_WRAPPER_FALLBACK + | BaseModel | None ) diff --git a/iceaxe/schemas/db_memory_serializer.py b/iceaxe/schemas/db_memory_serializer.py index 07d5554..fbc9dec 100644 --- a/iceaxe/schemas/db_memory_serializer.py +++ b/iceaxe/schemas/db_memory_serializer.py @@ -5,6 +5,7 @@ from typing import Any, Generator, Sequence, Type, TypeVar, Union from uuid import UUID +from pydantic import BaseModel from pydantic_core import PydanticUndefined from iceaxe.base import ( @@ -489,6 +490,16 @@ def handle_column_type(self, key: str, info: DBFieldInfo, table: Type[TableBase] ) else: raise ValueError(f"Unsupported date type: {annotation}") + elif is_type_compatible(annotation, BaseModel): + if info.is_json: + return TypeDeclarationResponse( + primitive_type=ColumnType.JSON, + ) + else: + raise ValueError( + f"Pydantic model fields must have Field(is_json=True) specified: {annotation}\n" + f"Column: {table.__name__}.{key}" + ) elif is_type_compatible(annotation, JSON_WRAPPER_FALLBACK): if info.is_json: return TypeDeclarationResponse( diff --git a/iceaxe/session.py b/iceaxe/session.py index 17c20d5..0184bc1 100644 --- a/iceaxe/session.py +++ b/iceaxe/session.py @@ -2,7 +2,6 @@ from collections import defaultdict from contextlib import asynccontextmanager from inspect import isclass -from json import loads as json_loads from math import ceil from typing import ( Any, @@ -558,13 +557,9 @@ async def upsert( processed_values = [] for field in returning_fields_cols: value = row[field.key] - if ( - value is not None - and field.root_model.model_fields[ - field.key - ].is_json - ): - value = json_loads(value) + value = field.root_model.model_fields[ + field.key + ].from_db_value(value) processed_values.append(value) results.append(tuple(processed_values)) else: @@ -747,7 +742,13 @@ async def refresh(self, objects: Sequence[TableBase]): if obj_id in results: # Update field-by-field for field in fields: - setattr(obj, field, results[obj_id][field]) + setattr( + obj, + field, + model.model_fields[field].from_db_value( + results[obj_id][field] + ), + ) else: LOGGER.error( f"Object {obj} with primary key {obj_id} not found in database" diff --git a/iceaxe/session_optimized.pyx b/iceaxe/session_optimized.pyx index 08349c2..f90f37f 100644 --- a/iceaxe/session_optimized.pyx +++ b/iceaxe/session_optimized.pyx @@ -2,7 +2,6 @@ from typing import Any, List, Tuple from iceaxe.base import TableBase from iceaxe.queries import FunctionMetadata from iceaxe.alias_values import Alias -from json import loads as json_loads from cpython.ref cimport PyObject from cpython.object cimport PyObject_GetItem from libc.stdlib cimport malloc, free @@ -132,7 +131,7 @@ cdef list process_values( if field_value is not None: all_none = False if fields[j][k].is_json: - field_value = json_loads(field_value) + field_value = select_raw.model_fields[field_name].from_db_value(field_value) obj_dict[field_name] = field_value @@ -153,6 +152,8 @@ cdef list process_values( item = value[f"{table_name}_{column_name}"] except KeyError: raise KeyError(f"Key '{table_name}_{column_name}' not found in value.") + if item is not None and select_raw.field_definition.is_json: + item = select_raw.field_definition.from_db_value(item) result_value[j] = item Py_INCREF(item)