diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 37516d6f..937d01a1 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -77,7 +77,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: "3.9 - 3.14" + python-version: "3.10 - 3.14" update-environment: true - name: Upgrade pip @@ -146,7 +146,6 @@ jobs: - { runner: windows-latest, platform: windows, archs: "auto32" } - { runner: windows-11-arm, platform: windows, archs: "ARM64" } python-version: - - "3.9" - "3.10" - "3.11" - "3.12" @@ -157,8 +156,6 @@ jobs: - "pypy-3.11" exclude: # Exclude unsupported Python versions - - python-version: "3.9" - target: { archs: "ARM64" } - python-version: "3.10" target: { archs: "ARM64" } - python-version: "pypy-3.11" @@ -387,7 +384,7 @@ jobs: if: startsWith(github.ref, 'refs/tags/') uses: actions/setup-python@v6 with: - python-version: "3.9 - 3.14" + python-version: "3.10 - 3.14" update-environment: true - name: Upgrade pip diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index 71042462..d3a4a6bf 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -67,7 +67,6 @@ jobs: matrix: runner: [ubuntu-latest, macos-latest, windows-latest] python-version: - - "3.9" - "3.10" - "3.11" - "3.12" @@ -75,8 +74,6 @@ jobs: - "3.14" python-abiflags: ["d", "td"] exclude: - - python-version: "3.9" - python-abiflags: "td" - python-version: "3.10" python-abiflags: "td" - python-version: "3.11" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d43a86cc..fa2d1ac3 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -66,7 +66,6 @@ jobs: matrix: runner: [ubuntu-latest, macos-latest, windows-latest] python-version: - - "3.9" - "3.10" - "3.11" - "3.12" diff --git a/CHANGELOG.md b/CHANGELOG.md index 65807279..244ff643 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,7 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Removed -- +- Drop Python 3.9 support by [@XuehaiPan](https://github.com/XuehaiPan) in [#272](https://github.com/metaopt/optree/pull/272). ------ diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 439104cc..2f08a271 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -88,13 +88,13 @@ python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=p It will build wheel binaries for all supported CPython versions. The outputs will be placed in the `wheelhouse` directory. To build a wheel for a specific CPython version, you can use the [`CIBW_BUILD`](https://cibuildwheel.readthedocs.io/en/stable/options/#build-skip) environment variable. -For example, the following command will build a wheel for Python 3.9: +For example, the following command will build a wheel for Python 3.10: ```bash -CIBW_BUILD="cp39*manylinux*" python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml +CIBW_BUILD="cp310*manylinux*" python3 -m cibuildwheel --platform=linux --output-dir=wheelhouse --config-file=pyproject.toml ``` -You can change `cp39*` to `cp312*` to build for Python 3.12. See for more options. +You can change `cp310*` to `cp312*` to build for Python 3.12. See for more options. ## Documentation diff --git a/README.md b/README.md index 7046b661..e62758fe 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ # OpTree -![Python 3.9+](https://img.shields.io/badge/Python-3.9%2B-brightgreen) +![Python 3.10+](https://img.shields.io/badge/Python-3.10%2B-brightgreen) [![PyPI](https://img.shields.io/pypi/v/optree?logo=pypi)](https://pypi.org/project/optree) ![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/metaopt/optree/build.yml?label=build&logo=github) ![GitHub Workflow Status](https://img.shields.io/github/actions/workflow/status/metaopt/optree/tests.yml?label=tests&logo=github) @@ -71,7 +71,7 @@ export pybind11_DIR="/path/to/custom/pybind11" pip3 install . ``` -Compiling from source requires Python 3.9+, a C++ compiler (`g++` / `clang++` / `icpx` / `cl.exe`) that supports C++20, and a `cmake` installation. +Compiling from source requires Python 3.10+, a C++ compiler (`g++` / `clang++` / `icpx` / `cl.exe`) that supports C++20, and a `cmake` installation. -------------------------------------------------------------------------------- diff --git a/docs/source/conf.py b/docs/source/conf.py index 7a5adfa0..c27a41f1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -281,7 +281,7 @@ def matches_pytree_typing_alias( return all( matches_pytree_typing_alias(arg, pat, recursive_pattern, recursive_ref_names) - for arg, pat in zip(annotation_args, pattern_args) + for arg, pat in zip(annotation_args, pattern_args, strict=True) ) for pytree_alias, pytree_instance in tuple(PyTree.__instances__.items()): diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index e0469047..b8bba882 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -23,8 +23,8 @@ limitations under the License. #include -#if !(defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x03090000) // Python 3.9 -# error "Python 3.9 or newer is required." +#if !(defined(PY_VERSION_HEX) && PY_VERSION_HEX >= 0x030A0000) // Python 3.10 +# error "Python 3.10 or newer is required." #endif #if !(defined(PYBIND11_VERSION_HEX) && PYBIND11_VERSION_HEX >= 0x020C00F0) // pybind11 2.12.0 diff --git a/optree/accessors.py b/optree/accessors.py index b4f3de93..cc2bffe0 100644 --- a/optree/accessors.py +++ b/optree/accessors.py @@ -17,7 +17,6 @@ from __future__ import annotations import dataclasses -import sys from collections.abc import Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeVar, overload from typing_extensions import Self # Python 3.11+ @@ -47,10 +46,7 @@ ] -SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} # Python 3.10+ - - -@dataclasses.dataclass(init=True, repr=False, eq=False, frozen=True, **SLOTS) +@dataclasses.dataclass(init=True, repr=False, eq=False, frozen=True, slots=True) class PyTreeEntry: """Base class for path entries.""" @@ -122,9 +118,6 @@ def codify(self, /, node: str = '') -> str: return f'{node}[]' # should be overridden -del SLOTS - - _T = TypeVar('_T') _T_co = TypeVar('_T_co', covariant=True) _KT_co = TypeVar('_KT_co', covariant=True) @@ -134,7 +127,7 @@ def codify(self, /, node: str = '') -> str: class AutoEntry(PyTreeEntry): """A generic path entry class that determines the entry type on creation automatically.""" - __slots__: ClassVar[tuple[()]] = () + __slots__: ClassVar[tuple[()]] = () # type: ignore[misc] def __new__( # type: ignore[misc] cls, @@ -184,7 +177,7 @@ def __new__( # type: ignore[misc] class GetItemEntry(PyTreeEntry): """A generic path entry class for nodes that access their children by :meth:`__getitem__`.""" - __slots__: ClassVar[tuple[()]] = () + __slots__: ClassVar[tuple[()]] = () # type: ignore[misc] def __call__(self, obj: Any, /) -> Any: """Get the child object.""" @@ -198,7 +191,7 @@ def codify(self, /, node: str = '') -> str: class GetAttrEntry(PyTreeEntry): """A generic path entry class for nodes that access their children by :meth:`__getattr__`.""" - __slots__: ClassVar[tuple[()]] = () + __slots__: ClassVar[tuple[()]] = () # type: ignore[misc] entry: str @@ -219,13 +212,13 @@ def codify(self, /, node: str = '') -> str: class FlattenedEntry(PyTreeEntry): # pylint: disable=too-few-public-methods """A fallback path entry class for flattened objects.""" - __slots__: ClassVar[tuple[()]] = () + __slots__: ClassVar[tuple[()]] = () # type: ignore[misc] class SequenceEntry(GetItemEntry, Generic[_T_co]): """A path entry class for sequences.""" - __slots__: ClassVar[tuple[()]] = () + __slots__: ClassVar[tuple[()]] = () # type: ignore[misc] entry: int type: builtins.type[Sequence[_T_co]] @@ -247,7 +240,7 @@ def __repr__(self, /) -> str: class MappingEntry(GetItemEntry, Generic[_KT_co, _VT_co]): """A path entry class for mappings.""" - __slots__: ClassVar[tuple[()]] = () + __slots__: ClassVar[tuple[()]] = () # type: ignore[misc] entry: _KT_co type: builtins.type[Mapping[_KT_co, _VT_co]] @@ -269,7 +262,7 @@ def __repr__(self, /) -> str: class NamedTupleEntry(SequenceEntry[_T]): """A path entry class for namedtuple objects.""" - __slots__: ClassVar[tuple[()]] = () + __slots__: ClassVar[tuple[()]] = () # type: ignore[misc] entry: int type: builtins.type[NamedTuple[_T]] # type: ignore[type-arg] @@ -299,7 +292,7 @@ def codify(self, /, node: str = '') -> str: class StructSequenceEntry(SequenceEntry[_T]): """A path entry class for PyStructSequence objects.""" - __slots__: ClassVar[tuple[()]] = () + __slots__: ClassVar[tuple[()]] = () # type: ignore[misc] entry: int type: builtins.type[StructSequence[_T]] @@ -329,7 +322,7 @@ def codify(self, /, node: str = '') -> str: class DataclassEntry(GetAttrEntry): """A path entry class for dataclasses.""" - __slots__: ClassVar[tuple[()]] = () + __slots__: ClassVar[tuple[()]] = () # type: ignore[misc] entry: str | int # type: ignore[assignment] diff --git a/optree/dataclasses.py b/optree/dataclasses.py index dfc871f3..50fb9b7e 100644 --- a/optree/dataclasses.py +++ b/optree/dataclasses.py @@ -69,14 +69,14 @@ import warnings from dataclasses import * # noqa: F401,F403,RUF100 # pylint: disable=wildcard-import,unused-wildcard-import from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable, Literal, Protocol, TypeVar, overload +from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload from typing_extensions import dataclass_transform # Python 3.11+ from optree.accessors import DataclassEntry if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Callable, Iterable __all__ = [ @@ -106,7 +106,7 @@ def field( hash: bool | None = None, # pylint: disable=redefined-builtin compare: bool = True, metadata: dict[Any, Any] | None = None, - kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+ + kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] doc: str | None = None, # Python 3.14+ pytree_node: bool | None = None, ) -> _T: ... @@ -121,7 +121,7 @@ def field( hash: bool | None = None, # pylint: disable=redefined-builtin compare: bool = True, metadata: dict[Any, Any] | None = None, - kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+ + kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] doc: str | None = None, # Python 3.14+ pytree_node: bool | None = None, ) -> _T: ... @@ -135,7 +135,7 @@ def field( hash: bool | None = None, # pylint: disable=redefined-builtin compare: bool = True, metadata: dict[Any, Any] | None = None, - kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+ + kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] doc: str | None = None, # Python 3.14+ pytree_node: bool | None = None, ) -> Any: ... @@ -150,7 +150,7 @@ def field( # noqa: D417 # pylint: disable=function-redefined hash: bool | None = None, # pylint: disable=redefined-builtin compare: bool = True, metadata: dict[Any, Any] | None = None, - kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] # Python 3.10+ + kw_only: bool | Literal[dataclasses.MISSING] = dataclasses.MISSING, # type: ignore[valid-type] doc: str | None = None, # Python 3.14+ pytree_node: bool | None = None, ) -> Any: @@ -192,13 +192,9 @@ def field( # noqa: D417 # pylint: disable=function-redefined 'hash': hash, 'compare': compare, 'metadata': metadata, + 'kw_only': kw_only, } - if sys.version_info >= (3, 10): # pragma: >=3.10 cover - kwargs['kw_only'] = kw_only - elif kw_only is not dataclasses.MISSING: # pragma: <3.10 cover - raise TypeError("field() got an unexpected keyword argument 'kw_only'") - if sys.version_info >= (3, 14): # pragma: >=3.14 cover kwargs['doc'] = doc elif doc is not None: # pragma: <3.14 cover @@ -222,9 +218,9 @@ def dataclass( order: bool = False, unsafe_hash: bool = False, frozen: bool = False, - match_args: bool = True, # Python 3.10+ - kw_only: bool = False, # Python 3.10+ - slots: bool = False, # Python 3.10+ + match_args: bool = True, + kw_only: bool = False, + slots: bool = False, weakref_slot: bool = False, # Python 3.11+ namespace: str, ) -> Callable[[_TypeT], _TypeT]: ... @@ -241,16 +237,16 @@ def dataclass( order: bool = False, unsafe_hash: bool = False, frozen: bool = False, - match_args: bool = True, # Python 3.10+ - kw_only: bool = False, # Python 3.10+ - slots: bool = False, # Python 3.10+ + match_args: bool = True, + kw_only: bool = False, + slots: bool = False, weakref_slot: bool = False, # Python 3.11+ namespace: str, ) -> _TypeT: ... @dataclass_transform(field_specifiers=(field,)) -def dataclass( # noqa: C901,D417 # pylint: disable=function-redefined +def dataclass( # noqa: D417 # pylint: disable=function-redefined cls: _TypeT | None = None, /, *, @@ -260,9 +256,9 @@ def dataclass( # noqa: C901,D417 # pylint: disable=function-redefined order: bool = False, unsafe_hash: bool = False, frozen: bool = False, - match_args: bool = True, # Python 3.10+ - kw_only: bool = False, # Python 3.10+ - slots: bool = False, # Python 3.10+ + match_args: bool = True, + kw_only: bool = False, + slots: bool = False, weakref_slot: bool = False, # Python 3.11+ namespace: str, ) -> _TypeT | Callable[[_TypeT], _TypeT]: @@ -286,19 +282,11 @@ def dataclass( # noqa: C901,D417 # pylint: disable=function-redefined 'order': order, 'unsafe_hash': unsafe_hash, 'frozen': frozen, + 'match_args': match_args, + 'kw_only': kw_only, + 'slots': slots, } - if sys.version_info >= (3, 10): # pragma: >=3.10 cover - kwargs['match_args'] = match_args - kwargs['kw_only'] = kw_only - kwargs['slots'] = slots - elif match_args is not True: # pragma: <3.10 cover - raise TypeError("dataclass() got an unexpected keyword argument 'match_args'") - elif kw_only is not False: # pragma: <3.10 cover - raise TypeError("dataclass() got an unexpected keyword argument 'kw_only'") - elif slots is not False: # pragma: <3.10 cover - raise TypeError("dataclass() got an unexpected keyword argument 'slots'") - if sys.version_info >= (3, 11): # pragma: >=3.11 cover kwargs['weakref_slot'] = weakref_slot elif weakref_slot is not False: # pragma: <3.11 cover @@ -360,9 +348,9 @@ def make_dataclass( # type: ignore[no-redef] # noqa: C901,D417 order: bool = False, unsafe_hash: bool = False, frozen: bool = False, - match_args: bool = True, # Python 3.10+ - kw_only: bool = False, # Python 3.10+ - slots: bool = False, # Python 3.10+ + match_args: bool = True, + kw_only: bool = False, + slots: bool = False, weakref_slot: bool = False, # Python 3.11+ module: str | None = None, # Python 3.12+ decorator: _DataclassDecorator[_TypeT] = dataclasses.dataclass, # type: ignore[assignment] # Python 3.14+ @@ -414,23 +402,15 @@ def make_dataclass( # type: ignore[no-redef] # noqa: C901,D417 'order': order, 'unsafe_hash': unsafe_hash, 'frozen': frozen, + 'match_args': match_args, + 'kw_only': kw_only, + 'slots': slots, } make_dataclass_kwargs = { 'bases': bases, 'namespace': ns, } - if sys.version_info >= (3, 10): # pragma: >=3.10 cover - dataclass_kwargs['match_args'] = match_args - dataclass_kwargs['kw_only'] = kw_only - dataclass_kwargs['slots'] = slots - elif match_args is not True: # pragma: <3.10 cover - raise TypeError("make_dataclass() got an unexpected keyword argument 'match_args'") - elif kw_only is not False: # pragma: <3.10 cover - raise TypeError("make_dataclass() got an unexpected keyword argument 'kw_only'") - elif slots is not False: # pragma: <3.10 cover - raise TypeError("make_dataclass() got an unexpected keyword argument 'slots'") - if sys.version_info >= (3, 11): # pragma: >=3.11 cover dataclass_kwargs['weakref_slot'] = weakref_slot elif weakref_slot is not False: # pragma: <3.11 cover @@ -596,7 +576,7 @@ def flatten_func( # pylint: disable-next=line-too-long def unflatten_func(metadata: tuple[tuple[str, Any], ...], children: tuple[_U, ...], /) -> _T: # type: ignore[type-var] - kwargs = dict(zip(children_field_names, children)) + kwargs = dict(zip(children_field_names, children, strict=True)) kwargs.update(metadata) return cls(**kwargs) # type: ignore[return-value] diff --git a/optree/functools.py b/optree/functools.py index da04b85b..046e8ae5 100644 --- a/optree/functools.py +++ b/optree/functools.py @@ -17,7 +17,7 @@ from __future__ import annotations import functools -from typing import TYPE_CHECKING, Any, Callable, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar from typing_extensions import Self # Python 3.11+ from optree import registry @@ -27,6 +27,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + from optree.accessors import PyTreeEntry diff --git a/optree/integrations/attrs.py b/optree/integrations/attrs.py index c70a3130..2ee2e199 100644 --- a/optree/integrations/attrs.py +++ b/optree/integrations/attrs.py @@ -56,7 +56,7 @@ import inspect import warnings from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Callable, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeVar, overload import attrs from attrs import ( @@ -84,6 +84,7 @@ if TYPE_CHECKING: + from collections.abc import Callable from typing import ClassVar @@ -462,7 +463,7 @@ def flatten_func( # pylint: disable-next=line-too-long def unflatten_func(metadata: tuple[tuple[str, Any], ...], children: tuple[_U, ...], /) -> _T: # type: ignore[type-var] - kwargs = dict(zip(children_aliases, children)) + kwargs = dict(zip(children_aliases, children, strict=True)) kwargs.update(metadata) return cls(**kwargs) diff --git a/optree/integrations/jax.py b/optree/integrations/jax.py index 9b576149..84bf0a67 100644 --- a/optree/integrations/jax.py +++ b/optree/integrations/jax.py @@ -41,8 +41,7 @@ import warnings from operator import itemgetter from types import FunctionType -from typing import Any, Callable -from typing_extensions import TypeAlias # Python 3.10+ +from typing import TYPE_CHECKING, Any, TypeAlias import jax.numpy as jnp from jax import Array, lax @@ -54,6 +53,10 @@ from optree.utils import safe_zip, total_order_sorted +if TYPE_CHECKING: + from collections.abc import Callable + + __all__ = ['ArrayLikeTree', 'ArrayTree', 'tree_ravel'] diff --git a/optree/integrations/numpy.py b/optree/integrations/numpy.py index f2450c1d..cecbb88b 100644 --- a/optree/integrations/numpy.py +++ b/optree/integrations/numpy.py @@ -22,8 +22,7 @@ import functools import itertools import warnings -from typing import Any, Callable -from typing_extensions import TypeAlias # Python 3.10+ +from typing import TYPE_CHECKING, Any, TypeAlias import numpy as np from numpy.typing import ArrayLike @@ -33,6 +32,10 @@ from optree.utils import safe_zip +if TYPE_CHECKING: + from collections.abc import Callable + + __all__ = ['ArrayLikeTree', 'ArrayTree', 'tree_ravel'] diff --git a/optree/integrations/torch.py b/optree/integrations/torch.py index cf8a1a03..dccbec84 100644 --- a/optree/integrations/torch.py +++ b/optree/integrations/torch.py @@ -21,8 +21,7 @@ import functools import warnings -from typing import Any, Callable -from typing_extensions import TypeAlias # Python 3.10+ +from typing import TYPE_CHECKING, Any, TypeAlias import torch @@ -31,6 +30,10 @@ from optree.utils import safe_zip +if TYPE_CHECKING: + from collections.abc import Callable + + __all__ = ['TensorTree', 'tree_ravel'] diff --git a/optree/ops.py b/optree/ops.py index 1dcec0d6..f44289a2 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -23,7 +23,7 @@ import itertools import textwrap from collections import OrderedDict, defaultdict, deque -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, overload import optree._C as _C from optree.accessors import PyTreeAccessor @@ -32,7 +32,7 @@ if TYPE_CHECKING: import builtins - from collections.abc import Collection, Iterable, Mapping + from collections.abc import Callable, Collection, Iterable, Mapping from optree.accessors import PyTreeEntry from optree.typing import ( @@ -1233,7 +1233,7 @@ def tree_transpose( leaves[offset : offset + inner_size] for offset in range(0, outer_size * inner_size, inner_size) ] - transposed = zip(*grouped) + transposed = zip(*grouped, strict=True) subtrees = map(outer_treespec.unflatten, transposed) return inner_treespec.unflatten(subtrees) # type: ignore[arg-type] @@ -1335,7 +1335,7 @@ def tree_transpose_map( raise ValueError(f'The inner structure must have at least one leaf. Got: {inner_treespec}.') grouped = [inner_treespec.flatten_up_to(o) for o in outputs] - transposed = zip(*grouped) + transposed = zip(*grouped, strict=True) subtrees = map(outer_treespec.unflatten, transposed) return inner_treespec.unflatten(subtrees) # type: ignore[arg-type] @@ -1422,7 +1422,7 @@ def tree_transpose_map_with_path( raise ValueError(f'The inner structure must have at least one leaf. Got: {inner_treespec}.') grouped = [inner_treespec.flatten_up_to(o) for o in outputs] - transposed = zip(*grouped) + transposed = zip(*grouped, strict=True) subtrees = map(outer_treespec.unflatten, transposed) return inner_treespec.unflatten(subtrees) # type: ignore[arg-type] @@ -1536,7 +1536,7 @@ def tree_transpose_map_with_accessor( raise ValueError(f'The inner structure must have at least one leaf. Got: {inner_treespec}.') grouped = [inner_treespec.flatten_up_to(o) for o in outputs] - transposed = zip(*grouped) + transposed = zip(*grouped, strict=True) subtrees = map(outer_treespec.unflatten, transposed) return inner_treespec.unflatten(subtrees) # type: ignore[arg-type] @@ -3725,7 +3725,7 @@ def helper( # pylint: disable=too-many-locals or entries == entries_ ), f'equal pytree nodes gave different keys: {entries} and {entries_}' # pylint: disable-next=invalid-name - for e, t1, t2 in zip(entries, prefix_tree_children, full_tree_children): + for e, t1, t2 in zip(entries, prefix_tree_children, full_tree_children, strict=True): yield from helper(accessor + e, t1, t2) return list(helper(PyTreeAccessor(), prefix_tree, full_tree)) diff --git a/optree/pytree.py b/optree/pytree.py index 3e725652..4e711bdb 100644 --- a/optree/pytree.py +++ b/optree/pytree.py @@ -132,8 +132,7 @@ if _TYPE_CHECKING: from collections.abc import Callable, Iterable - from typing import Any, TypeVar # pylint: disable=ungrouped-imports - from typing_extensions import ParamSpec # Python 3.10+ + from typing import Any, ParamSpec, TypeVar # pylint: disable=ungrouped-imports _P = ParamSpec('_P') _T = TypeVar('_T') diff --git a/optree/registry.py b/optree/registry.py index 96ea5707..0d600f04 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -22,11 +22,10 @@ import dataclasses import functools import inspect -import sys from collections import OrderedDict, defaultdict, deque, namedtuple from operator import itemgetter, methodcaller from threading import Lock -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, NamedTuple, TypeVar, overload +from typing import TYPE_CHECKING, Any, ClassVar, Generic, NamedTuple, TypeVar, overload import optree._C as _C from optree.accessors import ( @@ -51,7 +50,7 @@ if TYPE_CHECKING: import builtins - from collections.abc import Collection, Generator, Iterable + from collections.abc import Callable, Collection, Generator, Iterable from optree.typing import KT, VT, CustomTreeNode, FlattenFunc, UnflattenFunc @@ -67,10 +66,7 @@ ] -SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} # Python 3.10+ - - -@dataclasses.dataclass(init=True, repr=True, eq=True, frozen=True, **SLOTS) +@dataclasses.dataclass(init=True, repr=True, eq=True, frozen=True, slots=True) class PyTreeNodeRegistryEntry(Generic[T]): """A dataclass that stores the information of a pytree node type.""" @@ -78,17 +74,13 @@ class PyTreeNodeRegistryEntry(Generic[T]): flatten_func: FlattenFunc[T] unflatten_func: UnflattenFunc[T] - if sys.version_info >= (3, 10): # pragma: >=3.10 cover - _: dataclasses.KW_ONLY # Python 3.10+ + _: dataclasses.KW_ONLY path_entry_type: builtins.type[PyTreeEntry] = AutoEntry kind: PyTreeKind = PyTreeKind.CUSTOM namespace: str = '' -del SLOTS - - # pylint: disable-next=missing-class-docstring,too-few-public-methods class GlobalNamespace: # pragma: no cover __slots__: ClassVar[tuple[()]] = () @@ -103,7 +95,7 @@ def __repr__(self, /) -> str: if TYPE_CHECKING: - from typing_extensions import ParamSpec # Python 3.10+ + from typing import ParamSpec _P = ParamSpec('_P') _T = TypeVar('_T') @@ -219,11 +211,9 @@ def pytree_node_registry_get( # noqa: C901 and cls is not namedtuple # noqa: PYI024 and not inspect.isclass(cls) ): - raise TypeError(f'Expected a class or None, got {cls!r}.') # pragma: !=3.9 cover + raise TypeError(f'Expected a class or None, got {cls!r}.') if not isinstance(namespace, str): - raise TypeError( # pragma: !=3.9 cover - f'The namespace must be a string, got {namespace!r}.', - ) + raise TypeError(f'The namespace must be a string, got {namespace!r}.') if cls is None: namespaces = frozenset({namespace, ''}) diff --git a/optree/typing.py b/optree/typing.py index bbf3b98e..10885ca9 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -29,6 +29,7 @@ from collections import defaultdict as DefaultDict # noqa: N812 from collections import deque as Deque # noqa: N812 from collections.abc import ( + Callable, Collection, Hashable, ItemsView, @@ -40,13 +41,13 @@ ) from typing import ( Any, - Callable, ClassVar, Final, ForwardRef, Generic, - Optional, + ParamSpec, Protocol, + TypeAlias, TypeVar, Union, final, @@ -56,9 +57,7 @@ from typing_extensions import ( NamedTuple, # Generic NamedTuple: Python 3.11+ Never, # Python 3.11+ - ParamSpec, # Python 3.10+ Self, # Python 3.11+ - TypeAlias, # Python 3.10+ TypeAliasType, # Python 3.12+ ) from weakref import WeakKeyDictionary @@ -142,7 +141,7 @@ Children: TypeAlias = Iterable[T] -MetaData: TypeAlias = Optional[Hashable] +MetaData: TypeAlias = Hashable | None @runtime_checkable @@ -166,7 +165,12 @@ def __tree_unflatten__(cls, metadata: MetaData, children: Children[T], /) -> Sel """Unflatten the children and metadata into the custom pytree node.""" -_UnionType = type(Union[int, str]) +# Before Python 3.14, `Union[int, str]` produces `typing._UnionGenericAlias` while `int | str` +# produces `types.UnionType` -- they are different types. On Python 3.14+, the two are unified and +# `Union[int, str]` also produces `types.UnionType`. Using `type(Union[int, str])` here ensures +# `_UnionType` automatically matches the pytree alias type on all supported Python versions. See +# the comment at `__class_getitem__` below for why the pytree aliases use `Union[...]`. +_UnionType = type(Union[int, str]) # noqa: UP007 try: # pragma: no cover @@ -261,7 +265,12 @@ def __class_getitem__( # noqa: C901 # pylint: disable=too-many-branches else: recurse_ref = ForwardRef(f'{cls.__name__}[{param!r}]') - pytree_alias = Union[ + # We use `Union[...]` explicitly rather than chained `|` for clarity. Before Python 3.14, + # chained `|` with `typing._GenericAlias` operands (e.g., `Tuple[x]`, `List[y]`) would still + # produce `typing._UnionGenericAlias` (not `types.UnionType`) via `__or__`/`__ror__`. + # On Python 3.14+, both `Union[...]` and `|` produce `types.UnionType`. + # TODO(PEP 604): migrate to `|` when minimum Python is raised to 3.14+. + pytree_alias = Union[ # noqa: UP007 param, # type: ignore[valid-type] Tuple[recurse_ref, ...], # type: ignore[valid-type] # Tuple, NamedTuple, PyStructSequence List[recurse_ref], # type: ignore[valid-type] diff --git a/optree/utils.py b/optree/utils.py index 502cf41a..f5a01d67 100644 --- a/optree/utils.py +++ b/optree/utils.py @@ -16,8 +16,8 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence -from typing import TYPE_CHECKING, Any, Callable, overload +from collections.abc import Callable, Iterable, Sequence +from typing import TYPE_CHECKING, Any, overload if TYPE_CHECKING: @@ -101,7 +101,7 @@ def safe_zip(*args: Iterable[Any]) -> zip[tuple[Any, ...]]: seqs = [arg if isinstance(arg, Sequence) else list(arg) for arg in args] if len(set(map(len, seqs))) > 1: raise ValueError(f'length mismatch: {list(map(len, seqs))}') - return zip(*seqs) + return zip(*seqs, strict=True) def unzip2(xys: Iterable[tuple[T, S]], /) -> tuple[tuple[T, ...], tuple[S, ...]]: diff --git a/pyproject.toml b/pyproject.toml index bd918447..4304a1c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "setuptools.build_meta" name = "optree" description = "Optimized PyTree Utilities." readme = "README.md" -requires-python = ">= 3.9" +requires-python = ">= 3.10" authors = [{ name = "OpTree Contributors" }] license = "Apache-2.0" keywords = [ @@ -21,7 +21,6 @@ classifiers = [ "Development Status :: 4 - Beta", "Programming Language :: C++", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", @@ -155,7 +154,7 @@ environment = { CMAKE_OSX_SYSROOT = "iphonesimulator" } # Linter tools ################################################################# [tool.mypy] -python_version = "3.9" +python_version = "3.10" exclude = ['^tests/.*\.py$', '^(third|3rd)[_\-]?party/.*$', '^\.?venv/.*$'] pretty = true show_column_numbers = true @@ -178,7 +177,7 @@ warn_unused_ignores = true no_site_packages = true [tool.pylint] -main.py-version = "3.9" +main.py-version = "3.10" main.extension-pkg-allow-list = ["optree._C"] main.ignore-paths = ['^tests/$', '^(third|3rd)[_\-]?party/$'] basic.good-names = [] @@ -204,7 +203,7 @@ builtin = "clear,rare,en-GB_to_en-US" ignore-words = "docs/source/spelling_wordlist.txt" [tool.ruff] -target-version = "py39" +target-version = "py310" line-length = 100 output-format = "full" src = ["optree", "tests"] diff --git a/tests/helpers.py b/tests/helpers.py index 7f426b11..28043dac 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -121,7 +121,7 @@ def represent(value): return repr(value) ids = tuple( - '-'.join(f'{arg}({represent(value)})' for arg, value in zip(arguments, values)) + '-'.join(f'{arg}({represent(value)})' for arg, value in zip(arguments, values, strict=True)) for values in argvalues ) @@ -222,7 +222,7 @@ def assert_equal_type_and_value(actual, expected=MISSING, *, expected_type=None) assert actual == expected if isinstance(expected, optree.PyTreeAccessor): assert hash(actual) == hash(expected) - for i, j in zip(actual, expected): + for i, j in zip(actual, expected, strict=True): assert_equal_type_and_value(i, j) @@ -459,7 +459,7 @@ def __tree_flatten__(self): @classmethod def __tree_unflatten__(cls, metadata, children): - return cls(zip(metadata, children)) + return cls(zip(metadata, children, strict=True)) def __repr__(self): return f'{self.__class__.__name__}({super().__repr__()})' diff --git a/tests/integrations/test_jax.py b/tests/integrations/test_jax.py index ddd70b00..4cc0ee18 100644 --- a/tests/integrations/test_jax.py +++ b/tests/integrations/test_jax.py @@ -72,7 +72,7 @@ def replace_leaf(_): reconstructed_leaves, reconstructed_treespec = optree.tree_flatten(reconstructed) assert reconstructed_treespec == treespec assert len(leaves) == len(reconstructed_leaves) - for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves): + for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves, strict=True): assert jnp.allclose(leaf, reconstructed_leaf) leaf = jnp.asarray(leaf) reconstructed_leaf = jnp.asarray(reconstructed_leaf) @@ -133,7 +133,7 @@ def replace_leaf(_): reconstructed_leaves, reconstructed_treespec = optree.tree_flatten(reconstructed) assert reconstructed_treespec == treespec assert len(leaves) == len(reconstructed_leaves) - for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves): + for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves, strict=True): assert jnp.allclose(leaf, reconstructed_leaf) leaf = jnp.asarray(leaf) reconstructed_leaf = jnp.asarray(reconstructed_leaf) diff --git a/tests/integrations/test_numpy.py b/tests/integrations/test_numpy.py index 6660940c..2e5d5bda 100644 --- a/tests/integrations/test_numpy.py +++ b/tests/integrations/test_numpy.py @@ -66,7 +66,7 @@ def replace_leaf(_): reconstructed_leaves, reconstructed_treespec = optree.tree_flatten(reconstructed) assert reconstructed_treespec == treespec assert len(leaves) == len(reconstructed_leaves) - for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves): + for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves, strict=True): assert np.allclose(leaf, reconstructed_leaf) leaf = np.asarray(leaf) reconstructed_leaf = np.asarray(reconstructed_leaf) @@ -126,7 +126,7 @@ def replace_leaf(_): reconstructed_leaves, reconstructed_treespec = optree.tree_flatten(reconstructed) assert reconstructed_treespec == treespec assert len(leaves) == len(reconstructed_leaves) - for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves): + for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves, strict=True): assert np.allclose(leaf, reconstructed_leaf) leaf = np.asarray(leaf) reconstructed_leaf = np.asarray(reconstructed_leaf) diff --git a/tests/integrations/test_torch.py b/tests/integrations/test_torch.py index c55cd26d..6cbf9c89 100644 --- a/tests/integrations/test_torch.py +++ b/tests/integrations/test_torch.py @@ -74,7 +74,7 @@ def replace_leaf(_): reconstructed_leaves, reconstructed_treespec = optree.tree_flatten(reconstructed) assert reconstructed_treespec == treespec assert len(leaves) == len(reconstructed_leaves) - for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves): + for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves, strict=True): assert torch.is_tensor(leaf) assert torch.is_tensor(reconstructed_leaf) assert torch.allclose(leaf, reconstructed_leaf) @@ -136,7 +136,7 @@ def replace_leaf(_): reconstructed_leaves, reconstructed_treespec = optree.tree_flatten(reconstructed) assert reconstructed_treespec == treespec assert len(leaves) == len(reconstructed_leaves) - for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves): + for leaf, reconstructed_leaf in zip(leaves, reconstructed_leaves, strict=True): assert torch.is_tensor(leaf) assert torch.is_tensor(reconstructed_leaf) assert torch.allclose(leaf, reconstructed_leaf) diff --git a/tests/test_accessors.py b/tests/test_accessors.py index bd50bc10..71fe18ad 100644 --- a/tests/test_accessors.py +++ b/tests/test_accessors.py @@ -456,7 +456,7 @@ def __tree_unflatten__(cls, metadata, children): accessors, leaves, _ = optree.tree_flatten_with_accessor(obj, namespace='namespace') assert leaves == [1, 2, 3] assert accessors == expected_accessors - for a, b in zip(accessors, expected_accessors): + for a, b in zip(accessors, expected_accessors, strict=True): assert_equal_type_and_value(a, b) for accessor in accessors: diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index 8a4bae98..4371d7e2 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -106,20 +106,10 @@ def test_field_future_parameters(): optree.dataclasses.field() dataclasses.field() - if sys.version_info >= (3, 10): - optree.dataclasses.field(kw_only=True) - dataclasses.field(kw_only=True) - optree.dataclasses.field(kw_only=False) - dataclasses.field(kw_only=False) - else: - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - optree.dataclasses.field(kw_only=True) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.field(kw_only=True) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - optree.dataclasses.field(kw_only=False) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.field(kw_only=False) + optree.dataclasses.field(kw_only=True) + dataclasses.field(kw_only=True) + optree.dataclasses.field(kw_only=False) + dataclasses.field(kw_only=False) if sys.version_info >= (3, 14): optree.dataclasses.field(doc='doc') @@ -179,41 +169,18 @@ def test_dataclass_future_parameters(): optree.dataclasses.dataclass(namespace='namespace') dataclasses.dataclass() - if sys.version_info >= (3, 10): - optree.dataclasses.dataclass(match_args=True, namespace='namespace') - dataclasses.dataclass(match_args=True) - optree.dataclasses.dataclass(match_args=False, namespace='namespace') - dataclasses.dataclass(match_args=False) - optree.dataclasses.dataclass(kw_only=True, namespace='namespace') - dataclasses.dataclass(kw_only=True) - optree.dataclasses.dataclass(kw_only=False, namespace='namespace') - dataclasses.dataclass(kw_only=False) - optree.dataclasses.dataclass(slots=True, namespace='namespace') - dataclasses.dataclass(slots=True) - optree.dataclasses.dataclass(slots=False, namespace='namespace') - dataclasses.dataclass(slots=False) - else: - optree.dataclasses.dataclass(match_args=True, namespace='namespace') - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.dataclass(match_args=True) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - optree.dataclasses.dataclass(match_args=False, namespace='error') - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.dataclass(match_args=False) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - optree.dataclasses.dataclass(kw_only=True, namespace='error') - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.dataclass(kw_only=True) - optree.dataclasses.dataclass(kw_only=False, namespace='namespace') - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.dataclass(kw_only=False) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - optree.dataclasses.dataclass(slots=True, namespace='error') - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.dataclass(slots=True) - optree.dataclasses.dataclass(slots=False, namespace='namespace') - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.dataclass(slots=False) + optree.dataclasses.dataclass(match_args=True, namespace='namespace') + dataclasses.dataclass(match_args=True) + optree.dataclasses.dataclass(match_args=False, namespace='namespace') + dataclasses.dataclass(match_args=False) + optree.dataclasses.dataclass(kw_only=True, namespace='namespace') + dataclasses.dataclass(kw_only=True) + optree.dataclasses.dataclass(kw_only=False, namespace='namespace') + dataclasses.dataclass(kw_only=False) + optree.dataclasses.dataclass(slots=True, namespace='namespace') + dataclasses.dataclass(slots=True) + optree.dataclasses.dataclass(slots=False, namespace='namespace') + dataclasses.dataclass(slots=False) if sys.version_info >= (3, 11): optree.dataclasses.dataclass(weakref_slot=True, namespace='namespace') @@ -442,149 +409,72 @@ def test_make_dataclass_future_parameters(): }, ) - if sys.version_info >= (3, 10): - optree.dataclasses.make_dataclass( - 'Foo2', - ['x', ('y', int), ('z', float, 0.0)], - match_args=True, - namespace='namespace', - ) - dataclasses.make_dataclass( - 'Foo3', - ['x', ('y', int), ('z', float, 0.0)], - match_args=True, - ) - optree.dataclasses.make_dataclass( - 'Foo4', - ['x', ('y', int), ('z', float, 0.0)], - match_args=False, - namespace='namespace', - ) - dataclasses.make_dataclass( - 'Foo5', - ['x', ('y', int), ('z', float, 0.0)], - match_args=False, - ) - optree.dataclasses.make_dataclass( - 'Foo6', - ['x', ('y', int), ('z', float, 0.0)], - kw_only=True, - namespace='namespace', - ) - dataclasses.make_dataclass( - 'Foo7', - ['x', ('y', int), ('z', float, 0.0)], - kw_only=True, - ) - optree.dataclasses.make_dataclass( - 'Foo8', - ['x', ('y', int), ('z', float, 0.0)], - kw_only=False, - namespace='namespace', - ) - dataclasses.make_dataclass( - 'Foo9', - ['x', ('y', int), ('z', float, 0.0)], - kw_only=False, - ) - optree.dataclasses.make_dataclass( - 'Foo10', - ['x', ('y', int), ('z', float, 0.0)], - slots=True, - namespace='namespace', - ) - dataclasses.make_dataclass( - 'Foo11', - ['x', ('y', int), ('z', float, 0.0)], - slots=True, - ) - optree.dataclasses.make_dataclass( - 'Foo12', - ['x', ('y', int), ('z', float, 0.0)], - slots=False, - namespace='namespace', - ) - dataclasses.make_dataclass( - 'Foo13', - ['x', ('y', int), ('z', float, 0.0)], - slots=False, - ) - else: - optree.dataclasses.make_dataclass( - 'Foo2', - ['x', ('y', int), ('z', float, 0.0)], - match_args=True, - namespace='namespace', - ) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.make_dataclass( - 'Foo3', - ['x', ('y', int), ('z', float, 0.0)], - match_args=True, - ) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - optree.dataclasses.make_dataclass( - 'Foo4', - ['x', ('y', int), ('z', float, 0.0)], - match_args=False, - namespace='error', - ) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.make_dataclass( - 'Foo5', - ['x', ('y', int), ('z', float, 0.0)], - match_args=False, - ) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - optree.dataclasses.make_dataclass( - 'Foo6', - ['x', ('y', int), ('z', float, 0.0)], - kw_only=True, - namespace='error', - ) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.make_dataclass( - 'Foo7', - ['x', ('y', int), ('z', float, 0.0)], - kw_only=True, - ) - optree.dataclasses.make_dataclass( - 'Foo8', - ['x', ('y', int), ('z', float, 0.0)], - kw_only=False, - namespace='namespace', - ) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.make_dataclass( - 'Foo9', - ['x', ('y', int), ('z', float, 0.0)], - kw_only=False, - ) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - optree.dataclasses.make_dataclass( - 'Foo10', - ['x', ('y', int), ('z', float, 0.0)], - slots=True, - namespace='error', - ) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.make_dataclass( - 'Foo11', - ['x', ('y', int), ('z', float, 0.0)], - slots=True, - ) - optree.dataclasses.make_dataclass( - 'Foo12', - ['x', ('y', int), ('z', float, 0.0)], - slots=False, - namespace='namespace', - ) - with pytest.raises(TypeError, match='got an unexpected keyword argument'): - dataclasses.make_dataclass( - 'Foo13', - ['x', ('y', int), ('z', float, 0.0)], - slots=False, - ) + optree.dataclasses.make_dataclass( + 'Foo2', + ['x', ('y', int), ('z', float, 0.0)], + match_args=True, + namespace='namespace', + ) + dataclasses.make_dataclass( + 'Foo3', + ['x', ('y', int), ('z', float, 0.0)], + match_args=True, + ) + optree.dataclasses.make_dataclass( + 'Foo4', + ['x', ('y', int), ('z', float, 0.0)], + match_args=False, + namespace='namespace', + ) + dataclasses.make_dataclass( + 'Foo5', + ['x', ('y', int), ('z', float, 0.0)], + match_args=False, + ) + optree.dataclasses.make_dataclass( + 'Foo6', + ['x', ('y', int), ('z', float, 0.0)], + kw_only=True, + namespace='namespace', + ) + dataclasses.make_dataclass( + 'Foo7', + ['x', ('y', int), ('z', float, 0.0)], + kw_only=True, + ) + optree.dataclasses.make_dataclass( + 'Foo8', + ['x', ('y', int), ('z', float, 0.0)], + kw_only=False, + namespace='namespace', + ) + dataclasses.make_dataclass( + 'Foo9', + ['x', ('y', int), ('z', float, 0.0)], + kw_only=False, + ) + optree.dataclasses.make_dataclass( + 'Foo10', + ['x', ('y', int), ('z', float, 0.0)], + slots=True, + namespace='namespace', + ) + dataclasses.make_dataclass( + 'Foo11', + ['x', ('y', int), ('z', float, 0.0)], + slots=True, + ) + optree.dataclasses.make_dataclass( + 'Foo12', + ['x', ('y', int), ('z', float, 0.0)], + slots=False, + namespace='namespace', + ) + dataclasses.make_dataclass( + 'Foo13', + ['x', ('y', int), ('z', float, 0.0)], + slots=False, + ) if sys.version_info >= (3, 11): optree.dataclasses.make_dataclass( diff --git a/tests/test_ops.py b/tests/test_ops.py index fc70fd8f..e97779a2 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -20,9 +20,7 @@ import itertools import operator import pickle -import platform import re -import sys from collections import OrderedDict, defaultdict, deque import pytest @@ -39,7 +37,6 @@ CustomTuple, FlatCache, MyAnotherDict, - Py_DEBUG, always, assert_equal_type_and_value, check_script_in_subprocess, @@ -62,9 +59,6 @@ def test_import_no_warnings(): def test_max_depth(): - if sys.version_info < (3, 10) and platform.system() == 'Windows' and Py_DEBUG: - pytest.skip('Flaky with Python 3.9 on Windows in debug mode.') - lst = [1] for _ in range(optree.MAX_RECURSION_DEPTH - 1): lst = [lst] @@ -557,7 +551,7 @@ def test_paths_and_accessors(data): assert other_treespec == expected_treespec assert paths == expected_paths assert accessors == expected_accessors - for leaf, accessor, path in zip(leaves, accessors, paths): + for leaf, accessor, path in zip(leaves, accessors, paths, strict=True): assert isinstance(accessor, optree.PyTreeAccessor) assert isinstance(path, tuple) assert len(accessor) == len(path) @@ -628,7 +622,7 @@ def test_paths_and_accessors_with_is_leaf( assert treespec == expected_treespec assert other_leaves == expected_leaves assert other_treespec == expected_treespec - for leaf, accessor, path in zip(leaves, accessors, paths): + for leaf, accessor, path in zip(leaves, accessors, paths, strict=True): assert isinstance(accessor, optree.PyTreeAccessor) assert isinstance(path, tuple) assert len(accessor) == len(path) @@ -3398,7 +3392,7 @@ def flatten(node): # noqa: C901 assert node_kind == optree.PyTreeKind.CUSTOM assert len(entries) == len(children) if hasattr(node, '__getitem__'): - for child, entry in zip(children, entries): + for child, entry in zip(children, entries, strict=True): assert node[entry] is child assert unflatten_func(metadata, children) == node @@ -3407,7 +3401,7 @@ def flatten(node): # noqa: C901 with pytest.raises(ValueError, match=re.escape('Expected no children.')): unflatten_func(metadata, range(1)) - for child, entry in zip(children, entries): + for child, entry in zip(children, entries, strict=True): path_stack.append(entry) accessor_stack.append(output.path_entry_type(entry, node_type, node_kind)) flatten(child) diff --git a/tests/test_registry.py b/tests/test_registry.py index 229dd50e..069e391d 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -16,7 +16,6 @@ # pylint: disable=missing-function-docstring,invalid-name import re -import sys import weakref from collections import UserDict, UserList, namedtuple from dataclasses import dataclass @@ -28,7 +27,6 @@ from helpers import ( GLOBAL_NAMESPACE, NODETYPE_REGISTRY, - PYPY, Py_GIL_DISABLED, disable_systrace, gc_collect, @@ -580,11 +578,10 @@ def test_pytree_node_registry_get_with_invalid_arguments(): assert optree.register_pytree_node.get(None) == registry assert optree.register_pytree_node.get(namespace=GLOBAL_NAMESPACE) == registry assert optree.register_pytree_node.get(namedtuple) is registry[namedtuple] # noqa: PYI024 - if sys.version_info[:2] != (3, 9) or PYPY: - with pytest.raises(TypeError, match='Expected a class or None'): - optree.register_pytree_node.get(dataclass) - with pytest.raises(TypeError, match='The namespace must be a string'): - optree.register_pytree_node.get(list, namespace=None) + with pytest.raises(TypeError, match='Expected a class or None'): + optree.register_pytree_node.get(dataclass) + with pytest.raises(TypeError, match='The namespace must be a string'): + optree.register_pytree_node.get(list, namespace=None) def test_pytree_node_registry_with_init_subclass(): @@ -600,7 +597,7 @@ def __tree_flatten__(self): @classmethod def __tree_unflatten__(cls, metadata, children): - return cls(zip(metadata, children)) + return cls(zip(metadata, children, strict=True)) class MyAnotherDict(MyDict): pass diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 63fa2796..ec180fb3 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -704,7 +704,11 @@ def test_treespec_compose_children( stack = [(composed_treespec.children(), expected_treespec.children())] while stack: composed_children, expected_children = stack.pop() - for composed_child, expected_child in zip(composed_children, expected_children): + for composed_child, expected_child in zip( + composed_children, + expected_children, + strict=True, + ): assert composed_child == expected_child stack.append((composed_child.children(), expected_child.children())) @@ -786,7 +790,7 @@ def gen_path(spec): yield () return - for entry, child in zip(entries, children): + for entry, child in zip(entries, children, strict=True): for suffix in gen_path(child): yield (entry, *suffix) @@ -815,7 +819,7 @@ def gen_typed_path(spec): node_type = spec.type node_kind = spec.kind - for entry, child in zip(entries, children): + for entry, child in zip(entries, children, strict=True): for suffix in gen_typed_path(child): yield ((entry, node_type, node_kind), *suffix) @@ -1078,11 +1082,11 @@ def test_treespec_transform(): ) == optree.tree_structure([[1, 2, 3], [4]]) assert optree.treespec_transform( treespec, - lambda spec: optree.treespec_dict(zip('abcd', spec.children())), + lambda spec: optree.treespec_dict(zip('abcd', spec.children(), strict=False)), ) == optree.tree_structure({'a': {'a': 0, 'b': 1, 'c': 2}, 'b': {'a': 3}}) assert optree.treespec_transform( treespec, - lambda spec: optree.treespec_dict(zip('abcd', spec.children())), + lambda spec: optree.treespec_dict(zip('abcd', spec.children(), strict=False)), lambda spec: optree.tree_structure([0, None, 1]), ) == optree.tree_structure( {'a': {'a': [0, None, 1], 'b': [2, None, 3], 'c': [4, None, 5]}, 'b': {'a': [6, None, 7]}}, @@ -1095,7 +1099,7 @@ def test_treespec_transform(): optree.treespec_transform( treespec, lambda spec: optree.tree_structure( - MyAnotherDict(zip(spec.entries(), spec.children())), + MyAnotherDict(zip(spec.entries(), spec.children(), strict=True)), namespace='namespace', ), ) @@ -1135,7 +1139,10 @@ def test_treespec_transform(): def fn(spec): with optree.dict_insertion_ordered(True, namespace='undefined'): - return optree.treespec_dict(zip('abcd', spec.children()), namespace='undefined') + return optree.treespec_dict( + zip('abcd', spec.children(), strict=False), + namespace='undefined', + ) with pytest.raises(ValueError, match=r'Expected treespec\(s\) with namespace .*, got .*\.'): optree.treespec_transform(namespaced_treespec, fn) @@ -1556,7 +1563,7 @@ def test_treespec_constructor( # noqa: C901 if dict_should_be_sorted or dict_session_namespace not in {'', namespace}: assert ( optree.treespec_dict( - zip(sorted(node), children_treespecs), + zip(sorted(node), children_treespecs, strict=True), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1564,7 +1571,7 @@ def test_treespec_constructor( # noqa: C901 ) assert ( optree.treespec_from_collection( - dict(zip(sorted(node), children_treespecs)), + dict(zip(sorted(node), children_treespecs, strict=True)), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1582,7 +1589,7 @@ def test_treespec_constructor( # noqa: C901 with context: assert ( optree.treespec_dict( - zip(node, children_treespecs), + zip(node, children_treespecs, strict=True), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1590,7 +1597,7 @@ def test_treespec_constructor( # noqa: C901 ) assert ( optree.treespec_from_collection( - dict(zip(node, children_treespecs)), + dict(zip(node, children_treespecs, strict=True)), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1599,7 +1606,7 @@ def test_treespec_constructor( # noqa: C901 elif node_type is OrderedDict: assert ( optree.treespec_ordereddict( - zip(node, children_treespecs), + zip(node, children_treespecs, strict=True), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1607,7 +1614,7 @@ def test_treespec_constructor( # noqa: C901 ) assert ( optree.treespec_from_collection( - OrderedDict(zip(node, children_treespecs)), + OrderedDict(zip(node, children_treespecs, strict=True)), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1618,7 +1625,7 @@ def test_treespec_constructor( # noqa: C901 assert ( optree.treespec_defaultdict( node.default_factory, - zip(sorted(node), children_treespecs), + zip(sorted(node), children_treespecs, strict=True), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1628,7 +1635,7 @@ def test_treespec_constructor( # noqa: C901 optree.treespec_from_collection( defaultdict( node.default_factory, - zip(sorted(node), children_treespecs), + zip(sorted(node), children_treespecs, strict=True), ), none_is_leaf=none_is_leaf, namespace=passed_namespace, @@ -1648,7 +1655,7 @@ def test_treespec_constructor( # noqa: C901 assert ( optree.treespec_defaultdict( node.default_factory, - zip(node, children_treespecs), + zip(node, children_treespecs, strict=True), none_is_leaf=none_is_leaf, namespace=passed_namespace, ) @@ -1658,7 +1665,7 @@ def test_treespec_constructor( # noqa: C901 optree.treespec_from_collection( defaultdict( node.default_factory, - zip(node, children_treespecs), + zip(node, children_treespecs, strict=True), ), none_is_leaf=none_is_leaf, namespace=passed_namespace, diff --git a/tests/test_typing.py b/tests/test_typing.py index 51cb95cc..8e04c8c8 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -113,14 +113,13 @@ def test_pytree_typing(): T = TypeVar('T') optree.PyTree[int] - optree.PyTree[Union[int, str]] + optree.PyTree[Union[int, str]] # noqa: UP007 optree.PyTree[T] assert optree.PyTree[optree.PyTree[int]] == optree.PyTree[int] - assert optree.PyTree[optree.PyTree[Union[int, str]]] == optree.PyTree[Union[int, str]] + assert optree.PyTree[optree.PyTree[Union[int, str]]] == optree.PyTree[Union[int, str]] # noqa: UP007 assert optree.PyTree[optree.PyTree[T]] == optree.PyTree[T] - if sys.version_info >= (3, 10): - optree.PyTree[int | str] - assert optree.PyTree[optree.PyTree[int | str]] == optree.PyTree[int | str] + optree.PyTree[float | bytes] + assert optree.PyTree[optree.PyTree[float | bytes]] == optree.PyTree[float | bytes] IntTree = optree.PyTreeTypeVar('IntTree', int) # noqa: N806 assert IntTree == optree.PyTree[IntTree]