Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/dataclass_binder/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from inspect import cleandoc, get_annotations, getmodule, getsource, isabstract
from pathlib import Path
from textwrap import dedent
from types import MappingProxyType, ModuleType, NoneType, UnionType
from types import GenericAlias, MappingProxyType, ModuleType, NoneType, UnionType
from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, Generic, TypeVar, Union, cast, get_args, get_origin, overload
from weakref import WeakKeyDictionary

Expand All @@ -37,7 +37,7 @@
# Note: Actually 'field_type' can either be a type of a typing special form,
# but there is no way yet to annotate typing special forms.
# This is the source of a lot of the casts and suppressions in this function.
def _collect_type(field_type: type, context: str) -> type | Binder[Any]:
def _collect_type(field_type: type, context: str) -> type | GenericAlias | Binder[Any]:
"""
Verify and streamline a type annotation.

Expand All @@ -48,7 +48,7 @@ def _collect_type(field_type: type, context: str) -> type | Binder[Any]:
"""
origin = get_origin(field_type)
if origin is None:
if field_type is Any: # type: ignore[comparison-overlap]
if field_type is Any:
return object
elif not isinstance(field_type, type):
raise TypeError(f"Annotation for field '{context}' is not a type")
Expand Down Expand Up @@ -94,7 +94,7 @@ def _collect_type(field_type: type, context: str) -> type | Binder[Any]:
else:
return origin[tuple(_collect_type(arg, arg_context) for arg in args)] # type: ignore[no-any-return]
# Use the convention that the first argument is the element type.
return origin[_collect_type(args[0], arg_context)] # type: ignore[no-any-return]
return origin[_collect_type(args[0], arg_context)]
elif origin is type:
try:
(arg,) = get_args(field_type)
Expand Down Expand Up @@ -215,7 +215,7 @@ class _ClassInfo(Generic[T]):
_cache: ClassVar[MutableMapping[type[Any], _ClassInfo[Any]]] = WeakKeyDictionary()

dataclass: type[T]
field_types: Mapping[str, type | Binder[Any]]
field_types: Mapping[str, type | GenericAlias | Binder[Any]]
_field_docstrings: Mapping[str, str] | None = None

@classmethod
Expand All @@ -225,7 +225,7 @@ def get(cls, dataclass: type[T]) -> _ClassInfo[T]:
except KeyError:
# Populate field_types *after* adding new instance to the cache to make sure
# _collect_type() will find the given dataclass if it's accessed recursively.
field_types: dict[str, type | Binder[Any]] = {}
field_types: dict[str, type | GenericAlias | Binder[Any]] = {}
info = cls(dataclass, field_types)
cls._cache[dataclass] = info
for field, field_type in _get_fields(dataclass):
Expand Down Expand Up @@ -371,7 +371,9 @@ def _bind_to_single_type(self, value: object, field_type: type, context: str) ->

raise TypeError(f"Value for '{context}' has type '{type(value).__name__}', expected '{field_type.__name__}'")

def _bind_to_field(self, value: object, field_type: type | Binder[Any], instance: T | None, context: str) -> object:
def _bind_to_field(
self, value: object, field_type: type | GenericAlias | Binder[Any], instance: T | None, context: str
) -> object:
"""
Convert a TOML value to a field type which is possibly a union type.

Expand Down Expand Up @@ -870,7 +872,7 @@ def get_field_docstrings(dataclass: type[Any]) -> Mapping[str, str]:
match node:
case ast.AnnAssign(target=ast.Name(id=name)):
scope = name
case ast.Expr(value=ast.Constant(value=docstring)):
case ast.Expr(value=ast.Constant(value=str(docstring))):
if scope is None:
# When using 'scope is not None', Coverage 6.4.4 will consider the 'is None' branch uncovered.
pass
Expand All @@ -884,7 +886,7 @@ def format_template(class_or_instance: Any) -> Iterator[str]:
yield from Binder(class_or_instance).format_toml_template()


def _format_value_for_type(field_type: type[Any]) -> str:
def _format_value_for_type(field_type: GenericAlias | type[Any]) -> str:
origin = get_origin(field_type)
if origin is None:
if field_type is str:
Expand Down