Skip to content

Commit 3e538e8

Browse files
committed
feat(functools): add Placeholder support and tests for functools.partial
Re-export `functools.Placeholder` (Python 3.14+) from `optree.functools` and add comprehensive tests verifying that `optree.functools.partial` works consistently with stdlib's `functools.partial` when Placeholders are used in positional arguments.
1 parent bd32248 commit 3e538e8

3 files changed

Lines changed: 250 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1616
- Add `attrs` integration module `optree.integrations.attrs` with `field`, `define`, `frozen`, `mutable`, `make_class`, `register_node`, and `AttrsEntry` by [@XuehaiPan](https://github.com/XuehaiPan) in [#273](https://github.com/metaopt/optree/pull/273).
1717
- Add `optree.dataclasses.register_node` to register existing dataclasses as pytree nodes by [@XuehaiPan](https://github.com/XuehaiPan) in [#273](https://github.com/metaopt/optree/pull/273).
1818
- Extend `GetAttrEntry` to support dotted attribute paths for traversing nested attributes (e.g., `a.b.c`) by [@XuehaiPan](https://github.com/XuehaiPan).
19+
- Add `functools.Placeholder` support and re-export for `optree.functools.partial` (Python 3.14+) by [@XuehaiPan](https://github.com/XuehaiPan) in [#276](https://github.com/metaopt/optree/pull/276).
1920

2021
### Changed
2122

optree/functools.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from __future__ import annotations
1818

19+
import contextlib
1920
import functools
2021
from typing import TYPE_CHECKING, Any, Callable, ClassVar
2122
from typing_extensions import Self # Python 3.11+
@@ -29,13 +30,19 @@
2930
if TYPE_CHECKING:
3031
from optree.accessors import PyTreeEntry
3132

32-
3333
__all__ = [
3434
'partial',
3535
'reduce',
3636
]
3737

3838

39+
with contextlib.suppress(ImportError): # pragma: >=3.14 cover
40+
# pylint: disable-next=no-name-in-module,unused-import
41+
from functools import Placeholder # type: ignore[attr-defined]
42+
43+
__all__ += ['Placeholder']
44+
45+
3946
class _HashablePartialShim:
4047
"""A shim object that delegates :meth:`__call__`, :meth:`__eq__`, and :meth:`__hash__` to a :func:`functools.partial` object.""" # pylint: disable=line-too-long
4148

@@ -111,6 +118,17 @@ class partial( # noqa: N801 # pylint: disable=invalid-name,too-few-public-metho
111118
112119
Had we passed :func:`operator.add` to ``call_func_on_cuda`` directly, it would have resulted in
113120
a :class:`TypeError` or :class:`AttributeError`.
121+
122+
On Python 3.14+, :data:`functools.Placeholder` can be used to reserve positional argument slots:
123+
124+
>>> from optree.functools import partial, Placeholder # doctest: +SKIP
125+
>>> import operator
126+
>>> sub_from = partial(operator.sub, Placeholder, 3) # doctest: +SKIP
127+
>>> sub_from(10) # doctest: +SKIP
128+
7
129+
130+
:data:`~functools.Placeholder` objects are treated as leaves in the pytree and their identity is
131+
preserved through flatten/unflatten round-trips.
114132
"""
115133

116134
__slots__: ClassVar[tuple[()]] = ()

tests/test_functools.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,19 @@
1717

1818
import functools
1919

20+
import pytest
21+
2022
import optree
2123
from helpers import GLOBAL_NAMESPACE, parametrize
2224

2325

26+
HAS_PLACEHOLDER = hasattr(functools, 'Placeholder')
27+
needs_placeholder = pytest.mark.skipif(
28+
not HAS_PLACEHOLDER,
29+
reason='functools.Placeholder requires Python 3.14+',
30+
)
31+
32+
2433
def dummy_func(*args, **kwargs): # pylint: disable=unused-argument
2534
return
2635

@@ -80,3 +89,224 @@ def test_partial_func_attribute_has_stable_hash():
8089
assert fn == p1.func # pylint: disable=comparison-with-callable
8190
assert p1.func == p2.func
8291
assert hash(p1.func) == hash(p2.func)
92+
93+
94+
@needs_placeholder
95+
def test_partial_placeholder_roundtrip():
96+
ph = functools.Placeholder
97+
98+
def f(*args, **kwargs):
99+
return args, kwargs
100+
101+
p1 = optree.functools.partial(f, ph, 42)
102+
leaves, treespec = optree.tree_flatten(p1)
103+
p2 = optree.tree_unflatten(treespec, leaves)
104+
assert p2.func == p1.func
105+
assert p2.args == p1.args
106+
assert p2.args[0] is ph
107+
assert p2.keywords == p1.keywords
108+
assert p2('x') == f('x', 42)
109+
110+
111+
@needs_placeholder
112+
def test_partial_placeholder_call_after_roundtrip():
113+
ph = functools.Placeholder
114+
115+
def f(*args, **kwargs):
116+
return args, kwargs
117+
118+
p1 = optree.functools.partial(f, ph, 42)
119+
leaves, treespec = optree.tree_flatten(p1)
120+
p2 = optree.tree_unflatten(treespec, leaves)
121+
122+
# Fill placeholder
123+
assert p2('x') == (('x', 42), {})
124+
125+
# Extra args beyond placeholder
126+
assert p2('x', 'y') == (('x', 42, 'y'), {})
127+
128+
# Missing placeholder arg
129+
with pytest.raises(TypeError, match='missing positional arguments'):
130+
p2()
131+
132+
133+
@needs_placeholder
134+
def test_partial_multiple_placeholders_roundtrip():
135+
ph = functools.Placeholder
136+
137+
def f(*args, **kwargs):
138+
return args, kwargs
139+
140+
p1 = optree.functools.partial(f, ph, 42, ph, 99)
141+
leaves, treespec = optree.tree_flatten(p1)
142+
p2 = optree.tree_unflatten(treespec, leaves)
143+
assert p2.args == (ph, 42, ph, 99)
144+
assert p2.args[0] is ph
145+
assert p2.args[2] is ph
146+
assert p2('a', 'b') == (('a', 42, 'b', 99), {})
147+
148+
149+
@needs_placeholder
150+
def test_partial_placeholder_with_keywords():
151+
ph = functools.Placeholder
152+
153+
def f(*args, **kwargs):
154+
return args, kwargs
155+
156+
p1 = optree.functools.partial(f, ph, 42, key='value')
157+
leaves, treespec = optree.tree_flatten(p1)
158+
p2 = optree.tree_unflatten(treespec, leaves)
159+
assert p2.args == (ph, 42)
160+
assert p2.keywords == {'key': 'value'}
161+
assert p2('x') == (('x', 42), {'key': 'value'})
162+
163+
164+
@needs_placeholder
165+
def test_partial_placeholder_is_leaf():
166+
ph = functools.Placeholder
167+
168+
def f(*args, **kwargs):
169+
return args, kwargs
170+
171+
p = optree.functools.partial(f, ph, 42)
172+
leaves = optree.tree_leaves(p)
173+
assert ph in leaves
174+
assert 42 in leaves
175+
176+
177+
@needs_placeholder
178+
def test_partial_placeholder_tree_map():
179+
ph = functools.Placeholder
180+
181+
def f(*args, **kwargs):
182+
return args, kwargs
183+
184+
p1 = optree.functools.partial(f, ph, 42)
185+
186+
# Identity tree_map preserves Placeholder
187+
p2 = optree.tree_map(lambda x: x, p1)
188+
assert p2.args[0] is ph
189+
assert p2.args[1] == 42
190+
assert p2('test') == (('test', 42), {})
191+
192+
193+
@needs_placeholder
194+
def test_partial_placeholder_in_larger_tree():
195+
ph = functools.Placeholder
196+
197+
def f(*args, **kwargs):
198+
return args, kwargs
199+
200+
p = optree.functools.partial(f, ph, 42)
201+
tree = {'fn': p, 'data': [1, 2, 3]}
202+
leaves, treespec = optree.tree_flatten(tree)
203+
tree2 = optree.tree_unflatten(treespec, leaves)
204+
assert tree2['fn'].args[0] is ph
205+
assert tree2['fn']('test') == (('test', 42), {})
206+
assert tree2['data'] == [1, 2, 3]
207+
208+
209+
@needs_placeholder
210+
def test_partial_wrapping_stdlib_partial_with_placeholder():
211+
ph = functools.Placeholder
212+
213+
def f(*args, **kwargs):
214+
return args, kwargs
215+
216+
stdlib_p = functools.partial(f, ph, 42)
217+
op1 = optree.functools.partial(stdlib_p, 'extra')
218+
219+
# Anti-merge: outer args are separate
220+
assert op1.args == ('extra',)
221+
assert op1() == (('extra', 42), {})
222+
223+
# Roundtrip
224+
leaves, treespec = optree.tree_flatten(op1)
225+
op2 = optree.tree_unflatten(treespec, leaves)
226+
assert op2.args == ('extra',)
227+
assert op2() == (('extra', 42), {})
228+
229+
230+
@needs_placeholder
231+
def test_partial_wrapping_stdlib_partial_with_placeholder_no_extra_args():
232+
ph = functools.Placeholder
233+
234+
def f(*args, **kwargs):
235+
return args, kwargs
236+
237+
stdlib_p = functools.partial(f, ph, 42)
238+
op1 = optree.functools.partial(stdlib_p)
239+
assert op1.args == ()
240+
assert op1('hello') == (('hello', 42), {})
241+
242+
# Roundtrip
243+
leaves, treespec = optree.tree_flatten(op1)
244+
op2 = optree.tree_unflatten(treespec, leaves)
245+
assert op2('hello') == (('hello', 42), {})
246+
247+
248+
@needs_placeholder
249+
def test_partial_nested_optree_partial_with_placeholder():
250+
ph = functools.Placeholder
251+
252+
def f(*args, **kwargs):
253+
return args, kwargs
254+
255+
inner = optree.functools.partial(f, ph, 42)
256+
outer = optree.functools.partial(inner, 'extra')
257+
258+
# Anti-merge behavior
259+
assert outer.args == ('extra',)
260+
assert outer() == (('extra', 42), {})
261+
262+
# Roundtrip of outer
263+
leaves, treespec = optree.tree_flatten(outer)
264+
outer2 = optree.tree_unflatten(treespec, leaves)
265+
assert outer2() == (('extra', 42), {})
266+
267+
268+
@needs_placeholder
269+
def test_partial_trailing_placeholder_rejection():
270+
ph = functools.Placeholder
271+
272+
def f(*args, **kwargs):
273+
return args, kwargs
274+
275+
with pytest.raises(TypeError, match='trailing Placeholders are not allowed'):
276+
optree.functools.partial(f, 42, ph)
277+
278+
with pytest.raises(TypeError, match='trailing Placeholders are not allowed'):
279+
optree.functools.partial(f, ph)
280+
281+
with pytest.raises(TypeError, match='trailing Placeholders are not allowed'):
282+
optree.functools.partial(f, ph, 1, ph)
283+
284+
285+
@needs_placeholder
286+
def test_partial_keyword_placeholder_rejection():
287+
ph = functools.Placeholder
288+
289+
def f(*args, **kwargs):
290+
return args, kwargs
291+
292+
with pytest.raises(TypeError, match='Placeholder'):
293+
optree.functools.partial(f, kw=ph)
294+
295+
296+
@needs_placeholder
297+
def test_partial_repr_with_placeholder():
298+
ph = functools.Placeholder
299+
300+
def f(*args, **kwargs):
301+
return args, kwargs
302+
303+
p = optree.functools.partial(f, ph, 42)
304+
r = repr(p)
305+
assert 'Placeholder' in r
306+
assert '42' in r
307+
308+
309+
@needs_placeholder
310+
def test_partial_placeholder_reexport():
311+
assert hasattr(optree.functools, 'Placeholder')
312+
assert optree.functools.Placeholder is functools.Placeholder

0 commit comments

Comments
 (0)