diff --git a/docs/tutorial/field-types.md b/docs/tutorial/field-types.md index aeea2d9..b0d6e20 100644 --- a/docs/tutorial/field-types.md +++ b/docs/tutorial/field-types.md @@ -30,6 +30,46 @@ `DbmModel` (Embed Model) > Field value can be an instance of another DbmModel subclass. See [Embed Models](embed-models.md) for details. +## Optional Types + +Any supported type can be made optional using `Optional[T]` or `T | None` (Python 3.10+). +Optional fields default to `None` if no value is provided and accept both the inner type value and `None`. + +```python +import typing + +from pydbm import DbmModel + + +class User(DbmModel): + name: str + nickname: typing.Optional[str] # defaults to None, accepts str or None + age: typing.Optional[int] # defaults to None, accepts int or None +``` + +```python +user = User(name="hakan") +assert user.nickname is None +assert user.age is None + +user = User(name="hakan", nickname="hako", age=30) +assert user.nickname == "hako" +assert user.age == 30 +``` + +You can also provide a custom default value for optional fields: + +```python +class User(DbmModel): + name: str + nickname: typing.Optional[str] = "anonymous" + +user = User(name="hakan") +assert user.nickname == "anonymous" +``` + +## Example + ```python import datetime diff --git a/src/pydbm/database/manager.py b/src/pydbm/database/manager.py index 8c34105..d4ebf8a 100644 --- a/src/pydbm/database/manager.py +++ b/src/pydbm/database/manager.py @@ -8,7 +8,7 @@ from pydbm import contstant as C from pydbm.database.data_types import BaseDataType -from pydbm.inspect_extra import get_obj_annotations +from pydbm.inspect_extra import get_obj_annotations, is_optional_type, unwrap_optional from pydbm.models.fields import AutoField if typing.TYPE_CHECKING: @@ -52,6 +52,7 @@ class DatabaseManager: "db_path", "db", DATABASE_HEADER_NAME, + "__optional_fields__", "_keys", "__is_db_open", ) @@ -113,8 +114,18 @@ def set_database_header(self): ann = get_obj_annotations(obj=self.model) resolved_ann = {} + optional_fields: set[str] = set() for key, value in ann.items(): - if value in DATABASE_HEADER_MAPPING: + if is_optional_type(value): + inner_type = unwrap_optional(value) + optional_fields.add(key) + if inner_type in DATABASE_HEADER_MAPPING: + resolved_ann[key] = inner_type + elif isinstance(inner_type, type) and hasattr(inner_type, "objects"): + resolved_ann[key] = dict + else: + resolved_ann[key] = inner_type + elif value in DATABASE_HEADER_MAPPING: resolved_ann[key] = value elif isinstance(value, type) and hasattr(value, "objects"): resolved_ann[key] = dict @@ -133,6 +144,7 @@ def set_database_header(self): assert database_header == db_headers, f"Database headers are not equal: '{database_header}' != '{db_headers}'" # type: ignore[str-bytes-safe] # noqa: E501 setattr(self, DATABASE_HEADER_NAME, resolved_ann) + self.__optional_fields__ = optional_fields def open(self): if not self.__is_db_open: @@ -148,6 +160,9 @@ def close(self) -> None: def save(self, *, id: str, fields: dict[str, typing.Any]) -> None: data: dict[str, typing.Any] = {} for key, value in fields.items(): + if value is None and key in self.__optional_fields__: + data[key] = None + continue header_type = self.__database_headers__[key] if header_type is dict and hasattr(value, "as_dict"): embed_headers = value.objects.__database_headers__ @@ -194,7 +209,10 @@ def get(self, *, id: str | None = None, **unique_together) -> DbmModel: to_python = ast.literal_eval(data_from_dbm.decode("utf-8")) # TODO: implement own parser fields: dict[str, typing.Any] = {} for key, value in to_python.items(): - fields[key] = BaseDataType.get_data_type(self.__database_headers__[key]).get(value) + if value is None and key in self.__optional_fields__: + fields[key] = None + else: + fields[key] = BaseDataType.get_data_type(self.__database_headers__[key]).get(value) return self.model(**fields) @@ -212,6 +230,9 @@ def update(self, *, id: str, **updated_fields) -> None: data = ast.literal_eval(data_from_dbm.decode("utf-8")) for key, value in updated_fields.items(): + if value is None and key in self.__optional_fields__: + data[key] = None + continue header_type = self.__database_headers__[key] if header_type is dict and hasattr(value, "as_dict"): embed_headers = value.objects.__database_headers__ @@ -241,7 +262,10 @@ def all(self) -> typing.Iterable[DbmModel]: to_python = ast.literal_eval(data_from_dbm.decode("utf-8")) fields: dict[str, typing.Any] = {} for key, value in to_python.items(): - fields[key] = BaseDataType.get_data_type(self.__database_headers__[key]).get(value) + if value is None and key in self.__optional_fields__: + fields[key] = None + else: + fields[key] = BaseDataType.get_data_type(self.__database_headers__[key]).get(value) yield self.model(**fields) def filter(self, **kwargs) -> typing.Iterator[DbmModel]: diff --git a/src/pydbm/inspect_extra.py b/src/pydbm/inspect_extra.py index 78166e9..b257d64 100644 --- a/src/pydbm/inspect_extra.py +++ b/src/pydbm/inspect_extra.py @@ -6,9 +6,35 @@ __all__ = ( "get_obj_annotations", + "is_optional_type", + "unwrap_optional", ) +def is_optional_type(tp: typing.Any) -> bool: + """Check if a type is Optional (Union with NoneType), e.g. Optional[str] or str | None.""" + origin = typing.get_origin(tp) + if origin is typing.Union: + args = typing.get_args(tp) + return type(None) in args and len(args) == 2 + if sys.version_info >= (3, 10): + import types as _types + + if isinstance(tp, _types.UnionType): + args = typing.get_args(tp) + return type(None) in args and len(args) == 2 + return False + + +def unwrap_optional(tp: typing.Any) -> typing.Any: + """Extract the inner type from Optional[X] / X | None.""" + args = typing.get_args(tp) + for arg in args: + if arg is not type(None): + return arg + return type(None) + + def get_obj_annotations(*, obj: typing.Type[typing.Any]) -> dict[str, typing.Any]: assert inspect.isclass(obj), f"{obj!r} must be a class" diff --git a/src/pydbm/models/fields/base.py b/src/pydbm/models/fields/base.py index 31790c0..293ac75 100644 --- a/src/pydbm/models/fields/base.py +++ b/src/pydbm/models/fields/base.py @@ -41,6 +41,7 @@ class BaseField: "kwargs", "_is_call_run", "_is_embed_model", + "_is_optional", ) def __init__( @@ -65,6 +66,7 @@ def __init__( self._is_call_run = False self._is_embed_model = False + self._is_optional = False def __set_name__(self, instance: Meta, name: str) -> None: self.public_name = name @@ -78,6 +80,12 @@ def __get__(self, instance: Meta, owner: DbmModel) -> typing.Any: def __set__(self, instance: DbmModel, value: typing.Any) -> None: if self._is_embed_model: + if value is None and self._is_optional: + setattr(instance, self.private_name, None) + if self.field_name != C.PRIMARY_KEY: + instance.fields[self.field_name] = None + return + from pydbm.database.data_types import BaseDataType if isinstance(value, dict): @@ -102,8 +110,9 @@ def __set__(self, instance: DbmModel, value: typing.Any) -> None: if self.field_name != C.PRIMARY_KEY: instance.fields[self.field_name] = eligible_value - def __call__(self: Self, field_name: str, field_type: SupportedClassT, *args, **kwargs) -> Self: # type: ignore[valid-type] # noqa: E501 + def __call__(self: Self, field_name: str, field_type: SupportedClassT, *args, is_optional: bool = False, **kwargs) -> Self: # type: ignore[valid-type] # noqa: E501 self._is_call_run = True + self._is_optional = is_optional self.field_name = field_name self.field_type = field_type @@ -122,7 +131,15 @@ def __call__(self: Self, field_name: str, field_type: SupportedClassT, *args, ** ) self.validators.append(validator_mapping[field_type]) else: - self.validators.append(validator_mapping[field_type]) + inner_validator = validator_mapping[field_type] + if is_optional: + def optional_validator(value: typing.Any, v: ValidatorT = inner_validator) -> None: + if value is not None: + v(value) + + self.validators.append(optional_validator) + else: + self.validators.append(inner_validator) if field_type is int: if self.min_value: diff --git a/src/pydbm/models/meta.py b/src/pydbm/models/meta.py index ca8d31a..0b9a277 100644 --- a/src/pydbm/models/meta.py +++ b/src/pydbm/models/meta.py @@ -6,7 +6,7 @@ from pydbm import typing_extra from pydbm.database import DatabaseManager from pydbm.exceptions import EmptyModelError, PydbmBaseException, ReadOnlyFieldError, UnnecessaryParamsError -from pydbm.inspect_extra import get_obj_annotations +from pydbm.inspect_extra import get_obj_annotations, is_optional_type, unwrap_optional from pydbm.models.fields import AutoField, Field, Undefined __all__ = ( @@ -130,9 +130,17 @@ def generate_fields(mcs, cls, cls_name: str, namespace: dict[str, typing.Any]) - if field_name == C.PRIMARY_KEY: continue + optional = is_optional_type(field_type) + actual_type = unwrap_optional(field_type) if optional else field_type + default_value: Field | typing.Any = namespace.get(field_name, Undefined) - field = default_value if isinstance(default_value, Field) else Field(default=default_value) - fields.update({field_name: field(field_name, field_type)}) + if isinstance(default_value, Field): + field = default_value + elif optional and default_value is Undefined: + field = Field(default=None) + else: + field = Field(default=default_value) + fields.update({field_name: field(field_name, actual_type, is_optional=optional)}) return fields @staticmethod diff --git a/tests/models/test_optional.py b/tests/models/test_optional.py new file mode 100644 index 0000000..3e285ec --- /dev/null +++ b/tests/models/test_optional.py @@ -0,0 +1,156 @@ +import typing + +import pytest + +from pydbm import DbmModel, Field, ValidationError + + +class UserModel(DbmModel): + name: str + nickname: typing.Optional[str] + + +def test_optional_field_default_none(): + user = UserModel(name="hakan") + assert user.nickname is None + + +def test_optional_field_with_value(): + user = UserModel(name="hakan", nickname="hako") + assert user.nickname == "hako" + + +def test_optional_field_explicit_none(): + user = UserModel(name="hakan", nickname=None) + assert user.nickname is None + + +def test_optional_field_save_and_load_none(teardown_db): + user = UserModel(name="hakan") + user.save() + + loaded = UserModel.objects.get(id=user.id) + assert loaded.name == "hakan" + assert loaded.nickname is None + + +def test_optional_field_save_and_load_value(teardown_db): + user = UserModel(name="hakan", nickname="hako") + user.save() + + loaded = UserModel.objects.get(id=user.id) + assert loaded.name == "hakan" + assert loaded.nickname == "hako" + + +def test_optional_field_update_to_none(teardown_db): + user = UserModel.objects.create(name="hakan", nickname="hako") + assert user.nickname == "hako" + + user.update(nickname=None) + assert user.nickname is None + + loaded = UserModel.objects.get(id=user.id) + assert loaded.nickname is None + + +def test_optional_field_update_from_none(teardown_db): + user = UserModel.objects.create(name="hakan") + assert user.nickname is None + + user.update(nickname="hako") + assert user.nickname == "hako" + + loaded = UserModel.objects.get(id=user.id) + assert loaded.nickname == "hako" + + +def test_optional_field_validation_wrong_type(): + with pytest.raises(ValidationError) as cm: + UserModel(name="hakan", nickname=123) + + assert cm.value.field_name == "nickname" + assert cm.value.field_value == 123 + + +def test_optional_field_all(teardown_db): + UserModel.objects.create(name="hakan", nickname="hako") + UserModel.objects.create(name="celik") + + users = sorted(list(UserModel.objects.all()), key=lambda x: x.name) + assert users[0].name == "celik" + assert users[0].nickname is None + assert users[1].name == "hakan" + assert users[1].nickname == "hako" + + +def test_optional_field_filter(teardown_db): + UserModel.objects.create(name="hakan", nickname="hako") + UserModel.objects.create(name="celik") + + result = list(UserModel.objects.filter(nickname=None)) + assert len(result) == 1 + assert result[0].name == "celik" + + +def test_optional_int_field(teardown_db): + class Model(DbmModel): + name: str + age: typing.Optional[int] + + model = Model(name="hakan") + assert model.age is None + + model = Model(name="hakan", age=30) + assert model.age == 30 + + model.save() + loaded = Model.objects.get(id=model.id) + assert loaded.age == 30 + + +def test_optional_int_field_save_none(teardown_db): + class Model(DbmModel): + name: str + age: typing.Optional[int] + + model = Model(name="hakan") + model.save() + + loaded = Model.objects.get(id=model.id) + assert loaded.age is None + + +def test_optional_field_with_custom_default(): + class Model(DbmModel): + name: str + nickname: typing.Optional[str] = "default_nick" + + model = Model(name="hakan") + assert model.nickname == "default_nick" + + +def test_optional_field_with_field_descriptor(): + class Model(DbmModel): + name: str + nickname: typing.Optional[str] = Field(default="custom") + + model = Model(name="hakan") + assert model.nickname == "custom" + + +def test_optional_field_repr(): + user = UserModel(name="hakan", nickname="hako") + assert "nickname='hako'" in repr(user) + + user2 = UserModel(name="hakan") + assert "nickname=None" in repr(user2) + + +def test_optional_field_eq(): + user1 = UserModel(name="hakan") + user2 = UserModel(name="hakan") + assert user1 == user2 + + user3 = UserModel(name="hakan", nickname="hako") + assert user1 != user3