diff --git a/src/dataclass_binder/_impl.py b/src/dataclass_binder/_impl.py index 60ede50..87dc1ae 100644 --- a/src/dataclass_binder/_impl.py +++ b/src/dataclass_binder/_impl.py @@ -24,7 +24,7 @@ from inspect import cleandoc, get_annotations, getmodule, getsource, isabstract from pathlib import Path from textwrap import dedent -from types import MappingProxyType, ModuleType, NoneType, UnionType +from types import GenericAlias, MappingProxyType, ModuleType, NoneType, UnionType from typing import TYPE_CHECKING, Any, BinaryIO, ClassVar, Generic, TypeVar, Union, cast, get_args, get_origin, overload from weakref import WeakKeyDictionary @@ -37,7 +37,7 @@ # Note: Actually 'field_type' can either be a type of a typing special form, # but there is no way yet to annotate typing special forms. # This is the source of a lot of the casts and suppressions in this function. -def _collect_type(field_type: type, context: str) -> type | Binder[Any]: +def _collect_type(field_type: type, context: str) -> type | GenericAlias | Binder[Any]: """ Verify and streamline a type annotation. @@ -48,7 +48,7 @@ def _collect_type(field_type: type, context: str) -> type | Binder[Any]: """ origin = get_origin(field_type) if origin is None: - if field_type is Any: # type: ignore[comparison-overlap] + if field_type is Any: return object elif not isinstance(field_type, type): raise TypeError(f"Annotation for field '{context}' is not a type") @@ -94,7 +94,7 @@ def _collect_type(field_type: type, context: str) -> type | Binder[Any]: else: return origin[tuple(_collect_type(arg, arg_context) for arg in args)] # type: ignore[no-any-return] # Use the convention that the first argument is the element type. - return origin[_collect_type(args[0], arg_context)] # type: ignore[no-any-return] + return origin[_collect_type(args[0], arg_context)] elif origin is type: try: (arg,) = get_args(field_type) @@ -215,7 +215,7 @@ class _ClassInfo(Generic[T]): _cache: ClassVar[MutableMapping[type[Any], _ClassInfo[Any]]] = WeakKeyDictionary() dataclass: type[T] - field_types: Mapping[str, type | Binder[Any]] + field_types: Mapping[str, type | GenericAlias | Binder[Any]] _field_docstrings: Mapping[str, str] | None = None @classmethod @@ -225,7 +225,7 @@ def get(cls, dataclass: type[T]) -> _ClassInfo[T]: except KeyError: # Populate field_types *after* adding new instance to the cache to make sure # _collect_type() will find the given dataclass if it's accessed recursively. - field_types: dict[str, type | Binder[Any]] = {} + field_types: dict[str, type | GenericAlias | Binder[Any]] = {} info = cls(dataclass, field_types) cls._cache[dataclass] = info for field, field_type in _get_fields(dataclass): @@ -371,7 +371,9 @@ def _bind_to_single_type(self, value: object, field_type: type, context: str) -> raise TypeError(f"Value for '{context}' has type '{type(value).__name__}', expected '{field_type.__name__}'") - def _bind_to_field(self, value: object, field_type: type | Binder[Any], instance: T | None, context: str) -> object: + def _bind_to_field( + self, value: object, field_type: type | GenericAlias | Binder[Any], instance: T | None, context: str + ) -> object: """ Convert a TOML value to a field type which is possibly a union type. @@ -870,7 +872,7 @@ def get_field_docstrings(dataclass: type[Any]) -> Mapping[str, str]: match node: case ast.AnnAssign(target=ast.Name(id=name)): scope = name - case ast.Expr(value=ast.Constant(value=docstring)): + case ast.Expr(value=ast.Constant(value=str(docstring))): if scope is None: # When using 'scope is not None', Coverage 6.4.4 will consider the 'is None' branch uncovered. pass @@ -884,7 +886,7 @@ def format_template(class_or_instance: Any) -> Iterator[str]: yield from Binder(class_or_instance).format_toml_template() -def _format_value_for_type(field_type: type[Any]) -> str: +def _format_value_for_type(field_type: GenericAlias | type[Any]) -> str: origin = get_origin(field_type) if origin is None: if field_type is str: