From 472d0ea529cc4e94f9e61b9dffb123bcc6efc5b4 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 31 Mar 2026 22:50:12 +0800 Subject: [PATCH] feat(functools): add `Placeholder` support and tests for `functools.partial` Re-export `functools.Placeholder` (Python 3.14+) from `optree.functools` and add comprehensive tests verifying that `optree.functools.partial` works consistently with stdlib's `functools.partial` when Placeholders are used in positional arguments. --- CHANGELOG.md | 1 + optree/functools.py | 18 ++++ tests/test_functools.py | 230 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 249 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8dae7113..c893c94a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Add `attrs` integration module `optree.integrations.attrs` with `field`, `define`, `frozen`, `mutable`, `make_class`, `register_node`, and `AttrsEntry` by [@XuehaiPan](https://github.com/XuehaiPan) in [#273](https://github.com/metaopt/optree/pull/273). - Add `optree.dataclasses.register_node` to register existing dataclasses as pytree nodes by [@XuehaiPan](https://github.com/XuehaiPan) in [#273](https://github.com/metaopt/optree/pull/273). - Extend `GetAttrEntry` to support dotted attribute paths for traversing nested attributes (e.g., `a.b.c`) by [@XuehaiPan](https://github.com/XuehaiPan). +- Add `functools.Placeholder` support and re-export for `optree.functools.partial` (Python 3.14+) by [@XuehaiPan](https://github.com/XuehaiPan) in [#276](https://github.com/metaopt/optree/pull/276). ### Changed diff --git a/optree/functools.py b/optree/functools.py index da04b85b..21563d28 100644 --- a/optree/functools.py +++ b/optree/functools.py @@ -16,6 +16,7 @@ from __future__ import annotations +import contextlib import functools from typing import TYPE_CHECKING, Any, Callable, ClassVar from typing_extensions import Self # Python 3.11+ @@ -36,6 +37,13 @@ ] +with contextlib.suppress(ImportError): # pragma: >=3.14 cover + # pylint: disable-next=no-name-in-module,unused-import + from functools import Placeholder # type: ignore[attr-defined] + + __all__ += ['Placeholder'] + + class _HashablePartialShim: """A shim object that delegates :meth:`__call__`, :meth:`__eq__`, and :meth:`__hash__` to a :func:`functools.partial` object.""" # pylint: disable=line-too-long @@ -111,6 +119,16 @@ class partial( # noqa: N801 # pylint: disable=invalid-name,too-few-public-metho Had we passed :func:`operator.add` to ``call_func_on_cuda`` directly, it would have resulted in a :class:`TypeError` or :class:`AttributeError`. + + On Python 3.14+, :data:`functools.Placeholder` can be used to reserve positional argument slots: + + >>> from functools import Placeholder # doctest: +SKIP + >>> square = partial(pow, Placeholder, 2) # doctest: +SKIP + >>> square(5) + 25 + + :data:`~functools.Placeholder` objects are treated as leaves in the pytree and their identity is + preserved through flatten/unflatten round-trips. """ __slots__: ClassVar[tuple[()]] = () diff --git a/tests/test_functools.py b/tests/test_functools.py index c7f92e6d..f740112d 100644 --- a/tests/test_functools.py +++ b/tests/test_functools.py @@ -17,10 +17,19 @@ import functools +import pytest + import optree from helpers import GLOBAL_NAMESPACE, parametrize +HAS_PLACEHOLDER = hasattr(functools, 'Placeholder') +needs_placeholder = pytest.mark.skipif( + not HAS_PLACEHOLDER, + reason='functools.Placeholder requires Python 3.14+', +) + + def dummy_func(*args, **kwargs): # pylint: disable=unused-argument return @@ -80,3 +89,224 @@ def test_partial_func_attribute_has_stable_hash(): assert fn == p1.func # pylint: disable=comparison-with-callable assert p1.func == p2.func assert hash(p1.func) == hash(p2.func) + + +@needs_placeholder +def test_partial_placeholder_roundtrip(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + p1 = optree.functools.partial(f, ph, 42) + leaves, treespec = optree.tree_flatten(p1) + p2 = optree.tree_unflatten(treespec, leaves) + assert p2.func == p1.func + assert p2.args == p1.args + assert p2.args[0] is ph + assert p2.keywords == p1.keywords + assert p2('x') == f('x', 42) + + +@needs_placeholder +def test_partial_placeholder_call_after_roundtrip(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + p1 = optree.functools.partial(f, ph, 42) + leaves, treespec = optree.tree_flatten(p1) + p2 = optree.tree_unflatten(treespec, leaves) + + # Fill placeholder + assert p2('x') == (('x', 42), {}) + + # Extra args beyond placeholder + assert p2('x', 'y') == (('x', 42, 'y'), {}) + + # Missing placeholder arg + with pytest.raises(TypeError, match='missing positional arguments'): + p2() + + +@needs_placeholder +def test_partial_multiple_placeholders_roundtrip(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + p1 = optree.functools.partial(f, ph, 42, ph, 99) + leaves, treespec = optree.tree_flatten(p1) + p2 = optree.tree_unflatten(treespec, leaves) + assert p2.args == (ph, 42, ph, 99) + assert p2.args[0] is ph + assert p2.args[2] is ph + assert p2('a', 'b') == (('a', 42, 'b', 99), {}) + + +@needs_placeholder +def test_partial_placeholder_with_keywords(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + p1 = optree.functools.partial(f, ph, 42, key='value') + leaves, treespec = optree.tree_flatten(p1) + p2 = optree.tree_unflatten(treespec, leaves) + assert p2.args == (ph, 42) + assert p2.keywords == {'key': 'value'} + assert p2('x') == (('x', 42), {'key': 'value'}) + + +@needs_placeholder +def test_partial_placeholder_is_leaf(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + p = optree.functools.partial(f, ph, 42) + leaves = optree.tree_leaves(p) + assert ph in leaves + assert 42 in leaves + + +@needs_placeholder +def test_partial_placeholder_tree_map(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + p1 = optree.functools.partial(f, ph, 42) + + # Identity tree_map preserves Placeholder + p2 = optree.tree_map(lambda x: x, p1) + assert p2.args[0] is ph + assert p2.args[1] == 42 + assert p2('test') == (('test', 42), {}) + + +@needs_placeholder +def test_partial_placeholder_in_larger_tree(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + p = optree.functools.partial(f, ph, 42) + tree = {'fn': p, 'data': [1, 2, 3]} + leaves, treespec = optree.tree_flatten(tree) + tree2 = optree.tree_unflatten(treespec, leaves) + assert tree2['fn'].args[0] is ph + assert tree2['fn']('test') == (('test', 42), {}) + assert tree2['data'] == [1, 2, 3] + + +@needs_placeholder +def test_partial_wrapping_stdlib_partial_with_placeholder(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + stdlib_p = functools.partial(f, ph, 42) + op1 = optree.functools.partial(stdlib_p, 'extra') + + # Anti-merge: outer args are separate + assert op1.args == ('extra',) + assert op1() == (('extra', 42), {}) + + # Roundtrip + leaves, treespec = optree.tree_flatten(op1) + op2 = optree.tree_unflatten(treespec, leaves) + assert op2.args == ('extra',) + assert op2() == (('extra', 42), {}) + + +@needs_placeholder +def test_partial_wrapping_stdlib_partial_with_placeholder_no_extra_args(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + stdlib_p = functools.partial(f, ph, 42) + op1 = optree.functools.partial(stdlib_p) + assert op1.args == () + assert op1('hello') == (('hello', 42), {}) + + # Roundtrip + leaves, treespec = optree.tree_flatten(op1) + op2 = optree.tree_unflatten(treespec, leaves) + assert op2('hello') == (('hello', 42), {}) + + +@needs_placeholder +def test_partial_nested_optree_partial_with_placeholder(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + inner = optree.functools.partial(f, ph, 42) + outer = optree.functools.partial(inner, 'extra') + + # Anti-merge behavior + assert outer.args == ('extra',) + assert outer() == (('extra', 42), {}) + + # Roundtrip of outer + leaves, treespec = optree.tree_flatten(outer) + outer2 = optree.tree_unflatten(treespec, leaves) + assert outer2() == (('extra', 42), {}) + + +@needs_placeholder +def test_partial_trailing_placeholder_rejection(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + with pytest.raises(TypeError, match='trailing Placeholders are not allowed'): + optree.functools.partial(f, 42, ph) + + with pytest.raises(TypeError, match='trailing Placeholders are not allowed'): + optree.functools.partial(f, ph) + + with pytest.raises(TypeError, match='trailing Placeholders are not allowed'): + optree.functools.partial(f, ph, 1, ph) + + +@needs_placeholder +def test_partial_keyword_placeholder_rejection(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + with pytest.raises(TypeError, match='Placeholder'): + optree.functools.partial(f, kw=ph) + + +@needs_placeholder +def test_partial_repr_with_placeholder(): + ph = functools.Placeholder + + def f(*args, **kwargs): + return args, kwargs + + p = optree.functools.partial(f, ph, 42) + r = repr(p) + assert 'Placeholder' in r + assert '42' in r + + +@needs_placeholder +def test_partial_placeholder_reexport(): + assert hasattr(optree.functools, 'Placeholder') + assert optree.functools.Placeholder is functools.Placeholder