-
Notifications
You must be signed in to change notification settings - Fork 183
Implement vectorize_graph for XTensor Ops #1876
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d4a6659
6316869
d9bfd26
4949bbf
e56c37f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,13 +4,15 @@ | |
| # https://numpy.org/neps/nep-0021-advanced-indexing.html | ||
| # https://docs.xarray.dev/en/latest/user-guide/indexing.html | ||
| # https://tutorial.xarray.dev/intermediate/indexing/advanced-indexing.html | ||
| from itertools import chain | ||
| from typing import Literal | ||
|
|
||
| from pytensor.graph.basic import Apply, Constant, Variable | ||
| from pytensor.scalar.basic import discrete_dtypes | ||
| from pytensor.tensor.basic import as_tensor | ||
| from pytensor.tensor.type_other import NoneTypeT, SliceType, make_slice | ||
| from pytensor.xtensor.basic import XOp, xtensor_from_tensor | ||
| from pytensor.xtensor.shape import broadcast | ||
| from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor | ||
|
|
||
|
|
||
|
|
@@ -195,6 +197,15 @@ def combine_dim_info(idx_dim, idx_dim_shape): | |
| output = xtensor(dtype=x.type.dtype, shape=out_shape, dims=out_dims) | ||
| return Apply(self, [x, *idxs], [output]) | ||
|
|
||
| def vectorize_node(self, node, new_x, *new_idxs, new_dim): | ||
| # new_x may have dims in different order | ||
| # we pair each pre-existing dim to the respective index | ||
| # with new dims having simply a slice(None) | ||
| old_x, *_ = node.inputs | ||
| dims_to_idxs = dict(zip(old_x.dims, new_idxs, strict=False)) | ||
| new_idxs = tuple(dims_to_idxs.get(dim, slice(None)) for dim in new_x.dims) | ||
| return [self(new_x, *new_idxs)] | ||
|
|
||
|
|
||
| index = Index() | ||
|
|
||
|
|
@@ -226,6 +237,29 @@ def make_node(self, x, y, *idxs): | |
| out = x.type() | ||
| return Apply(self, [x, y, *idxs], [out]) | ||
|
|
||
| def vectorize_node(self, node, *new_inputs, new_dim): | ||
| # If y or the indices have new dimensions we need to broadcast_x | ||
| exclude: set[str] = set( | ||
| chain.from_iterable( | ||
| old_inp.dims | ||
| for old_inp in node.inputs | ||
| if isinstance(old_inp.type, XTensorType) | ||
| ) | ||
| ) | ||
| old_x, *_ = node.inputs | ||
| new_x, *_ = broadcast( | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does the exclude here mean that you can't vectorize on a dimension that already exists in one of the inputs, or is that automatically handled somewhere else. I tried to write out a specific example a few times but since it's an indexing Op I'm not clear enough on how broadcasting will work. If I have
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Your first question, yes you can index with On vectorization. No, you can't do
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is enforced in # Or have new dimensions that were already in the graph
if new_core_dims := ((new_dims_set - old_dims_set) & all_old_dims_set):
raise ValueError(
f"Vectorized input {new_inp} has new dimensions that were present in the original graph: {new_core_dims}"
) |
||
| *[ | ||
| new_inp | ||
| for new_inp in new_inputs | ||
| if isinstance(new_inp.type, XTensorType) | ||
| ], | ||
| exclude=tuple(exclude), | ||
| ) | ||
| # New batch dimensions must go on the right since indices map to indexed dimensions positionally in the Op | ||
| new_x = new_x.transpose(*old_x.dims, ...) | ||
| _, new_y, *new_idxs = new_inputs | ||
| return [self(new_x, new_y, *new_idxs)] | ||
|
|
||
|
|
||
| index_assignment = IndexUpdate("set") | ||
| index_increment = IndexUpdate("inc") | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs a bit more documentation, I wouldn't have understood the
Applychange without our call yesterdayThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's pretty internal though, the type signature shows _vectorize_node returns
Sequence[Variable] | Applyand I was hopingvect_node_or_outputs -> vect_outputsfurther drives the point home. I want to deprecate theApplyeventually as well, but don't want to do it in this PR