Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add `attrs` integration module `optree.integrations.attrs` with `field`, `define`, `frozen`, `mutable`, `make_class`, `register_node`, and `AttrsEntry` by [@XuehaiPan](https://github.com/XuehaiPan) in [#273](https://github.com/metaopt/optree/pull/273).
- Add `optree.dataclasses.register_node` to register existing dataclasses as pytree nodes by [@XuehaiPan](https://github.com/XuehaiPan) in [#273](https://github.com/metaopt/optree/pull/273).
- Extend `GetAttrEntry` to support dotted attribute paths for traversing nested attributes (e.g., `a.b.c`) by [@XuehaiPan](https://github.com/XuehaiPan).
- Add `functools.Placeholder` support and re-export for `optree.functools.partial` (Python 3.14+) by [@XuehaiPan](https://github.com/XuehaiPan) in [#276](https://github.com/metaopt/optree/pull/276).

### Changed

Expand Down
18 changes: 18 additions & 0 deletions optree/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import contextlib
import functools
from typing import TYPE_CHECKING, Any, Callable, ClassVar
from typing_extensions import Self # Python 3.11+
Expand All @@ -36,6 +37,13 @@
]


with contextlib.suppress(ImportError): # pragma: >=3.14 cover
# pylint: disable-next=no-name-in-module,unused-import
from functools import Placeholder # type: ignore[attr-defined]

__all__ += ['Placeholder']


class _HashablePartialShim:
"""A shim object that delegates :meth:`__call__`, :meth:`__eq__`, and :meth:`__hash__` to a :func:`functools.partial` object.""" # pylint: disable=line-too-long

Expand Down Expand Up @@ -111,6 +119,16 @@ class partial( # noqa: N801 # pylint: disable=invalid-name,too-few-public-metho

Had we passed :func:`operator.add` to ``call_func_on_cuda`` directly, it would have resulted in
a :class:`TypeError` or :class:`AttributeError`.

On Python 3.14+, :data:`functools.Placeholder` can be used to reserve positional argument slots:

>>> from functools import Placeholder # doctest: +SKIP
>>> square = partial(pow, Placeholder, 2) # doctest: +SKIP
>>> square(5)
25

:data:`~functools.Placeholder` objects are treated as leaves in the pytree and their identity is
preserved through flatten/unflatten round-trips.
"""

__slots__: ClassVar[tuple[()]] = ()
Expand Down
230 changes: 230 additions & 0 deletions tests/test_functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,19 @@

import functools

import pytest

import optree
from helpers import GLOBAL_NAMESPACE, parametrize


HAS_PLACEHOLDER = hasattr(functools, 'Placeholder')
needs_placeholder = pytest.mark.skipif(
not HAS_PLACEHOLDER,
reason='functools.Placeholder requires Python 3.14+',
)


def dummy_func(*args, **kwargs): # pylint: disable=unused-argument
return

Expand Down Expand Up @@ -80,3 +89,224 @@ def test_partial_func_attribute_has_stable_hash():
assert fn == p1.func # pylint: disable=comparison-with-callable
assert p1.func == p2.func
assert hash(p1.func) == hash(p2.func)


@needs_placeholder
def test_partial_placeholder_roundtrip():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

p1 = optree.functools.partial(f, ph, 42)
leaves, treespec = optree.tree_flatten(p1)
p2 = optree.tree_unflatten(treespec, leaves)
assert p2.func == p1.func
assert p2.args == p1.args
assert p2.args[0] is ph
assert p2.keywords == p1.keywords
assert p2('x') == f('x', 42)


@needs_placeholder
def test_partial_placeholder_call_after_roundtrip():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

p1 = optree.functools.partial(f, ph, 42)
leaves, treespec = optree.tree_flatten(p1)
p2 = optree.tree_unflatten(treespec, leaves)

# Fill placeholder
assert p2('x') == (('x', 42), {})

# Extra args beyond placeholder
assert p2('x', 'y') == (('x', 42, 'y'), {})

# Missing placeholder arg
with pytest.raises(TypeError, match='missing positional arguments'):
p2()


@needs_placeholder
def test_partial_multiple_placeholders_roundtrip():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

p1 = optree.functools.partial(f, ph, 42, ph, 99)
leaves, treespec = optree.tree_flatten(p1)
p2 = optree.tree_unflatten(treespec, leaves)
assert p2.args == (ph, 42, ph, 99)
assert p2.args[0] is ph
assert p2.args[2] is ph
assert p2('a', 'b') == (('a', 42, 'b', 99), {})


@needs_placeholder
def test_partial_placeholder_with_keywords():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

p1 = optree.functools.partial(f, ph, 42, key='value')
leaves, treespec = optree.tree_flatten(p1)
p2 = optree.tree_unflatten(treespec, leaves)
assert p2.args == (ph, 42)
assert p2.keywords == {'key': 'value'}
assert p2('x') == (('x', 42), {'key': 'value'})


@needs_placeholder
def test_partial_placeholder_is_leaf():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

p = optree.functools.partial(f, ph, 42)
leaves = optree.tree_leaves(p)
assert ph in leaves
assert 42 in leaves


@needs_placeholder
def test_partial_placeholder_tree_map():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

p1 = optree.functools.partial(f, ph, 42)

# Identity tree_map preserves Placeholder
p2 = optree.tree_map(lambda x: x, p1)
assert p2.args[0] is ph
assert p2.args[1] == 42
assert p2('test') == (('test', 42), {})


@needs_placeholder
def test_partial_placeholder_in_larger_tree():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

p = optree.functools.partial(f, ph, 42)
tree = {'fn': p, 'data': [1, 2, 3]}
leaves, treespec = optree.tree_flatten(tree)
tree2 = optree.tree_unflatten(treespec, leaves)
assert tree2['fn'].args[0] is ph
assert tree2['fn']('test') == (('test', 42), {})
assert tree2['data'] == [1, 2, 3]


@needs_placeholder
def test_partial_wrapping_stdlib_partial_with_placeholder():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

stdlib_p = functools.partial(f, ph, 42)
op1 = optree.functools.partial(stdlib_p, 'extra')

# Anti-merge: outer args are separate
assert op1.args == ('extra',)
assert op1() == (('extra', 42), {})

# Roundtrip
leaves, treespec = optree.tree_flatten(op1)
op2 = optree.tree_unflatten(treespec, leaves)
assert op2.args == ('extra',)
assert op2() == (('extra', 42), {})


@needs_placeholder
def test_partial_wrapping_stdlib_partial_with_placeholder_no_extra_args():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

stdlib_p = functools.partial(f, ph, 42)
op1 = optree.functools.partial(stdlib_p)
assert op1.args == ()
assert op1('hello') == (('hello', 42), {})

# Roundtrip
leaves, treespec = optree.tree_flatten(op1)
op2 = optree.tree_unflatten(treespec, leaves)
assert op2('hello') == (('hello', 42), {})


@needs_placeholder
def test_partial_nested_optree_partial_with_placeholder():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

inner = optree.functools.partial(f, ph, 42)
outer = optree.functools.partial(inner, 'extra')

# Anti-merge behavior
assert outer.args == ('extra',)
assert outer() == (('extra', 42), {})

# Roundtrip of outer
leaves, treespec = optree.tree_flatten(outer)
outer2 = optree.tree_unflatten(treespec, leaves)
assert outer2() == (('extra', 42), {})


@needs_placeholder
def test_partial_trailing_placeholder_rejection():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

with pytest.raises(TypeError, match='trailing Placeholders are not allowed'):
optree.functools.partial(f, 42, ph)

with pytest.raises(TypeError, match='trailing Placeholders are not allowed'):
optree.functools.partial(f, ph)

with pytest.raises(TypeError, match='trailing Placeholders are not allowed'):
optree.functools.partial(f, ph, 1, ph)


@needs_placeholder
def test_partial_keyword_placeholder_rejection():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

with pytest.raises(TypeError, match='Placeholder'):
optree.functools.partial(f, kw=ph)


@needs_placeholder
def test_partial_repr_with_placeholder():
ph = functools.Placeholder

def f(*args, **kwargs):
return args, kwargs

p = optree.functools.partial(f, ph, 42)
r = repr(p)
assert 'Placeholder' in r
assert '42' in r


@needs_placeholder
def test_partial_placeholder_reexport():
assert hasattr(optree.functools, 'Placeholder')
assert optree.functools.Placeholder is functools.Placeholder