diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 37516d6f..2f698252 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -77,7 +77,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v6 with: - python-version: "3.9 - 3.14" + python-version: "3.9 - 3.15" update-environment: true - name: Upgrade pip @@ -154,6 +154,8 @@ jobs: - "3.13t" - "3.14" - "3.14t" + # - "3.15" + # - "3.15t" - "pypy-3.11" exclude: # Exclude unsupported Python versions @@ -183,6 +185,11 @@ jobs: runner: macos-latest platform: ios archs: "arm64_iphoneos" + # - python-version: "3.15" + # target: + # runner: macos-latest + # platform: ios + # archs: "arm64_iphoneos" # iOS Simulator - python-version: "3.13" target: @@ -194,6 +201,11 @@ jobs: runner: macos-latest platform: ios archs: "arm64_iphonesimulator" + # - python-version: "3.15" + # target: + # runner: macos-latest + # platform: ios + # archs: "arm64_iphonesimulator" # Android - python-version: "3.13" target: @@ -205,6 +217,11 @@ jobs: runner: ubuntu-latest platform: android archs: "arm64_v8a" + # - python-version: "3.15" + # target: + # runner: ubuntu-latest + # platform: android + # archs: "arm64_v8a" # Pyodide - python-version: "3.12" target: @@ -387,7 +404,7 @@ jobs: if: startsWith(github.ref, 'refs/tags/') uses: actions/setup-python@v6 with: - python-version: "3.9 - 3.14" + python-version: "3.9 - 3.15" update-environment: true - name: Upgrade pip diff --git a/.github/workflows/tests-with-pydebug.yml b/.github/workflows/tests-with-pydebug.yml index bc5f075f..d901926d 100644 --- a/.github/workflows/tests-with-pydebug.yml +++ b/.github/workflows/tests-with-pydebug.yml @@ -73,6 +73,7 @@ jobs: - "3.12" - "3.13" - "3.14" + - "3.15" python-abiflags: ["d", "td"] exclude: - python-version: "3.9" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 61803d37..64365150 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -74,6 +74,8 @@ jobs: - "3.13t" - "3.14" - "3.14t" + - "3.15" + - "3.15t" - "pypy-3.11" fail-fast: false timeout-minutes: 120 diff --git a/README.md b/README.md index 78b88602..ee62ec42 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,7 @@ OpTree out-of-the-box supports the following Python container types in the globa - [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict) - [`collections.deque`](https://docs.python.org/3/library/collections.html#collections.deque) - [`PyStructSequence`](https://docs.python.org/3/c-api/tuple.html#struct-sequence-objects) types created by C API [`PyStructSequence_NewType`](https://docs.python.org/3/c-api/tuple.html#c.PyStructSequence_NewType) +- [`frozendict`](https://docs.python.org/3/library/stdtypes.html#frozendict) (Python 3.15+) These types are considered non-leaf nodes in the tree. Python objects whose type is not registered are treated as leaf nodes. @@ -356,7 +357,7 @@ There are several key attributes of the pytree type registry: > [!WARNING] > Any `PyTreeSpec` objects created before the unregistration still hold a reference to the old registration. Unflattening such a `PyTreeSpec` will use the **old** `unflatten_func`, not the newly registered one. -3. **Built-in types cannot be re-registered.** The behavior of the types listed in [Built-in PyTree Node Types](#built-in-pytree-node-types) (e.g., key-sorted traversal for `dict` and `collections.defaultdict`) is fixed. +3. **Built-in types cannot be re-registered.** The behavior of the types listed in [Built-in PyTree Node Types](#built-in-pytree-node-types) (e.g., key-sorted traversal for `dict`, `collections.defaultdict`, and `frozendict`) is fixed. 4. **Inherited subclasses are not implicitly registered.** The registry lookup uses `type(obj) is registered_type` rather than `isinstance(obj, registered_type)`. Users need to register the subclasses explicitly. To register all subclasses, it is easy to implement with [`metaclass`](https://docs.python.org/3/reference/datamodel.html#metaclasses) or [`__init_subclass__`](https://docs.python.org/3/reference/datamodel.html#customizing-class-creation), for example: @@ -496,7 +497,7 @@ OrderedDict({ The built-in Python dictionary ([`builtins.dict`](https://docs.python.org/3/library/stdtypes.html#dict)) is a mapping whose leaves are its values. Since [Python 3.7](https://docs.python.org/3/whatsnew/3.7.html), `dict` is guaranteed to be insertion ordered, but the equality operator (`==`) ignores key order. To ensure [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) — "equal `dict`" implies "equal ordering of leaves" — the leaves (values) are returned in key-sorted order. -The same applies to [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict). +The same applies to [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict) and [`frozendict`](https://docs.python.org/3/library/stdtypes.html#frozendict) (Python 3.15+). ```python >>> optree.tree_flatten({'a': [1, 2], 'b': [3]}) @@ -561,7 +562,7 @@ False ([3, 1, 2], PyTreeSpec(OrderedDict({'b': [*], 'a': [*, *]}))) ``` -To flatten [`builtins.dict`](https://docs.python.org/3/library/stdtypes.html#dict) and [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict) objects with the insertion order preserved, use the `dict_insertion_ordered` context manager: +To flatten [`builtins.dict`](https://docs.python.org/3/library/stdtypes.html#dict), [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict), and [`frozendict`](https://docs.python.org/3/library/stdtypes.html#frozendict) (Python 3.15+) objects with the insertion order preserved, use the `dict_insertion_ordered` context manager: ```python >>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5} diff --git a/docs/source/ops.rst b/docs/source/ops.rst index 56ff04e5..6ccb681a 100644 --- a/docs/source/ops.rst +++ b/docs/source/ops.rst @@ -175,4 +175,5 @@ Tree Reduce Functions .. autofunction:: treespec_defaultdict .. autofunction:: treespec_deque .. autofunction:: treespec_structseq +.. TODO(frozendict): Add ``.. autofunction:: treespec_frozendict`` when building with Python 3.15+. .. autofunction:: treespec_from_collection diff --git a/docs/source/spelling_wordlist.txt b/docs/source/spelling_wordlist.txt index f2a47c4e..57df2622 100644 --- a/docs/source/spelling_wordlist.txt +++ b/docs/source/spelling_wordlist.txt @@ -34,6 +34,7 @@ eq fillvalue fmt forwardref +frozendict frozenset func functools diff --git a/docs/source/treespec.rst b/docs/source/treespec.rst index 7be69e2b..dacf6f7b 100644 --- a/docs/source/treespec.rst +++ b/docs/source/treespec.rst @@ -23,3 +23,5 @@ Check section :ref:`PyTreeSpec Functions` for more detailed documentation. deque structseq from_collection + +.. TODO(frozendict): Add ``frozendict`` to the autosummary when building with Python 3.15+. diff --git a/include/optree/pymacros.h b/include/optree/pymacros.h index e0469047..8e214767 100644 --- a/include/optree/pymacros.h +++ b/include/optree/pymacros.h @@ -43,6 +43,12 @@ limitations under the License. # undef OPTREE_HAS_SUBINTERPRETER_SUPPORT #endif +#if PY_VERSION_HEX >= 0x030F00A7 // Python 3.15.0a7+ +# define OPTREE_HAS_FROZENDICT 1 +#else +# undef OPTREE_HAS_FROZENDICT +#endif + namespace py = pybind11; #if !defined(Py_ALWAYS_INLINE) diff --git a/include/optree/pytypes.h b/include/optree/pytypes.h index c0469a36..a1600790 100644 --- a/include/optree/pytypes.h +++ b/include/optree/pytypes.h @@ -66,6 +66,10 @@ constexpr py::ssize_t MAX_TYPE_CACHE_SIZE = 4096; #define PyOrderedDict_Type (reinterpret_cast(PyOrderedDictTypeObject.ptr())) #define PyDefaultDict_Type (reinterpret_cast(PyDefaultDictTypeObject.ptr())) #define PyDeque_Type (reinterpret_cast(PyDequeTypeObject.ptr())) +#if defined(OPTREE_HAS_FROZENDICT) +# define PyFrozenDictTypeObject \ + (py::reinterpret_borrow(reinterpret_cast(&PyFrozenDict_Type))) +#endif inline const py::object &ImportOrderedDict() { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; @@ -181,6 +185,14 @@ inline Py_ALWAYS_INLINE void AssertExactDict(const py::handle &object) { } } +#if defined(OPTREE_HAS_FROZENDICT) +inline Py_ALWAYS_INLINE void AssertExactFrozenDict(const py::handle &object) { + if (!PyFrozenDict_CheckExact(object.ptr())) [[unlikely]] { + throw py::value_error("Expected an instance of frozendict, got " + PyRepr(object) + "."); + } +} +#endif + inline Py_ALWAYS_INLINE void AssertExactOrderedDict(const py::handle &object) { if (!py::type::handle_of(object).is(PyOrderedDictTypeObject)) [[unlikely]] { throw py::value_error("Expected an instance of collections.OrderedDict, got " + @@ -198,10 +210,17 @@ inline Py_ALWAYS_INLINE void AssertExactDefaultDict(const py::handle &object) { inline Py_ALWAYS_INLINE void AssertExactStandardDict(const py::handle &object) { if (!(PyDict_CheckExact(object.ptr()) || py::type::handle_of(object).is(PyOrderedDictTypeObject) || - py::type::handle_of(object).is(PyDefaultDictTypeObject))) [[unlikely]] { + py::type::handle_of(object).is(PyDefaultDictTypeObject) +#if defined(OPTREE_HAS_FROZENDICT) + || PyFrozenDict_CheckExact(object.ptr()) +#endif + )) [[unlikely]] { throw py::value_error( - "Expected an instance of dict, collections.OrderedDict, or collections.defaultdict, " - "got " + + "Expected an instance of dict, " +#if defined(OPTREE_HAS_FROZENDICT) + "frozendict, " +#endif + "collections.OrderedDict, or collections.defaultdict, got " + PyRepr(object) + "."); } } diff --git a/include/optree/registry.h b/include/optree/registry.h index 5d2be589..108a564c 100644 --- a/include/optree/registry.h +++ b/include/optree/registry.h @@ -53,6 +53,7 @@ enum class PyTreeKind : std::uint8_t { DefaultDict, // A collections.defaultdict Deque, // A collections.deque StructSequence, // A PyStructSequence + FrozenDict, // A frozendict (Python 3.15+) NumKinds, // Number of kinds (placed at the end) }; @@ -67,6 +68,7 @@ constexpr PyTreeKind kOrderedDict = PyTreeKind::OrderedDict; constexpr PyTreeKind kDefaultDict = PyTreeKind::DefaultDict; constexpr PyTreeKind kDeque = PyTreeKind::Deque; constexpr PyTreeKind kStructSequence = PyTreeKind::StructSequence; +constexpr PyTreeKind kFrozenDict = PyTreeKind::FrozenDict; constexpr PyTreeKind kNumPyTreeKinds = PyTreeKind::NumKinds; // Registry of custom node types. diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 2e46a2a6..c33f955b 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -272,7 +272,7 @@ class PyTreeSpec { // Kind-specific metadata. // For a NamedTuple/PyStructSequence, contains the tuple type object. - // For a Dict, contains a sorted list of keys. + // For a Dict or FrozenDict, contains a sorted list of keys. // For a OrderedDict, contains a list of keys. // For a DefaultDict, contains a tuple of (default_factory, sorted list of keys). // For a Deque, contains the `maxlen` attribute. @@ -293,7 +293,7 @@ class PyTreeSpec { // Number of leaf and interior nodes in the subtree rooted at this node. ssize_t num_nodes = 0; - // For a Dict or DefaultDict, contains the keys in insertion order. + // For a Dict, DefaultDict, or FrozenDict, contains the keys in insertion order. py::object original_keys{}; }; diff --git a/optree/_C.pyi b/optree/_C.pyi index cc44efa4..caa8efaf 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -57,6 +57,7 @@ PYBIND11_HAS_SUBINTERPRETER_SUPPORT: Final[bool] GLIBCXX_USE_CXX11_ABI: Final[bool] OPTREE_HAS_SUBINTERPRETER_SUPPORT: Final[bool] OPTREE_HAS_READ_WRITE_LOCK: Final[bool] +OPTREE_HAS_FROZENDICT: Final[bool] @final class InternalError(SystemError): ... @@ -74,6 +75,7 @@ class PyTreeKind(enum.IntEnum): DEFAULTDICT = enum.auto() # a collections.defaultdict DEQUE = enum.auto() # a collections.deque STRUCTSEQUENCE = enum.auto() # a PyStructSequence + FROZENDICT = enum.auto() # a frozendict (Python 3.15+) NUM_KINDS: ClassVar[int] diff --git a/optree/__init__.py b/optree/__init__.py index fcd7cced..43e85072 100644 --- a/optree/__init__.py +++ b/optree/__init__.py @@ -14,6 +14,9 @@ # ============================================================================== """OpTree: Optimized PyTree Utilities.""" +import sys + +import optree._C as _C from optree import accessors, dataclasses, functools, integrations, pytree, treespec, typing from optree.accessors import ( AutoEntry, @@ -225,6 +228,13 @@ 'structseq_fields', ] + +if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + from optree.ops import treespec_frozendict + + __all__.insert(__all__.index('treespec_from_collection'), 'treespec_frozendict') + + MAX_RECURSION_DEPTH: int = MAX_RECURSION_DEPTH """Maximum recursion depth for pytree traversal. @@ -235,3 +245,6 @@ """Literal constant that treats :data:`None` as a pytree non-leaf node.""" NONE_IS_LEAF: bool = NONE_IS_LEAF # literal constant """Literal constant that treats :data:`None` as a pytree leaf node.""" + + +del sys diff --git a/optree/ops.py b/optree/ops.py index d311bfcb..ea8b78f7 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -21,6 +21,7 @@ import difflib import functools import itertools +import sys import textwrap from collections import OrderedDict, defaultdict, deque from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, overload @@ -115,6 +116,13 @@ 'prefix_errors', ] + +if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + from builtins import frozendict # type: ignore[import] # pylint: disable=no-name-in-module + + __all__.insert(__all__.index('treespec_from_collection'), 'treespec_frozendict') + + MAX_RECURSION_DEPTH: int = _C.MAX_RECURSION_DEPTH """Maximum recursion depth for pytree traversal. @@ -160,9 +168,10 @@ def tree_flatten( >>> tree_flatten(None, none_is_leaf=True) ([None], PyTreeSpec(*, NoneIsLeaf)) - For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is - dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` - if you want to keep the keys in the insertion order. + For unordered dictionaries, :class:`dict`, :class:`collections.defaultdict`, and + :class:`frozendict` (Python 3.15+), the order is dependent on the **sorted** keys in the + dictionary. Please use :class:`collections.OrderedDict` if you want to keep the keys in the + insertion order. >>> from collections import OrderedDict >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) @@ -232,9 +241,10 @@ def tree_flatten_with_path( >>> tree_flatten_with_path(None, none_is_leaf=True) ([()], [None], PyTreeSpec(*, NoneIsLeaf)) - For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is - dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` - if you want to keep the keys in the insertion order. + For unordered dictionaries, :class:`dict`, :class:`collections.defaultdict`, and + :class:`frozendict` (Python 3.15+), the order is dependent on the **sorted** keys in the + dictionary. Please use :class:`collections.OrderedDict` if you want to keep the keys in the + insertion order. >>> from collections import OrderedDict >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) @@ -320,9 +330,10 @@ def tree_flatten_with_accessor( >>> tree_flatten_with_accessor(None, none_is_leaf=True) ([PyTreeAccessor(*, ())], [None], PyTreeSpec(*, NoneIsLeaf)) - For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is - dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict` - if you want to keep the keys in the insertion order. + For unordered dictionaries, :class:`dict`, :class:`collections.defaultdict`, and + :class:`frozendict` (Python 3.15+), the order is dependent on the **sorted** keys in the + dictionary. Please use :class:`collections.OrderedDict` if you want to keep the keys in the + insertion order. >>> from collections import OrderedDict >>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)]) @@ -3408,7 +3419,7 @@ def treespec_deque( >>> treespec_deque([treespec_leaf(), treespec_leaf(), treespec_none()], maxlen=5) PyTreeSpec(deque([*, *, None], maxlen=5)) >>> treespec_deque() - PyTreeSpec(deque([])) + PyTreeSpec(deque()) >>> treespec_deque([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])]) PyTreeSpec(deque([*, (*, *)])) >>> treespec_deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2})], maxlen=5) @@ -3473,6 +3484,46 @@ def treespec_structseq( ) +if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + + def treespec_frozendict( + mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), + /, + *, + none_is_leaf: bool = False, + namespace: str = '', + **kwargs: PyTreeSpec, + ) -> PyTreeSpec: + """Make a frozendict treespec from a frozendict of child treespecs. + + See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`. + + >>> treespec_frozendict({'a': treespec_leaf(), 'b': treespec_leaf()}) # doctest: +SKIP + PyTreeSpec(frozendict({'a': *, 'b': *})) + >>> treespec_frozendict() # doctest: +SKIP + PyTreeSpec(frozendict()) + + Args: + mapping (mapping of PyTreeSpec, optional): A mapping of child treespecs. They must have + the same ``none_is_leaf`` and ``namespace`` values. + none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`, + :data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the + treespec rather than in the leaves list and :data:`None` will remain in the result + pytree. (default: :data:`False`) + namespace (str, optional): The registry namespace used for custom pytree node types. + (default: :const:`''`, i.e., the global namespace) + **kwargs (PyTreeSpec, optional): Additional child treespecs to add to the mapping. + + Returns: + A treespec representing a frozendict node with the given children. + """ + return _C.make_from_collection( + frozendict(mapping, **kwargs), + none_is_leaf, + namespace, + ) + + def treespec_from_collection( collection: Collection[PyTreeSpec], /, @@ -3523,6 +3574,8 @@ def treespec_from_collection( STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict}) +if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + STANDARD_DICT_TYPES |= frozenset({frozendict}) def prefix_errors( # noqa: C901 diff --git a/optree/registry.py b/optree/registry.py index 96ea5707..4b1ac808 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -18,6 +18,7 @@ from __future__ import annotations +import builtins import contextlib import dataclasses import functools @@ -50,7 +51,6 @@ if TYPE_CHECKING: - import builtins from collections.abc import Collection, Generator, Iterable from optree.typing import KT, VT, CustomTreeNode, FlattenFunc, UnflattenFunc @@ -68,6 +68,8 @@ SLOTS = {'slots': True} if sys.version_info >= (3, 10) else {} # Python 3.10+ +if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + from builtins import frozendict # type: ignore[import] # pylint: disable=no-name-in-module @dataclasses.dataclass(init=True, repr=True, eq=True, frozen=True, **SLOTS) @@ -236,6 +238,8 @@ def pytree_node_registry_get( # noqa: C901 if _C.is_dict_insertion_ordered(namespace): registry[dict] = _DICT_INSERTION_ORDERED_REGISTRY_ENTRY registry[defaultdict] = _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY + if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + registry[frozendict] = _FROZENDICT_INSERTION_ORDERED_REGISTRY_ENTRY return registry if namespace != '': @@ -248,6 +252,9 @@ def pytree_node_registry_get( # noqa: C901 return _DICT_INSERTION_ORDERED_REGISTRY_ENTRY if cls is defaultdict: return _DEFAULTDICT_INSERTION_ORDERED_REGISTRY_ENTRY + if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + if cls is builtins.frozendict: # pylint: disable=no-member + return _FROZENDICT_INSERTION_ORDERED_REGISTRY_ENTRY handler = _NODETYPE_REGISTRY.get(cls) if handler is not None: @@ -936,3 +943,48 @@ def _structseq_unflatten( path_entry_type=MappingEntry, kind=PyTreeKind.DEFAULTDICT, ) + +if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + + def _frozendict_flatten( + dct: frozendict[KT, VT], # type: ignore[type-arg] + /, + ) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]: + keys, values = unzip2(_sorted_items(dct.items())) + return values, list(keys), keys + + def _frozendict_unflatten( + keys: list[KT], + values: Iterable[VT], + /, + ) -> frozendict[KT, VT]: # type: ignore[type-arg] + return frozendict(safe_zip(keys, values)) + + def _frozendict_insertion_ordered_flatten( + dct: frozendict[KT, VT], # type: ignore[type-arg] + /, + ) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]: + keys, values = unzip2(dct.items()) + return values, list(keys), keys + + def _frozendict_insertion_ordered_unflatten( + keys: list[KT], + values: Iterable[VT], + /, + ) -> frozendict[KT, VT]: # type: ignore[type-arg] + return frozendict(safe_zip(keys, values)) + + _NODETYPE_REGISTRY[frozendict] = PyTreeNodeRegistryEntry( + frozendict, + _frozendict_flatten, + _frozendict_unflatten, + path_entry_type=MappingEntry, + kind=PyTreeKind.FROZENDICT, + ) + _FROZENDICT_INSERTION_ORDERED_REGISTRY_ENTRY = PyTreeNodeRegistryEntry( + frozendict, + _frozendict_insertion_ordered_flatten, + _frozendict_insertion_ordered_unflatten, + path_entry_type=MappingEntry, + kind=PyTreeKind.FROZENDICT, + ) diff --git a/optree/treespec.py b/optree/treespec.py index 16bbceb5..d55b5c2a 100644 --- a/optree/treespec.py +++ b/optree/treespec.py @@ -27,6 +27,9 @@ from __future__ import annotations +import sys + +import optree._C as _C from optree.ops import treespec_defaultdict as defaultdict from optree.ops import treespec_deque as deque from optree.ops import treespec_dict as dict # pylint: disable=redefined-builtin @@ -53,3 +56,13 @@ 'structseq', 'from_collection', ] + + +if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + # pylint: disable-next=unused-import,redefined-builtin + from optree.ops import treespec_frozendict as frozendict # noqa: F401 + + __all__.insert(__all__.index('from_collection'), 'frozendict') + + +del sys diff --git a/optree/typing.py b/optree/typing.py index bbf3b98e..725e0d87 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -130,6 +130,16 @@ ] +if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + # pylint: disable-next=no-name-in-module,ungrouped-imports + from builtins import frozendict # type: ignore[import] + + # pylint: disable-next=no-name-in-module,unused-import + from builtins import frozendict as FrozenDict # type: ignore[import] # noqa: F401,N812 + + __all__.insert(__all__.index('Dict') + 1, 'FrozenDict') + + PyTreeDef: TypeAlias = PyTreeSpec # alias T = TypeVar('T') @@ -261,14 +271,21 @@ def __class_getitem__( # noqa: C901 # pylint: disable=too-many-branches else: recurse_ref = ForwardRef(f'{cls.__name__}[{param!r}]') - pytree_alias = Union[ - param, # type: ignore[valid-type] + pytree_types = [ + param, Tuple[recurse_ref, ...], # type: ignore[valid-type] # Tuple, NamedTuple, PyStructSequence List[recurse_ref], # type: ignore[valid-type] Dict[Any, recurse_ref], # type: ignore[valid-type] # Dict, OrderedDict, DefaultDict - Deque[recurse_ref], # type: ignore[valid-type] - CustomTreeNode[recurse_ref], # type: ignore[valid-type] ] + if sys.version_info >= (3, 15) and _C.OPTREE_HAS_FROZENDICT: # pragma: >=3.15 cover + pytree_types.append(frozendict[Any, recurse_ref]) # type: ignore[valid-type] + pytree_types.extend( + [ + Deque[recurse_ref], # type: ignore[list-item,valid-type] + CustomTreeNode[recurse_ref], # type: ignore[list-item,valid-type] + ], + ) + pytree_alias = Union[tuple(pytree_types)] # type: ignore[valid-type] with cls.__instance_lock__: cls.__instances__[pytree_alias] = (param, name) # type: ignore[index] @@ -310,7 +327,7 @@ def count(self, key: Any, /) -> int: """Emulate sequence-like behavior.""" raise NotImplementedError - def get(self, key: Any, /, default: S | None = None) -> PyTree[T] | T | S | None: + def get(self, key: Any, default: S | None = None, /) -> PyTree[T] | T | S | None: """Emulate mapping-like behavior.""" raise NotImplementedError diff --git a/pyproject.toml b/pyproject.toml index bd918447..c3c8dd95 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ classifiers = [ dependencies = [ "typing-extensions >= 4.6.0", "typing-extensions >= 4.12.0; python_version >= '3.13'", + "typing-extensions >= 4.14.0; python_version >= '3.15'", ] dynamic = ["version"] @@ -77,9 +78,12 @@ test = [ "typing-extensions == 4.6.0; python_version < '3.13' and platform_system == 'Linux'", "typing-extensions == 4.6.0; python_version < '3.13' and platform_system == 'Darwin'", "typing-extensions == 4.6.0; python_version < '3.13' and platform_system == 'Windows'", - "typing-extensions == 4.12.0; python_version >= '3.13' and platform_system == 'Linux'", - "typing-extensions == 4.12.0; python_version >= '3.13' and platform_system == 'Darwin'", - "typing-extensions == 4.12.0; python_version >= '3.13' and platform_system == 'Windows'", + "typing-extensions == 4.12.0; python_version >= '3.13' and python_version < '3.15' and platform_system == 'Linux'", + "typing-extensions == 4.12.0; python_version >= '3.13' and python_version < '3.15' and platform_system == 'Darwin'", + "typing-extensions == 4.12.0; python_version >= '3.13' and python_version < '3.15' and platform_system == 'Windows'", + "typing-extensions == 4.14.0; python_version >= '3.15' and platform_system == 'Linux'", + "typing-extensions == 4.14.0; python_version >= '3.15' and platform_system == 'Darwin'", + "typing-extensions == 4.14.0; python_version >= '3.15' and platform_system == 'Windows'", ] docs = [ "sphinx ~= 8.0", @@ -267,6 +271,9 @@ ignore = [ # TRY003: avoid specifying long messages outside the exception class # long messages are necessary for clarity "TRY003", + # FURB152: literals that are similar to constants in `math` module. + # change code semantics + "FURB152", # RUF022: `__all__` is not ordered according to an "isort-style" sort # `__all__` contains comments to group names "RUF022", @@ -299,6 +306,7 @@ typing-modules = ["optree.typing"] [tool.ruff.lint.isort] known-first-party = ["optree"] extra-standard-library = ["typing_extensions"] +known-local-folder = ["tests"] lines-after-imports = 2 [tool.ruff.lint.pydocstyle] @@ -317,4 +325,5 @@ ban-relative-imports = "all" [tool.pytest.ini_options] verbosity_assertions = 3 +testpaths = ["tests"] filterwarnings = ["always", "error"] diff --git a/requirements.txt b/requirements.txt index bb5ece41..9269cc68 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ typing-extensions >= 4.6.0 typing-extensions >= 4.12.0; python_version >= '3.13' +typing-extensions >= 4.14.0; python_version >= '3.15' diff --git a/src/optree.cpp b/src/optree.cpp index cd8753f8..8990249c 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -107,6 +107,11 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] #else BUILDTIME_METADATA["OPTREE_HAS_READ_WRITE_LOCK"] = py::bool_(false); #endif +#if defined(OPTREE_HAS_FROZENDICT) + BUILDTIME_METADATA["OPTREE_HAS_FROZENDICT"] = py::bool_(true); +#else + BUILDTIME_METADATA["OPTREE_HAS_FROZENDICT"] = py::bool_(false); +#endif mod.attr("BUILDTIME_METADATA") = std::move(BUILDTIME_METADATA); py::exec( @@ -285,6 +290,7 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] .value("DEFAULTDICT", PyTreeKind::DefaultDict, "A collections.defaultdict.") .value("DEQUE", PyTreeKind::Deque, "A collections.deque.") .value("STRUCTSEQUENCE", PyTreeKind::StructSequence, "A PyStructSequence.") + .value("FROZENDICT", PyTreeKind::FrozenDict, "A frozendict.") .finalize(); auto PyTreeKindTypeObject = py::getattr(mod, "PyTreeKind"); #else @@ -300,7 +306,8 @@ void BuildModule(py::module_ &mod) { // NOLINT[runtime/references] .value("ORDEREDDICT", PyTreeKind::OrderedDict, "A collections.OrderedDict.") .value("DEFAULTDICT", PyTreeKind::DefaultDict, "A collections.defaultdict.") .value("DEQUE", PyTreeKind::Deque, "A collections.deque.") - .value("STRUCTSEQUENCE", PyTreeKind::StructSequence, "A PyStructSequence."); + .value("STRUCTSEQUENCE", PyTreeKind::StructSequence, "A PyStructSequence.") + .value("FROZENDICT", PyTreeKind::FrozenDict, "A frozendict."); #endif auto * const PyTreeKind_Type = reinterpret_cast(PyTreeKindTypeObject.ptr()); PyTreeKind_Type->tp_name = "optree.PyTreeKind"; diff --git a/src/registry.cpp b/src/registry.cpp index ef9c182a..501fa555 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -61,6 +61,9 @@ template add_builtin_type(PyOrderedDictTypeObject, PyTreeKind::OrderedDict); add_builtin_type(PyDefaultDictTypeObject, PyTreeKind::DefaultDict); add_builtin_type(PyDequeTypeObject, PyTreeKind::Deque); +#if defined(OPTREE_HAS_FROZENDICT) + add_builtin_type(PyFrozenDictTypeObject, PyTreeKind::FrozenDict); +#endif return registry; }) .get_stored(); diff --git a/src/treespec/constructors.cpp b/src/treespec/constructors.cpp index 9c95ae14..84c0a2a5 100644 --- a/src/treespec/constructors.cpp +++ b/src/treespec/constructors.cpp @@ -161,7 +161,8 @@ template case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { py::list keys; { const scoped_critical_section cs{handle}; diff --git a/src/treespec/flatten.cpp b/src/treespec/flatten.cpp index e51a4557..af6e8585 100644 --- a/src/treespec/flatten.cpp +++ b/src/treespec/flatten.cpp @@ -98,7 +98,8 @@ bool PyTreeSpec::FlattenIntoImpl(const py::handle &handle, case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { py::list keys; { const scoped_critical_section cs{handle}; @@ -365,7 +366,8 @@ bool PyTreeSpec::FlattenIntoWithPathImpl(const py::handle &handle, case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { const scoped_critical_section cs{handle}; const auto dict = py::reinterpret_borrow(handle); node.arity = DictGetSize(dict); @@ -638,7 +640,8 @@ py::list PyTreeSpec::FlattenUpTo(const py::object &tree) const { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { AssertExactStandardDict(object); const scoped_critical_section2 cs{object, node.node_data}; const auto dict = py::reinterpret_borrow(object); @@ -663,8 +666,10 @@ py::list PyTreeSpec::FlattenUpTo(const py::object &tree) const { oss << "dict"; } else if (node.kind == PyTreeKind::OrderedDict) [[likely]] { oss << "OrderedDict"; - } else [[unlikely]] { + } else if (node.kind == PyTreeKind::DefaultDict) [[likely]] { oss << "defaultdict"; + } else [[unlikely]] { + oss << "frozendict"; } oss << ": " << PyRepr(object) << "."; throw py::value_error(oss.str()); diff --git a/src/treespec/hashing.cpp b/src/treespec/hashing.cpp index 81acc013..cd2f492e 100644 --- a/src/treespec/hashing.cpp +++ b/src/treespec/hashing.cpp @@ -61,7 +61,8 @@ ssize_t PyTreeSpec::HashValueImpl() const { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { const scoped_critical_section cs{node.node_data}; if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { EXPECT_EQ(TupleGetSize(node.node_data), 2, "Number of metadata mismatch."); diff --git a/src/treespec/richcomparison.cpp b/src/treespec/richcomparison.cpp index 96b61798..6773625d 100644 --- a/src/treespec/richcomparison.cpp +++ b/src/treespec/richcomparison.cpp @@ -73,9 +73,11 @@ bool PyTreeSpec::IsPrefix(const PyTreeSpec &other, const bool &strict) const { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { if (b->kind != PyTreeKind::Dict && b->kind != PyTreeKind::OrderedDict && - b->kind != PyTreeKind::DefaultDict) [[likely]] { + b->kind != PyTreeKind::DefaultDict && b->kind != PyTreeKind::FrozenDict) + [[likely]] { return false; } const scoped_critical_section2 cs(a->node_data, b->node_data); diff --git a/src/treespec/serialization.cpp b/src/treespec/serialization.cpp index 8c397549..6e2ab7f7 100644 --- a/src/treespec/serialization.cpp +++ b/src/treespec/serialization.cpp @@ -48,6 +48,8 @@ namespace optree { return PyRepr(node.node_data); case PyTreeKind::Deque: return "deque"; + case PyTreeKind::FrozenDict: + return "frozendict"; case PyTreeKind::Custom: EXPECT_NE(node.custom, nullptr, "The custom registration is null."); return PyRepr(node.custom->type); @@ -104,13 +106,16 @@ std::string PyTreeSpec::ToStringImpl() const { } case PyTreeKind::Dict: - case PyTreeKind::OrderedDict: { + case PyTreeKind::OrderedDict: + case PyTreeKind::FrozenDict: { const scoped_critical_section cs{node.node_data}; EXPECT_EQ(ListGetSize(node.node_data), node.arity, "Number of keys and entries does not match."); if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] { sstream << "OrderedDict("; + } else if (node.kind == PyTreeKind::FrozenDict) [[unlikely]] { + sstream << "frozendict("; } if (node.kind == PyTreeKind::Dict || node.arity > 0) [[likely]] { sstream << "{"; @@ -128,7 +133,8 @@ std::string PyTreeSpec::ToStringImpl() const { if (node.kind == PyTreeKind::Dict || node.arity > 0) [[likely]] { sstream << "}"; } - if (node.kind == PyTreeKind::OrderedDict) [[unlikely]] { + if (node.kind == PyTreeKind::OrderedDict || node.kind == PyTreeKind::FrozenDict) + [[unlikely]] { sstream << ")"; } break; @@ -181,9 +187,15 @@ std::string PyTreeSpec::ToStringImpl() const { } case PyTreeKind::Deque: { - sstream << "deque([" << children << "]"; + sstream << "deque("; + if (node.arity > 0) [[likely]] { + sstream << "[" << children << "]"; + } if (!node.node_data.is_none()) [[unlikely]] { - sstream << ", maxlen=" << PyRepr(node.node_data); + if (node.arity > 0) [[likely]] { + sstream << ", "; + } + sstream << "maxlen=" << PyRepr(node.node_data); } sstream << ")"; break; @@ -332,13 +344,13 @@ py::object PyTreeSpec::ToPickleable() const { if (t.size() != 7) [[unlikely]] { if (t.size() == 8) [[likely]] { if (t[7].is_none()) [[likely]] { - if (node.kind == PyTreeKind::Dict || node.kind == PyTreeKind::DefaultDict) - [[unlikely]] { + if (node.kind == PyTreeKind::Dict || node.kind == PyTreeKind::DefaultDict || + node.kind == PyTreeKind::FrozenDict) [[unlikely]] { throw std::runtime_error("Malformed pickled PyTreeSpec."); } } else [[unlikely]] { - if (node.kind == PyTreeKind::Dict || node.kind == PyTreeKind::DefaultDict) - [[likely]] { + if (node.kind == PyTreeKind::Dict || node.kind == PyTreeKind::DefaultDict || + node.kind == PyTreeKind::FrozenDict) [[likely]] { node.original_keys = thread_safe_cast(t[7]); } else [[unlikely]] { throw std::runtime_error("Malformed pickled PyTreeSpec."); @@ -361,7 +373,8 @@ py::object PyTreeSpec::ToPickleable() const { } case PyTreeKind::Dict: - case PyTreeKind::OrderedDict: { + case PyTreeKind::OrderedDict: + case PyTreeKind::FrozenDict: { node.node_data = thread_safe_cast(t[2]); break; } diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index 5aab3862..6ecb1e1b 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -82,7 +82,8 @@ py::object PyTreeIter::NextImpl() { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { const scoped_critical_section cs{object}; const auto dict = py::reinterpret_borrow(object); py::list keys = DictKeys(dict); @@ -211,6 +212,7 @@ py::object PyTreeSpec::WalkImpl(const py::iterable &leaves, case PyTreeKind::DefaultDict: case PyTreeKind::Deque: case PyTreeKind::StructSequence: + case PyTreeKind::FrozenDict: case PyTreeKind::Custom: { const ssize_t size = py::ssize_t_cast(agenda.size()); EXPECT_GE(size, node.arity, "Too few elements for custom type."); diff --git a/src/treespec/treespec.cpp b/src/treespec/treespec.cpp index 0ae01cab..c98f1947 100644 --- a/src/treespec/treespec.cpp +++ b/src/treespec/treespec.cpp @@ -79,7 +79,8 @@ namespace optree { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { py::dict dict{}; const scoped_critical_section2 cs{node.node_data, node.original_keys}; if (node.kind == PyTreeKind::DefaultDict) [[unlikely]] { @@ -106,6 +107,11 @@ namespace optree { PyDefaultDictTypeObject(default_factory, std::move(dict)), default_factory); } +#if defined(OPTREE_HAS_FROZENDICT) + if (node.kind == PyTreeKind::FrozenDict) [[unlikely]] { + return PyFrozenDictTypeObject(std::move(dict)); + } +#endif return dict; } @@ -145,7 +151,8 @@ namespace optree { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store storage; return storage .call_once_and_store_result( @@ -259,9 +266,11 @@ namespace optree { case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { if (other_root.kind != PyTreeKind::Dict && other_root.kind != PyTreeKind::OrderedDict && - other_root.kind != PyTreeKind::DefaultDict) [[unlikely]] { + other_root.kind != PyTreeKind::DefaultDict && + other_root.kind != PyTreeKind::FrozenDict) [[unlikely]] { std::ostringstream oss{}; oss << "PyTreeSpecs have incompatible node types; expected type: " << NodeKindToString(root) << ", got: " << NodeKindToString(other_root) << "."; @@ -677,7 +686,8 @@ ssize_t PyTreeSpec::PathsImpl(PathVector &paths, // NOLINT[misc-no-recursion] case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { const scoped_critical_section cs{root.node_data}; const auto keys = (root.kind != PyTreeKind::DefaultDict ? py::reinterpret_borrow(root.node_data) @@ -787,7 +797,8 @@ ssize_t PyTreeSpec::AccessorsImpl(Span &accessors, // NOLINT[misc-no-recursion] case PyTreeKind::Dict: case PyTreeKind::OrderedDict: - case PyTreeKind::DefaultDict: { + case PyTreeKind::DefaultDict: + case PyTreeKind::FrozenDict: { const scoped_critical_section cs{root.node_data}; const auto keys = (root.kind != PyTreeKind::DefaultDict ? py::reinterpret_borrow(root.node_data) @@ -854,7 +865,8 @@ py::list PyTreeSpec::Entries() const { } case PyTreeKind::Dict: - case PyTreeKind::OrderedDict: { + case PyTreeKind::OrderedDict: + case PyTreeKind::FrozenDict: { const scoped_critical_section cs{root.node_data}; return py::getattr(root.node_data, "copy")(); } @@ -894,7 +906,8 @@ py::object PyTreeSpec::Entry(ssize_t index) const { } case PyTreeKind::Dict: - case PyTreeKind::OrderedDict: { + case PyTreeKind::OrderedDict: + case PyTreeKind::FrozenDict: { const scoped_critical_section cs{root.node_data}; return ListGetItem(root.node_data, index); } @@ -995,6 +1008,10 @@ py::object PyTreeSpec::GetType(const std::optional &node) const { return PyDefaultDictTypeObject; case PyTreeKind::Deque: return PyDequeTypeObject; + case PyTreeKind::FrozenDict: +#if defined(OPTREE_HAS_FROZENDICT) + return PyFrozenDictTypeObject; +#endif case PyTreeKind::NumKinds: default: INTERNAL_ERROR(); diff --git a/src/treespec/unflatten.cpp b/src/treespec/unflatten.cpp index 137d1b14..a8b80f21 100644 --- a/src/treespec/unflatten.cpp +++ b/src/treespec/unflatten.cpp @@ -55,6 +55,7 @@ py::object PyTreeSpec::UnflattenImpl(const Span &leaves) const { case PyTreeKind::DefaultDict: case PyTreeKind::Deque: case PyTreeKind::StructSequence: + case PyTreeKind::FrozenDict: case PyTreeKind::Custom: { const ssize_t size = py::ssize_t_cast(agenda.size()); py::object out = MakeNode(node, diff --git a/tests/concurrent/test_subinterpreters.py b/tests/concurrent/test_subinterpreters.py index aec86e9a..5bcf1960 100644 --- a/tests/concurrent/test_subinterpreters.py +++ b/tests/concurrent/test_subinterpreters.py @@ -81,6 +81,7 @@ def concurrent_run(func, /, *args, **kwargs): def check_module_importable(): import collections + import sys import time import optree @@ -93,7 +94,9 @@ def check_module_importable(): if is_current_interpreter_main != (main_interpreter_id == current_interpreter_id): raise RuntimeError('interpreter identity mismatch') - if not is_current_interpreter_main and optree._C.get_registry_size() != 8: + if not is_current_interpreter_main and optree._C.get_registry_size() != ( + 9 if sys.version_info >= (3, 15) and optree._C.OPTREE_HAS_FROZENDICT else 8 + ): raise RuntimeError('registry size mismatch') tree = { @@ -106,6 +109,26 @@ def check_module_importable(): ), 'g': collections.defaultdict(list, h=collections.deque([7, 8, 9], maxlen=10)), } + expected_leaves1 = [1, 2, 3, 4, 5, 6, 7, 8, 9] + expected_leaves2 = [ + 1, + 2, + 3, + 4, + None, + 5, + 6, + *([None] * (time.struct_time.n_sequence_fields - 1)), + 7, + 8, + 9, + ] + if sys.version_info >= (3, 15) and optree._C.OPTREE_HAS_FROZENDICT: + from builtins import frozendict # pylint: disable=no-name-in-module + + tree['i'] = frozendict({'k': 11, 'j': 10}) + expected_leaves1.extend([10, 11]) + expected_leaves2.extend([10, 11]) leaves1, treespec1 = optree.tree_flatten(tree, none_is_leaf=False) reconstructed1 = optree.tree_unflatten(treespec1, leaves1) @@ -113,7 +136,7 @@ def check_module_importable(): raise RuntimeError('unflatten/flatten mismatch') if treespec1.num_leaves != len(leaves1): raise RuntimeError(f'num_leaves mismatch: ({leaves1}, {treespec1})') - if leaves1 != [1, 2, 3, 4, 5, 6, 7, 8, 9]: + if leaves1 != expected_leaves1: raise RuntimeError(f'flattened leaves mismatch: ({leaves1}, {treespec1})') leaves2, treespec2 = optree.tree_flatten(tree, none_is_leaf=True) @@ -122,19 +145,7 @@ def check_module_importable(): raise RuntimeError('unflatten/flatten mismatch') if treespec2.num_leaves != len(leaves2): raise RuntimeError(f'num_leaves mismatch: ({leaves2}, {treespec2})') - if leaves2 != [ - 1, - 2, - 3, - 4, - None, - 5, - 6, - *([None] * (time.struct_time.n_sequence_fields - 1)), - 7, - 8, - 9, - ]: + if leaves2 != expected_leaves2: raise RuntimeError(f'flattened leaves mismatch: ({leaves2}, {treespec2})') _ = optree.tree_flatten_with_path(tree, none_is_leaf=False) @@ -328,9 +339,13 @@ def test_import_in_subinterpreters_concurrently(): from concurrent.futures import InterpreterPoolExecutor, as_completed def check_import(): + import sys import optree + import optree._C - if optree._C.get_registry_size() != 8: + if optree._C.get_registry_size() != ( + 9 if sys.version_info >= (3, 15) and optree._C.OPTREE_HAS_FROZENDICT else 8 + ): raise RuntimeError('registry size mismatch') if optree._C.is_current_interpreter_main(): raise RuntimeError('expected subinterpreter') diff --git a/tests/concurrent/test_threading.py b/tests/concurrent/test_threading.py index 965cc96e..00a96061 100644 --- a/tests/concurrent/test_threading.py +++ b/tests/concurrent/test_threading.py @@ -19,7 +19,6 @@ import itertools import pickle import weakref -from collections import OrderedDict, defaultdict from concurrent.futures import ThreadPoolExecutor, as_completed import pytest @@ -28,6 +27,7 @@ from helpers import ( GLOBAL_NAMESPACE, PYPY, + STANDARD_DICT_TYPES, TREES, WASM, Py_DEBUG, @@ -324,7 +324,7 @@ def check3(): actual = pickle.loads(expected_serialized) concurrent_run(check1) concurrent_run(check2) - if expected.type in {dict, OrderedDict, defaultdict}: + if expected.type in STANDARD_DICT_TYPES: concurrent_run(check3) diff --git a/tests/helpers.py b/tests/helpers.py index 7f426b11..2b1d2548 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -36,6 +36,7 @@ import optree from optree._C import ( + OPTREE_HAS_FROZENDICT, OPTREE_HAS_SUBINTERPRETER_SUPPORT, PYBIND11_HAS_NATIVE_ENUM, PYBIND11_HAS_SUBINTERPRETER_SUPPORT, @@ -43,6 +44,7 @@ Py_GIL_DISABLED, get_registry_size, ) +from optree.ops import STANDARD_DICT_TYPES as STANDARD_DICT_TYPES from optree.registry import __GLOBAL_NAMESPACE as GLOBAL_NAMESPACE from optree.registry import _NODETYPE_REGISTRY as NODETYPE_REGISTRY @@ -51,7 +53,7 @@ INITIAL_REGISTRY_SIZE = get_registry_size() -assert INITIAL_REGISTRY_SIZE == 8 +assert INITIAL_REGISTRY_SIZE == (9 if sys.version_info >= (3, 15) and OPTREE_HAS_FROZENDICT else 8) assert INITIAL_REGISTRY_SIZE + 2 == len(NODETYPE_REGISTRY) _ = PYBIND11_HAS_NATIVE_ENUM @@ -241,7 +243,7 @@ def is_dict(dct): def is_primitive_collection(obj): if type(obj) in {tuple, list, deque}: return all(isinstance(item, (int, float, str, bool, type(None))) for item in obj) - if type(obj) in {dict, OrderedDict, defaultdict}: + if type(obj) in STANDARD_DICT_TYPES: return all(isinstance(value, (int, float, str, bool, type(None))) for value in obj.values()) return False @@ -1552,8 +1554,8 @@ def __next__(self): 'PyTreeSpec(defaultdict(None, {}))', "PyTreeSpec(defaultdict(, {}))", "PyTreeSpec(defaultdict(, {'baz': *, 'foo': *, 'something': *}))", - 'PyTreeSpec(deque([]))', - 'PyTreeSpec(deque([], maxlen=0))', + 'PyTreeSpec(deque())', + 'PyTreeSpec(deque(maxlen=0))', 'PyTreeSpec(deque([None, *, *]))', 'PyTreeSpec(deque([None, *], maxlen=2))', "PyTreeSpec(CustomTreeNode(MyDict[['foo', 'baz']], [CustomTreeNode(MyDict[['c', 'b', 'a']], [None, *, *]), *]))", @@ -1595,8 +1597,8 @@ def __next__(self): 'PyTreeSpec(defaultdict(None, {}), NoneIsLeaf)', "PyTreeSpec(defaultdict(, {}), NoneIsLeaf)", "PyTreeSpec(defaultdict(, {'baz': *, 'foo': *, 'something': *}), NoneIsLeaf)", - 'PyTreeSpec(deque([]), NoneIsLeaf)', - 'PyTreeSpec(deque([], maxlen=0), NoneIsLeaf)', + 'PyTreeSpec(deque(), NoneIsLeaf)', + 'PyTreeSpec(deque(maxlen=0), NoneIsLeaf)', 'PyTreeSpec(deque([*, *, *]), NoneIsLeaf)', 'PyTreeSpec(deque([*, *], maxlen=2), NoneIsLeaf)', "PyTreeSpec(CustomTreeNode(MyDict[['foo', 'baz']], [CustomTreeNode(MyDict[['c', 'b', 'a']], [*, *, *]), *]), NoneIsLeaf)", @@ -1610,6 +1612,87 @@ def __next__(self): "PyTreeSpec(CustomTreeNode(FlatCache[PyTreeSpec({'a': [*, *]})], [*, *]), NoneIsLeaf)", ) +if sys.version_info >= (3, 15) and OPTREE_HAS_FROZENDICT: + from builtins import frozendict # type: ignore[import] + + TREES = ( # type: ignore[no-redef] + *TREES, + frozendict(), + frozendict({'a': 1, 'b': 2}), + frozendict({'baz': 101, 'foo': -42, 'something': 7}), + ) + TREE_PATHS_NONE_IS_NODE = [ + *TREE_PATHS_NONE_IS_NODE, + [], + [('a',), ('b',)], + [('baz',), ('foo',), ('something',)], + ] + TREE_PATHS_NONE_IS_LEAF = [ + *TREE_PATHS_NONE_IS_LEAF, + [], + [('a',), ('b',)], + [('baz',), ('foo',), ('something',)], + ] + TREE_ACCESSORS_NONE_IS_NODE = [ + *TREE_ACCESSORS_NONE_IS_NODE, + [], + [ + optree.PyTreeAccessor( + (optree.MappingEntry('a', frozendict, optree.PyTreeKind.FROZENDICT),), + ), + optree.PyTreeAccessor( + (optree.MappingEntry('b', frozendict, optree.PyTreeKind.FROZENDICT),), + ), + ], + [ + optree.PyTreeAccessor( + (optree.MappingEntry('baz', frozendict, optree.PyTreeKind.FROZENDICT),), + ), + optree.PyTreeAccessor( + (optree.MappingEntry('foo', frozendict, optree.PyTreeKind.FROZENDICT),), + ), + optree.PyTreeAccessor( + (optree.MappingEntry('something', frozendict, optree.PyTreeKind.FROZENDICT),), + ), + ], + ] + TREE_ACCESSORS_NONE_IS_LEAF = [ + *TREE_ACCESSORS_NONE_IS_LEAF, + [], + [ + optree.PyTreeAccessor( + (optree.MappingEntry('a', frozendict, optree.PyTreeKind.FROZENDICT),), + ), + optree.PyTreeAccessor( + (optree.MappingEntry('b', frozendict, optree.PyTreeKind.FROZENDICT),), + ), + ], + [ + optree.PyTreeAccessor( + (optree.MappingEntry('baz', frozendict, optree.PyTreeKind.FROZENDICT),), + ), + optree.PyTreeAccessor( + (optree.MappingEntry('foo', frozendict, optree.PyTreeKind.FROZENDICT),), + ), + optree.PyTreeAccessor( + (optree.MappingEntry('something', frozendict, optree.PyTreeKind.FROZENDICT),), + ), + ], + ] + TREE_STRINGS_NONE_IS_NODE = ( + *TREE_STRINGS_NONE_IS_NODE, + 'PyTreeSpec(frozendict())', + "PyTreeSpec(frozendict({'a': *, 'b': *}))", + "PyTreeSpec(frozendict({'baz': *, 'foo': *, 'something': *}))", + ) + TREE_STRINGS_NONE_IS_LEAF = ( + *TREE_STRINGS_NONE_IS_LEAF, + 'PyTreeSpec(frozendict(), NoneIsLeaf)', + "PyTreeSpec(frozendict({'a': *, 'b': *}), NoneIsLeaf)", + "PyTreeSpec(frozendict({'baz': *, 'foo': *, 'something': *}), NoneIsLeaf)", + ) + + TREE_STRINGS = { optree.NONE_IS_NODE: TREE_STRINGS_NONE_IS_NODE, optree.NONE_IS_LEAF: TREE_STRINGS_NONE_IS_LEAF, diff --git a/tests/test_ops.py b/tests/test_ops.py index fc70fd8f..e7eadf5f 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -15,6 +15,7 @@ # pylint: disable=missing-function-docstring,invalid-name +import builtins import copy import functools import itertools @@ -32,6 +33,7 @@ GLOBAL_NAMESPACE, IS_LEAF_FUNCTIONS, LEAVES, + OPTREE_HAS_FROZENDICT, TREE_ACCESSORS, TREE_PATHS, TREES, @@ -3381,6 +3383,16 @@ def flatten(node): # noqa: C901 else: assert metadata == (node.default_factory, list(node.keys())) assert node_kind == optree.PyTreeKind.DEFAULTDICT + elif ( + sys.version_info >= (3, 15) + and OPTREE_HAS_FROZENDICT + and node_type is builtins.frozendict # type: ignore[attr-defined] + ): + if dict_should_be_sorted or dict_session_namespace not in {'', namespace}: + assert metadata == sorted(node.keys()) + else: + assert metadata == list(node.keys()) + assert node_kind == optree.PyTreeKind.FROZENDICT elif node_type is deque: assert metadata == node.maxlen assert node_kind == optree.PyTreeKind.DEQUE diff --git a/tests/test_prefix_errors.py b/tests/test_prefix_errors.py index 73712ce6..6dc6cf0e 100644 --- a/tests/test_prefix_errors.py +++ b/tests/test_prefix_errors.py @@ -25,6 +25,7 @@ import optree from helpers import ( GLOBAL_NAMESPACE, + STANDARD_DICT_TYPES, TREES, CustomTuple, FlatCache, @@ -58,7 +59,7 @@ def test_different_types(): with pytest.raises( ValueError, match=( - r'Expected an instance of dict, collections.OrderedDict, or collections.defaultdict, ' + r'Expected an instance of dict, (frozendict, )?collections.OrderedDict, or collections.defaultdict, ' r'got .*\.' ), ): @@ -484,10 +485,10 @@ def build_subtree(x): return def shuffle_dictionary(x): - if type(x) in {dict, OrderedDict, defaultdict}: + if type(x) in STANDARD_DICT_TYPES: items = list(x.items()) random.shuffle(items) - dict_type = random.choice([dict, OrderedDict, defaultdict]) + dict_type = random.choice(list(STANDARD_DICT_TYPES)) if dict_type is defaultdict: return defaultdict(getattr(x, 'default_factory', int), items) return dict_type(items) @@ -496,7 +497,7 @@ def shuffle_dictionary(x): shuffled_tree = optree.tree_map( shuffle_dictionary, tree, - is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict}, + is_leaf=lambda x: type(x) in STANDARD_DICT_TYPES, none_is_leaf=none_is_leaf, namespace=namespace, ) @@ -508,7 +509,7 @@ def shuffle_dictionary(x): shuffled_suffix_tree = optree.tree_map( shuffle_dictionary, suffix_tree, - is_leaf=lambda x: type(x) in {dict, OrderedDict, defaultdict}, + is_leaf=lambda x: type(x) in STANDARD_DICT_TYPES, none_is_leaf=none_is_leaf, namespace=namespace, ) diff --git a/tests/test_treespec.py b/tests/test_treespec.py index 63fa2796..8340942a 100644 --- a/tests/test_treespec.py +++ b/tests/test_treespec.py @@ -15,6 +15,7 @@ # pylint: disable=missing-function-docstring,invalid-name +import builtins import contextlib import itertools import os @@ -36,7 +37,9 @@ from helpers import ( GLOBAL_NAMESPACE, NAMESPACED_TREE, + OPTREE_HAS_FROZENDICT, PYPY, + STANDARD_DICT_TYPES, TEST_ROOT, TREE_STRINGS, TREES, @@ -511,7 +514,7 @@ def test_treespec_pickle_roundtrip( else: actual = pickle.loads(pickle.dumps(expected)) assert actual == expected - if expected.type in {dict, OrderedDict, defaultdict}: + if expected.type in STANDARD_DICT_TYPES: assert list(optree.tree_unflatten(actual, range(len(actual)))) == list( optree.tree_unflatten(expected, range(len(expected))), ) @@ -1665,6 +1668,58 @@ def test_treespec_constructor( # noqa: C901 ) == expected_treespec ) + elif ( + sys.version_info >= (3, 15) + and OPTREE_HAS_FROZENDICT + and node_type is builtins.frozendict # type: ignore[attr-defined] + ): + if dict_should_be_sorted or dict_session_namespace not in {'', namespace}: + assert ( + optree.treespec_frozendict( + zip(sorted(node), children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + assert ( + optree.treespec_from_collection( + builtins.frozendict( # type: ignore[attr-defined] + zip(sorted(node), children_treespecs), + ), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + else: + context = ( + optree.dict_insertion_ordered( + True, + namespace=passed_namespace or GLOBAL_NAMESPACE, + ) + if dict_session_namespace != passed_namespace + else contextlib.nullcontext() + ) + with context: + assert ( + optree.treespec_frozendict( + zip(node, children_treespecs), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) + assert ( + optree.treespec_from_collection( + builtins.frozendict( # type: ignore[attr-defined] + zip(node, children_treespecs), + ), + none_is_leaf=none_is_leaf, + namespace=passed_namespace, + ) + == expected_treespec + ) elif node_type is deque: assert ( optree.treespec_deque(