Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.9 - 3.14"
python-version: "3.9 - 3.15"
update-environment: true

- name: Upgrade pip
Expand Down Expand Up @@ -154,6 +154,8 @@ jobs:
- "3.13t"
- "3.14"
- "3.14t"
# - "3.15"
# - "3.15t"
- "pypy-3.11"
exclude:
# Exclude unsupported Python versions
Expand Down Expand Up @@ -183,6 +185,11 @@ jobs:
runner: macos-latest
platform: ios
archs: "arm64_iphoneos"
# - python-version: "3.15"
# target:
# runner: macos-latest
# platform: ios
# archs: "arm64_iphoneos"
# iOS Simulator
- python-version: "3.13"
target:
Expand All @@ -194,6 +201,11 @@ jobs:
runner: macos-latest
platform: ios
archs: "arm64_iphonesimulator"
# - python-version: "3.15"
# target:
# runner: macos-latest
# platform: ios
# archs: "arm64_iphonesimulator"
# Android
- python-version: "3.13"
target:
Expand All @@ -205,6 +217,11 @@ jobs:
runner: ubuntu-latest
platform: android
archs: "arm64_v8a"
# - python-version: "3.15"
# target:
# runner: ubuntu-latest
# platform: android
# archs: "arm64_v8a"
# Pyodide
- python-version: "3.12"
target:
Expand Down Expand Up @@ -387,7 +404,7 @@ jobs:
if: startsWith(github.ref, 'refs/tags/')
uses: actions/setup-python@v6
with:
python-version: "3.9 - 3.14"
python-version: "3.9 - 3.15"
update-environment: true

- name: Upgrade pip
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/tests-with-pydebug.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ jobs:
- "3.12"
- "3.13"
- "3.14"
- "3.15"
python-abiflags: ["d", "td"]
exclude:
- python-version: "3.9"
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ jobs:
- "3.13t"
- "3.14"
- "3.14t"
- "3.15"
- "3.15t"
- "pypy-3.11"
fail-fast: false
timeout-minutes: 120
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ OpTree out-of-the-box supports the following Python container types in the globa
- [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict)
- [`collections.deque`](https://docs.python.org/3/library/collections.html#collections.deque)
- [`PyStructSequence`](https://docs.python.org/3/c-api/tuple.html#struct-sequence-objects) types created by C API [`PyStructSequence_NewType`](https://docs.python.org/3/c-api/tuple.html#c.PyStructSequence_NewType)
- [`frozendict`](https://docs.python.org/3/library/stdtypes.html#frozendict) (Python 3.15+)

These types are considered non-leaf nodes in the tree.
Python objects whose type is not registered are treated as leaf nodes.
Expand Down Expand Up @@ -356,7 +357,7 @@ There are several key attributes of the pytree type registry:
> [!WARNING]
> Any `PyTreeSpec` objects created before the unregistration still hold a reference to the old registration. Unflattening such a `PyTreeSpec` will use the **old** `unflatten_func`, not the newly registered one.

3. **Built-in types cannot be re-registered.** The behavior of the types listed in [Built-in PyTree Node Types](#built-in-pytree-node-types) (e.g., key-sorted traversal for `dict` and `collections.defaultdict`) is fixed.
3. **Built-in types cannot be re-registered.** The behavior of the types listed in [Built-in PyTree Node Types](#built-in-pytree-node-types) (e.g., key-sorted traversal for `dict`, `collections.defaultdict`, and `frozendict`) is fixed.

4. **Inherited subclasses are not implicitly registered.** The registry lookup uses `type(obj) is registered_type` rather than `isinstance(obj, registered_type)`. Users need to register the subclasses explicitly. To register all subclasses, it is easy to implement with [`metaclass`](https://docs.python.org/3/reference/datamodel.html#metaclasses) or [`__init_subclass__`](https://docs.python.org/3/reference/datamodel.html#customizing-class-creation), for example:

Expand Down Expand Up @@ -496,7 +497,7 @@ OrderedDict({
The built-in Python dictionary ([`builtins.dict`](https://docs.python.org/3/library/stdtypes.html#dict)) is a mapping whose leaves are its values.
Since [Python 3.7](https://docs.python.org/3/whatsnew/3.7.html), `dict` is guaranteed to be insertion ordered, but the equality operator (`==`) ignores key order.
To ensure [referential transparency](https://en.wikipedia.org/wiki/Referential_transparency) — "equal `dict`" implies "equal ordering of leaves" — the leaves (values) are returned in key-sorted order.
The same applies to [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict).
The same applies to [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict) and [`frozendict`](https://docs.python.org/3/library/stdtypes.html#frozendict) (Python 3.15+).

```python
>>> optree.tree_flatten({'a': [1, 2], 'b': [3]})
Expand Down Expand Up @@ -561,7 +562,7 @@ False
([3, 1, 2], PyTreeSpec(OrderedDict({'b': [*], 'a': [*, *]})))
```

To flatten [`builtins.dict`](https://docs.python.org/3/library/stdtypes.html#dict) and [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict) objects with the insertion order preserved, use the `dict_insertion_ordered` context manager:
To flatten [`builtins.dict`](https://docs.python.org/3/library/stdtypes.html#dict), [`collections.defaultdict`](https://docs.python.org/3/library/collections.html#collections.defaultdict), and [`frozendict`](https://docs.python.org/3/library/stdtypes.html#frozendict) (Python 3.15+) objects with the insertion order preserved, use the `dict_insertion_ordered` context manager:

```python
>>> tree = {'b': (2, [3, 4]), 'a': 1, 'c': None, 'd': 5}
Expand Down
1 change: 1 addition & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,4 +175,5 @@ Tree Reduce Functions
.. autofunction:: treespec_defaultdict
.. autofunction:: treespec_deque
.. autofunction:: treespec_structseq
.. TODO(frozendict): Add ``.. autofunction:: treespec_frozendict`` when building with Python 3.15+.
.. autofunction:: treespec_from_collection
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ eq
fillvalue
fmt
forwardref
frozendict
frozenset
func
functools
Expand Down
2 changes: 2 additions & 0 deletions docs/source/treespec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@ Check section :ref:`PyTreeSpec Functions` for more detailed documentation.
deque
structseq
from_collection

.. TODO(frozendict): Add ``frozendict`` to the autosummary when building with Python 3.15+.
25 changes: 22 additions & 3 deletions include/optree/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ constexpr py::ssize_t MAX_TYPE_CACHE_SIZE = 4096;
#define PyOrderedDict_Type (reinterpret_cast<PyTypeObject *>(PyOrderedDictTypeObject.ptr()))
#define PyDefaultDict_Type (reinterpret_cast<PyTypeObject *>(PyDefaultDictTypeObject.ptr()))
#define PyDeque_Type (reinterpret_cast<PyTypeObject *>(PyDequeTypeObject.ptr()))
#if PY_VERSION_HEX >= 0x030F00A7 // Python 3.15.0a7+
# define PyFrozenDictTypeObject \
(py::reinterpret_borrow<py::object>(reinterpret_cast<PyObject *>(&PyFrozenDict_Type)))
#endif

inline const py::object &ImportOrderedDict() {
PYBIND11_CONSTINIT static py::gil_safe_call_once_and_store<py::object> storage;
Expand Down Expand Up @@ -181,6 +185,14 @@ inline Py_ALWAYS_INLINE void AssertExactDict(const py::handle &object) {
}
}

#if PY_VERSION_HEX >= 0x030F00A7 // Python 3.15.0a7+
inline Py_ALWAYS_INLINE void AssertExactFrozenDict(const py::handle &object) {
if (!PyFrozenDict_CheckExact(object.ptr())) [[unlikely]] {
throw py::value_error("Expected an instance of frozendict, got " + PyRepr(object) + ".");
}
}
#endif

inline Py_ALWAYS_INLINE void AssertExactOrderedDict(const py::handle &object) {
if (!py::type::handle_of(object).is(PyOrderedDictTypeObject)) [[unlikely]] {
throw py::value_error("Expected an instance of collections.OrderedDict, got " +
Expand All @@ -198,10 +210,17 @@ inline Py_ALWAYS_INLINE void AssertExactDefaultDict(const py::handle &object) {
inline Py_ALWAYS_INLINE void AssertExactStandardDict(const py::handle &object) {
if (!(PyDict_CheckExact(object.ptr()) ||
py::type::handle_of(object).is(PyOrderedDictTypeObject) ||
py::type::handle_of(object).is(PyDefaultDictTypeObject))) [[unlikely]] {
py::type::handle_of(object).is(PyDefaultDictTypeObject)
#if PY_VERSION_HEX >= 0x030F00A7 // Python 3.15.0a7+
|| PyFrozenDict_CheckExact(object.ptr())
#endif
)) [[unlikely]] {
throw py::value_error(
"Expected an instance of dict, collections.OrderedDict, or collections.defaultdict, "
"got " +
"Expected an instance of dict, "
#if PY_VERSION_HEX >= 0x030F00A7 // Python 3.15.0a7+
"frozendict, "
#endif
"collections.OrderedDict, or collections.defaultdict, got " +
PyRepr(object) + ".");
}
}
Expand Down
2 changes: 2 additions & 0 deletions include/optree/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ enum class PyTreeKind : std::uint8_t {
DefaultDict, // A collections.defaultdict
Deque, // A collections.deque
StructSequence, // A PyStructSequence
FrozenDict, // A frozendict (Python 3.15+)
NumKinds, // Number of kinds (placed at the end)
};

Expand All @@ -67,6 +68,7 @@ constexpr PyTreeKind kOrderedDict = PyTreeKind::OrderedDict;
constexpr PyTreeKind kDefaultDict = PyTreeKind::DefaultDict;
constexpr PyTreeKind kDeque = PyTreeKind::Deque;
constexpr PyTreeKind kStructSequence = PyTreeKind::StructSequence;
constexpr PyTreeKind kFrozenDict = PyTreeKind::FrozenDict;
constexpr PyTreeKind kNumPyTreeKinds = PyTreeKind::NumKinds;

// Registry of custom node types.
Expand Down
4 changes: 2 additions & 2 deletions include/optree/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ class PyTreeSpec {

// Kind-specific metadata.
// For a NamedTuple/PyStructSequence, contains the tuple type object.
// For a Dict, contains a sorted list of keys.
// For a Dict or FrozenDict, contains a sorted list of keys.
// For a OrderedDict, contains a list of keys.
// For a DefaultDict, contains a tuple of (default_factory, sorted list of keys).
// For a Deque, contains the `maxlen` attribute.
Expand All @@ -293,7 +293,7 @@ class PyTreeSpec {
// Number of leaf and interior nodes in the subtree rooted at this node.
ssize_t num_nodes = 0;

// For a Dict or DefaultDict, contains the keys in insertion order.
// For a Dict, DefaultDict, or FrozenDict, contains the keys in insertion order.
py::object original_keys{};
};

Expand Down
1 change: 1 addition & 0 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class PyTreeKind(enum.IntEnum):
DEFAULTDICT = enum.auto() # a collections.defaultdict
DEQUE = enum.auto() # a collections.deque
STRUCTSEQUENCE = enum.auto() # a PyStructSequence
FROZENDICT = enum.auto() # a frozendict (Python 3.15+)

NUM_KINDS: ClassVar[int]

Expand Down
12 changes: 12 additions & 0 deletions optree/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# ==============================================================================
"""OpTree: Optimized PyTree Utilities."""

import sys

from optree import accessors, dataclasses, functools, integrations, pytree, treespec, typing
from optree.accessors import (
AutoEntry,
Expand Down Expand Up @@ -225,6 +227,13 @@
'structseq_fields',
]


if sys.version_info >= (3, 15): # pragma: >=3.15 cover
from optree.ops import treespec_frozendict

__all__.insert(__all__.index('treespec_from_collection'), 'treespec_frozendict')


MAX_RECURSION_DEPTH: int = MAX_RECURSION_DEPTH
"""Maximum recursion depth for pytree traversal.

Expand All @@ -235,3 +244,6 @@
"""Literal constant that treats :data:`None` as a pytree non-leaf node."""
NONE_IS_LEAF: bool = NONE_IS_LEAF # literal constant
"""Literal constant that treats :data:`None` as a pytree leaf node."""


del sys
73 changes: 63 additions & 10 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import difflib
import functools
import itertools
import sys
import textwrap
from collections import OrderedDict, defaultdict, deque
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Generic, overload
Expand Down Expand Up @@ -115,6 +116,13 @@
'prefix_errors',
]


if sys.version_info >= (3, 15): # pragma: >=3.15 cover
from builtins import frozendict # type: ignore[import] # pylint: disable=no-name-in-module

__all__.insert(__all__.index('treespec_from_collection'), 'treespec_frozendict')


MAX_RECURSION_DEPTH: int = _C.MAX_RECURSION_DEPTH
"""Maximum recursion depth for pytree traversal.

Expand Down Expand Up @@ -160,9 +168,10 @@ def tree_flatten(
>>> tree_flatten(None, none_is_leaf=True)
([None], PyTreeSpec(*, NoneIsLeaf))

For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
if you want to keep the keys in the insertion order.
For unordered dictionaries, :class:`dict`, :class:`collections.defaultdict`, and
:class:`frozendict` (Python 3.15+), the order is dependent on the **sorted** keys in the
dictionary. Please use :class:`collections.OrderedDict` if you want to keep the keys in the
insertion order.

>>> from collections import OrderedDict
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
Expand Down Expand Up @@ -232,9 +241,10 @@ def tree_flatten_with_path(
>>> tree_flatten_with_path(None, none_is_leaf=True)
([()], [None], PyTreeSpec(*, NoneIsLeaf))

For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
if you want to keep the keys in the insertion order.
For unordered dictionaries, :class:`dict`, :class:`collections.defaultdict`, and
:class:`frozendict` (Python 3.15+), the order is dependent on the **sorted** keys in the
dictionary. Please use :class:`collections.OrderedDict` if you want to keep the keys in the
insertion order.

>>> from collections import OrderedDict
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
Expand Down Expand Up @@ -320,9 +330,10 @@ def tree_flatten_with_accessor(
>>> tree_flatten_with_accessor(None, none_is_leaf=True)
([PyTreeAccessor(*, ())], [None], PyTreeSpec(*, NoneIsLeaf))

For unordered dictionaries, :class:`dict` and :class:`collections.defaultdict`, the order is
dependent on the **sorted** keys in the dictionary. Please use :class:`collections.OrderedDict`
if you want to keep the keys in the insertion order.
For unordered dictionaries, :class:`dict`, :class:`collections.defaultdict`, and
:class:`frozendict` (Python 3.15+), the order is dependent on the **sorted** keys in the
dictionary. Please use :class:`collections.OrderedDict` if you want to keep the keys in the
insertion order.

>>> from collections import OrderedDict
>>> tree = OrderedDict([('b', (2, [3, 4])), ('a', 1), ('c', None), ('d', 5)])
Expand Down Expand Up @@ -3408,7 +3419,7 @@ def treespec_deque(
>>> treespec_deque([treespec_leaf(), treespec_leaf(), treespec_none()], maxlen=5)
PyTreeSpec(deque([*, *, None], maxlen=5))
>>> treespec_deque()
PyTreeSpec(deque([]))
PyTreeSpec(deque())
>>> treespec_deque([treespec_leaf(), treespec_tuple([treespec_leaf(), treespec_leaf()])])
PyTreeSpec(deque([*, (*, *)]))
>>> treespec_deque([treespec_leaf(), tree_structure({'a': 1, 'b': 2})], maxlen=5)
Expand Down Expand Up @@ -3473,6 +3484,46 @@ def treespec_structseq(
)


if sys.version_info >= (3, 15): # pragma: >=3.15 cover

def treespec_frozendict(
mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (),
/,
*,
none_is_leaf: bool = False,
namespace: str = '',
**kwargs: PyTreeSpec,
) -> PyTreeSpec:
"""Make a frozendict treespec from a frozendict of child treespecs.

See also :func:`tree_structure`, :func:`treespec_leaf`, and :func:`treespec_none`.

>>> treespec_frozendict({'a': treespec_leaf(), 'b': treespec_leaf()}) # doctest: +SKIP
PyTreeSpec(frozendict({'a': *, 'b': *}))
>>> treespec_frozendict() # doctest: +SKIP
PyTreeSpec(frozendict())

Args:
mapping (mapping of PyTreeSpec, optional): A mapping of child treespecs. They must have
the same ``none_is_leaf`` and ``namespace`` values.
none_is_leaf (bool, optional): Whether to treat :data:`None` as a leaf. If :data:`False`,
:data:`None` is a non-leaf node with arity 0. Thus :data:`None` is contained in the
treespec rather than in the leaves list and :data:`None` will remain in the result
pytree. (default: :data:`False`)
namespace (str, optional): The registry namespace used for custom pytree node types.
(default: :const:`''`, i.e., the global namespace)
**kwargs (PyTreeSpec, optional): Additional child treespecs to add to the mapping.

Returns:
A treespec representing a frozendict node with the given children.
"""
return _C.make_from_collection(
frozendict(mapping, **kwargs),
none_is_leaf,
namespace,
)


def treespec_from_collection(
collection: Collection[PyTreeSpec],
/,
Expand Down Expand Up @@ -3523,6 +3574,8 @@ def treespec_from_collection(


STANDARD_DICT_TYPES: frozenset[type] = frozenset({dict, OrderedDict, defaultdict})
if sys.version_info >= (3, 15): # pragma: >=3.15 cover
STANDARD_DICT_TYPES |= frozenset({frozendict})


def prefix_errors( # noqa: C901
Expand Down
Loading
Loading