Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ lines-after-imports = 2
"tests/link/numba/**/test_*.py" = ["E402"]
"tests/link/pytorch/**/test_*.py" = ["E402"]
"tests/link/mlx/**/test_*.py" = ["E402"]
"tests/xtensor/**/test_*.py" = ["E402"]
"tests/xtensor/**/*.py" = ["E402"]



Expand Down
35 changes: 25 additions & 10 deletions pytensor/graph/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,19 +208,13 @@ def toposort_key(


@singledispatch
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply:
def _vectorize_node(op: Op, node: Apply, *batched_inputs) -> Apply | Sequence[Variable]:
# Default implementation is provided in pytensor.tensor.blockwise
raise NotImplementedError


def vectorize_node(node: Apply, *batched_inputs) -> Apply:
"""Returns vectorized version of node with new batched inputs."""
op = node.op
return _vectorize_node(op, node, *batched_inputs)


def _vectorize_not_needed(op, node, *batched_inputs):
return op.make_node(*batched_inputs)
return op.make_node(*batched_inputs).outputs


@overload
Expand Down Expand Up @@ -289,19 +283,40 @@ def vectorize_graph(
# [array([-10., -11.]), array([10., 11.])]

"""
# TODO: Move this to tensor.vectorize, and make this helper type agnostic.
#
# This helper may dispatch to tensor.vectorize_graph or xtensor.vectorize_graph depending on the replacement types
# The behavior is distinct, because tensor vectorization depends on axis-position while xtensor depends on dimension labels
#
# xtensor.vectorize_graph will be able to handle batched inner tensor operations, while tensor.vectorize_graph won't,
# as it is by design unaware of xtensors and their semantics.
if isinstance(outputs, Sequence):
seq_outputs = outputs
else:
seq_outputs = [outputs]

if not all(
isinstance(key, Variable) and isinstance(value, Variable)
for key, value in replace.items()
):
raise ValueError(f"Some of the replaced items are not Variables: {replace}")

inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
new_inputs = [replace.get(inp, inp) for inp in inputs]

vect_vars = dict(zip(inputs, new_inputs, strict=True))
for node in toposort(seq_outputs, blockers=inputs):
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
vect_node = vectorize_node(node, *vect_inputs)
for output, vect_output in zip(node.outputs, vect_node.outputs, strict=True):

vect_node_or_outputs = _vectorize_node(node.op, node, *vect_inputs)
# Compatibility with the old API
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs a bit more documentation, I wouldn't have understood the Apply change without our call yesterday

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's pretty internal though, the type signature shows _vectorize_node returns Sequence[Variable] | Apply and I was hoping vect_node_or_outputs -> vect_outputs further drives the point home. I want to deprecate the Apply eventually as well, but don't want to do it in this PR

vect_outputs = (
vect_node_or_outputs.outputs
if isinstance(vect_node_or_outputs, Apply)
else vect_node_or_outputs
)

for output, vect_output in zip(node.outputs, vect_outputs, strict=True):
if output in vect_vars:
# This can happen when some outputs of a multi-output node are given a replacement,
# while some of the remaining outputs are still needed in the graph.
Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from collections.abc import Collection, Iterable
from collections.abc import Collection, Iterable, Sequence
from textwrap import dedent

import numpy as np
Expand Down Expand Up @@ -1926,7 +1926,7 @@ def logspace(


def broadcast_to(
x: TensorVariable, shape: TensorVariable | tuple[Variable, ...]
x: TensorLike, shape: TensorLike | Sequence[TensorLike]
) -> TensorVariable:
"""Broadcast an array to a new shape.

Expand Down
16 changes: 11 additions & 5 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from pytensor.compile.mode import optdb
from pytensor.graph import Constant, Op, node_rewriter
from pytensor.graph.destroyhandler import inplace_candidates
from pytensor.graph.replace import vectorize_node
from pytensor.graph.replace import vectorize_graph
from pytensor.graph.rewriting.basic import copy_stack_trace, dfs_rewriter
from pytensor.graph.rewriting.unify import OpPattern, OpPatternOpTypeType
from pytensor.graph.traversal import apply_ancestors
from pytensor.tensor.basic import Alloc, ARange, alloc, shape_padleft
from pytensor.tensor.blockwise import Blockwise, _squeeze_left
from pytensor.tensor.math import Dot
Expand Down Expand Up @@ -37,10 +38,15 @@ def local_useless_blockwise(fgraph, node):
"""
op = node.op
inputs = node.inputs
dummy_core_node = op._create_dummy_core_node(node.inputs)
vect_node = vectorize_node(dummy_core_node, *inputs)
if not isinstance(vect_node.op, Blockwise):
return copy_stack_trace(node.outputs, vect_node.outputs)
dummy_core_node, dummy_inputs = op._create_dummy_core_node(
inputs, return_dummy_inputs=True
)
outputs = vectorize_graph(dummy_core_node.outputs, dict(zip(dummy_inputs, inputs)))
if not any(
isinstance(vect_node.op, Blockwise)
for vect_node in apply_ancestors(outputs, blockers=inputs)
):
return copy_stack_trace(node.outputs, outputs)


@node_rewriter([Blockwise])
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/rewriting/subtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,7 @@ def local_blockwise_of_subtensor(fgraph, node):
def local_blockwise_inc_subtensor(fgraph, node):
"""Rewrite blockwised inc_subtensors.

Note: The reason we don't apply this rewrite eagerly in the `vectorize_node` dispatch
Note: The reason we don't apply this rewrite eagerly in the `_vectorize_node` dispatch
Is that we often have batch dimensions from alloc of shapes/reshape that can be removed by rewrites

such as x[:vectorized(w.shape[0])].set(y), that will later be rewritten as x[:w.shape[1]].set(y),
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,7 +856,7 @@ def _vectorize_reshape(op, node, x, shape):
else:
raise ValueError("Invalid shape length passed into vectorize node of Reshape")

return reshape(x, new_shape, ndim=len(new_shape)).owner
return reshape(x, new_shape, ndim=len(tuple(new_shape))).owner


def reshape(
Expand Down
44 changes: 44 additions & 0 deletions pytensor/xtensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from pytensor.compile.ops import TypeCastingOp
from pytensor.graph import Apply, Op
from pytensor.graph.basic import Variable
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor

Expand All @@ -17,6 +18,11 @@ def perform(self, node, inputs, outputs):
def do_constant_folding(self, fgraph, node):
return False

def vectorize_node(
self, node, *new_inputs, new_dim: str | None
) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}")


class XTypeCastOp(TypeCastingOp):
"""Base class for Ops that type cast between TensorType and XTensorType.
Expand All @@ -27,6 +33,11 @@ class XTypeCastOp(TypeCastingOp):
def infer_shape(self, fgraph, node, input_shapes):
return input_shapes

def vectorize_node(
self, node, *new_inputs, new_dim: str | None
) -> Sequence[Variable]:
raise NotImplementedError(f"Vectorized node not implemented for {self}")


class TensorFromXTensor(XTypeCastOp):
__props__ = ()
Expand All @@ -42,6 +53,17 @@ def L_op(self, inputs, outs, g_outs):
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]

def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs
if (new_x.ndim - old_x.ndim) > 1:
raise NotImplementedError(
f"Vectorization of {self} cannot guarantee correct placement of multiple batch dimensions. "
"You can call vectorize_graph one batch dimension at a time, "
"or pytensor.xtensor.vectorization.vectorize_graph instead."
)
new_x = new_x.transpose(..., *old_x.dims)
return [self(new_x)]


tensor_from_xtensor = TensorFromXTensor()

Expand All @@ -63,6 +85,18 @@ def L_op(self, inputs, outs, g_outs):
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]

def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs
if new_x.ndim != old_x.ndim:
if new_dim is None:
raise NotImplementedError(
f"Vectorization of {self} cannot infer the new dimension labels. "
"Use pytensor.xtensor.vectorization.vectorize_graph instead."
)
return [type(self)(dims=(new_dim, *self.dims))(new_x)]
else:
return [self(new_x)]


def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)
Expand All @@ -85,6 +119,16 @@ def L_op(self, inputs, outs, g_outs):
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]

def vectorize_node(self, node, new_x, new_dim):
[old_x] = node.inputs
old_dim_mapping = dict(zip(old_x.dims, self.new_dims, strict=True))

# new_dims may include a mix of old dims (possibly re-ordered), and new dims which won't be renamed
new_dims = tuple(
old_dim_mapping.get(new_dim, new_dim) for new_dim in new_x.dims
)
return [type(self)(new_dims)(new_x)]


def rename(x, name_dict: dict[str, str] | None = None, **names: str):
if name_dict is not None:
Expand Down
34 changes: 34 additions & 0 deletions pytensor/xtensor/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
# https://numpy.org/neps/nep-0021-advanced-indexing.html
# https://docs.xarray.dev/en/latest/user-guide/indexing.html
# https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html
from itertools import chain
from typing import Literal

from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.scalar.basic import discrete_dtypes
from pytensor.tensor.basic import as_tensor
from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice
from pytensor.xtensor.basic import XOp, xtensor_from_tensor
from pytensor.xtensor.shape import broadcast
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor


Expand Down Expand Up @@ -195,6 +197,15 @@ def combine_dim_info(idx_dim, idx_dim_shape):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x, *idxs], [output])

def vectorize_node(self, node, new_x, *new_idxs, new_dim):
# new_x may have dims in different order
# we pair each pre-existing dim to the respective index
# with new dims having simply a slice(None)
old_x, *_ = node.inputs
dims_to_idxs = dict(zip(old_x.dims, new_idxs, strict=False))
new_idxs = tuple(dims_to_idxs.get(dim, slice(None)) for dim in new_x.dims)
return [self(new_x, *new_idxs)]


index = Index()

Expand Down Expand Up @@ -226,6 +237,29 @@ def make_node(self, x, y, *idxs):
out = x.type()
return Apply(self, [x, y, *idxs], [out])

def vectorize_node(self, node, *new_inputs, new_dim):
# If y or the indices have new dimensions we need to broadcast_x
exclude: set[str] = set(
chain.from_iterable(
old_inp.dims
for old_inp in node.inputs
if isinstance(old_inp.type, XTensorType)
)
)
old_x, *_ = node.inputs
new_x, *_ = broadcast(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the exclude here mean that you can't vectorize on a dimension that already exists in one of the inputs, or is that automatically handled somewhere else.

I tried to write out a specific example a few times but since it's an indexing Op I'm not clear enough on how broadcasting will work. If I have x = xtensor(dims=('a', )) can I index with idx = xtensor(dims=('b',)), or is that already nonsense? If it is possible, I am asking whether we can replace x with x_batch = xtensor(dims=('b', 'a'))

Copy link
Copy Markdown
Member Author

@ricardoV94 ricardoV94 Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your first question, yes you can index with xtensor(dims=('b',)). The way indexing in xarray/xtensor works is: the input has some dimensions, the index variables as other dimensions. It defines a map of input_dims -> dimensions_after_indexing. So xtensor(dims=("a",))[xtensor(dims=("b", dtype=int))] -> has dims("b",). Indexing converted a subset of a dim into b dim. There's special consideration if b already existed in the input or in another indexing variable, but let's not get there. Indexing gets hairy pretty fast.

On vectorization. No, you can't do x_batch = xtensor(dims=("b", "a")) The constraint is you can never introduce a new dimension in the batch inputs that already existed in the graph (otherwise it would interact with the core_graph, so it wouldn't be a true vectorization). The logic is, you should always be able to implement vectorization by doing a loop with each entry of the batch inputs at a time and then concatenate the results along the new dimension.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is enforced in vectorize_x_node:

        # Or have new dimensions that were already in the graph
        if new_core_dims := ((new_dims_set - old_dims_set) & all_old_dims_set):
            raise ValueError(
                f"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}"
            )

*[
new_inp
for new_inp in new_inputs
if isinstance(new_inp.type, XTensorType)
],
exclude=tuple(exclude),
)
# New batch dimensions must go on the right since indices map to indexed dimensions positionally in the Op
new_x = new_x.transpose(*old_x.dims, ...)
_, new_y, *new_idxs = new_inputs
return [self(new_x, new_y, *new_idxs)]


index_assignment = IndexUpdate("set")
index_increment = IndexUpdate("inc")
6 changes: 6 additions & 0 deletions pytensor/xtensor/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def make_node(self, x):
output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims)
return Apply(self, [x], [output])

def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]


def _process_user_dims(x: XTensorVariable, dim: REDUCE_DIM) -> Sequence[str]:
if isinstance(dim, str):
Expand Down Expand Up @@ -117,6 +120,9 @@ def make_node(self, x):
out = x.type()
return Apply(self, [x], [out])

def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]


def cumreduce(x, dim: REDUCE_DIM, *, binary_op):
x = as_xtensor(x)
Expand Down
43 changes: 43 additions & 0 deletions pytensor/xtensor/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def make_node(self, x):
)
return Apply(self, [x], [output])

def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]


def stack(x, dim: dict[str, Sequence[str]] | None = None, **dims: Sequence[str]):
if dim is not None:
Expand Down Expand Up @@ -146,6 +149,14 @@ def make_node(self, x, *unstacked_length):
)
return Apply(self, [x, *unstacked_lengths], [output])

def vectorize_node(self, node, new_x, *new_unstacked_length, new_dim):
new_unstacked_length = [ul.squeeze() for ul in new_unstacked_length]
if not all(ul.type.ndim == 0 for ul in new_unstacked_length):
raise NotImplementedError(
f"Vectorization of {self} with batched unstacked_length not implemented, "
)
return [self(new_x, *new_unstacked_length)]


def unstack(x, dim: dict[str, dict[str, int]] | None = None, **dims: dict[str, int]):
if dim is not None:
Expand Down Expand Up @@ -189,6 +200,11 @@ def make_node(self, x):
)
return Apply(self, [x], [output])

def vectorize_node(self, node, new_x, new_dim):
old_dims = self.dims
new_dims = tuple(dim for dim in new_x.dims if dim not in old_dims)
return [type(self)(dims=(*new_dims, *old_dims))(new_x)]


def transpose(
x,
Expand Down Expand Up @@ -302,6 +318,9 @@ def make_node(self, *inputs):
output = xtensor(dtype=dtype, dims=dims, shape=shape)
return Apply(self, inputs, [output])

def vectorize_node(self, node, *new_inputs, new_dim):
return [self(*new_inputs)]


def concat(xtensors, dim: str):
"""Concatenate a sequence of XTensorVariables along a specified dimension.
Expand Down Expand Up @@ -383,6 +402,9 @@ def make_node(self, x):
)
return Apply(self, [x], [out])

def vectorize_node(self, node, new_x, new_dim):
return [self(new_x)]


def squeeze(x, dim: str | Sequence[str] | None = None):
"""Remove dimensions of size 1 from an XTensorVariable."""
Expand Down Expand Up @@ -442,6 +464,14 @@ def make_node(self, x, size):
)
return Apply(self, [x, size], [out])

def vectorize_node(self, node, new_x, new_size, new_dim):
new_size = new_size.squeeze()
if new_size.type.ndim != 0:
raise NotImplementedError(
f"Vectorization of {self} with batched new_size not implemented, "
)
return [self(new_x, new_size)]


def expand_dims(x, dim=None, axis=None, **dim_kwargs):
"""Add one or more new dimensions to an XTensorVariable."""
Expand Down Expand Up @@ -537,6 +567,19 @@ def make_node(self, *inputs):

return Apply(self, inputs, outputs)

def vectorize_node(self, node, *new_inputs, new_dim):
if exclude_set := set(self.exclude):
for new_x, old_x in zip(node.inputs, new_inputs, strict=True):
if invalid_excluded := (
(set(new_x.dims) - set(old_x.dims)) & exclude_set
):
raise NotImplementedError(
f"Vectorize of {self} is undefined because one of the inputs {new_x} "
f"has an excluded dimension {sorted(invalid_excluded)} that it did not have before."
)

return self(*new_inputs, return_list=True)


def broadcast(
*args, exclude: str | Sequence[str] | None = None
Expand Down
Loading
Loading