Optimize JoinDims and SplitDims by canonicalizing to simpler operations (Partial fixes #1843)#1847
Optimize JoinDims and SplitDims by canonicalizing to simpler operations (Partial fixes #1843)#1847mengxingbw wants to merge 9 commits intopymc-devs:mainfrom
Conversation
Optimize JoinDims and SplitDims by canonicalizing to simpler operations (identity, expand_dims, squeeze). Partial fixes pymc-devs#1843
|
My guess is that ricardo meant reshape, not literally specify_shape (which you're right, just adds metadata but doesn't do any computation) |
|
I meant split dims, when the shape argument has just one entry That's what the syntax |
pytensor/tensor/rewriting/reshape.py
Outdated
| x, shape = node.inputs | ||
| axis = node.op.axis | ||
|
|
||
| if isinstance(shape, Constant) and shape.data.size == 0: |
There was a problem hiding this comment.
Doesn't need to be constant just static shape of zero shape.type.shape == (0,)
There was a problem hiding this comment.
Also I would merge this with the split-to-reshape rewrite so we don't accidentally run that before this
|
Thank you @jessegrabowski and @ricardoV94 for clarifying - so it sounds like we don't need split_dims(x, axis=axis, shape=(dim,)) → specify_shape(...) this function since it will fall into reshape anyways? I have made the changes according to the comment above. |
|
reshape should be our last resort, everything we can avoid as reshape we should |
|
To clarify, none of the changes in this PR were strictly needed, they are an improvement over simple reshape |
understood. is there anything else to do with the last function: |
ricardoV94
left a comment
There was a problem hiding this comment.
Making progress, needs a few more tweaks
pytensor/tensor/rewriting/reshape.py
Outdated
|
|
||
| @register_canonicalize | ||
| @node_rewriter([JoinDims]) | ||
| def local_join_dims_noop(fgraph, node): |
There was a problem hiding this comment.
merge these join dims rewrites in a single one, like we did with SplitDims
pytensor/tensor/rewriting/reshape.py
Outdated
| @@ -9,11 +11,24 @@ | |||
| def local_split_dims_to_reshape(fgraph, node): | |||
There was a problem hiding this comment.
Now that we don't do only reshape, we should have a more generic name. Same for the join_dims when we merge the special cases
| def local_split_dims_to_reshape(fgraph, node): | |
| def local_lower_split_dims(fgraph, node): |
| # After rewrite: should have 0 JoinDims nodes | ||
| assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 | ||
| # Output should be equivalent to input (identity rewrite) | ||
| # The rewrite returns the input variable, so output should match input shape/type | ||
| assert fg.outputs[0].type.shape == x.type.shape | ||
| assert fg.outputs[0].type.dtype == x.type.dtype | ||
| assert fg.outputs[0].type.ndim == x.type.ndim |
There was a problem hiding this comment.
Use utt.assert_equal_computations to check we have the specific graph that we expect, not just anything without JoinDims
There was a problem hiding this comment.
This recommendation applies to all new tests
There was a problem hiding this comment.
This recommendation applies to all new tests
I can't seem to get it to pass for the first 2 tests. when i looked it up, i got "assert_equal_computations is better suited for cases where the canonical form is a specific operation (like expand_dims, squeeze, or identity) where graph structures match. For basic reshape cases, the rewrite produces a different but equivalent graph structure, so structural checks are sufficient"
Please let me know how to proceed!
There was a problem hiding this comment.
Suggest you look at the generated graph, the utility prints it when the assert fails. It shouldn't have anything too strange in it
There was a problem hiding this comment.
Suggest you look at the generated graph, the utility prints it when the assert fails. It shouldn't have anything too strange in it
There was a problem hiding this comment.
after excluding:
--------------------------------------- Captured stdout call ----------------------------------------
rewriting: rewrite local_split_dims replaces SplitDims{axis=1}.0 of SplitDims{axis=1}(x, [2 5 1]) with Reshape{5}.0 of Reshape{5}(x, MakeVector{dtype='int64'}.0)
rewriting: rewrite MergeOptimizer replaces 2 of None with 2 of None
rewriting: rewrite MergeOptimizer replaces 0 of None with 0 of None
rewriting: rewrite MergeOptimizer replaces 1 of None with 1 of None
rewriting: rewrite MergeOptimizer replaces Shape.0 of Shape(x) with Shape.0 of Shape(x)
rewriting: rewrite MergeOptimizer replaces 0 of None with 0 of None
rewriting: rewrite MergeOptimizer replaces 2 of None with 2 of None
rewriting: rewrite local_subtensor_remove_broadcastable_index replaces Subtensor{i}.0 of Subtensor{i}(Subtensor{:stop}.0, 0) with Squeeze{axis=0}.0 of Squeeze{axis=0}(Subtensor{:stop}.0)
rewriting: rewrite local_subtensor_remove_broadcastable_index replaces Subtensor{i}.0 of Subtensor{i}(Subtensor{start:}.0, 0) with Squeeze{axis=0}.0 of Squeeze{axis=0}(Subtensor{start:}.0)
rewriting: rewrite constant_folding replaces Subtensor{i}.0 of Subtensor{i}([2 5 1], 2) with 1 of None
rewriting: rewrite constant_folding replaces Subtensor{i}.0 of Subtensor{i}([2 5 1], 1) with 5 of None
rewriting: rewrite constant_folding replaces Subtensor{i}.0 of Subtensor{i}([2 5 1], 0) with 2 of None
rewriting: rewrite local_reshape_to_dimshuffle replaces Reshape{5}.0 of Reshape{5}(x, MakeVector{dtype='int64'}.0) with ExpandDims{axis=3}.0 of ExpandDims{axis=3}(Reshape{4}.0)```
There was a problem hiding this comment.
And how does the rewritten graph look like now (vs the expected)?
There was a problem hiding this comment.
E
E Rewritten:
E SpecifyShape [id A] <Tensor5(float64, shape=(2, 2, 5, 1, 3))>
E ├─ ExpandDims{axis=3} [id B] <Tensor5(float64, shape=(?, 2, 5, 1, ?))>
E │ └─ Reshape{4} [id C] <Tensor4(float64, shape=(?, 2, 5, ?))>
E │ ├─ x [id D] <Tensor3(float64, shape=(2, 10, 3))>
E │ └─ MakeVector{dtype='int64'} [id E] <Vector(int64, shape=(4,))>
E │ ├─ Squeeze{axis=0} [id F] <Scalar(int64, shape=())>
E │ │ └─ Subtensor{:stop} [id G] <Vector(int64, shape=(1,))>
E │ │ ├─ Shape [id H] <Vector(int64, shape=(3,))>
E │ │ │ └─ x [id D] <Tensor3(float64, shape=(2, 10, 3))>
E │ │ └─ 1 [id I] <int64>
E │ ├─ 2 [id J] <Scalar(int64, shape=())>
E │ ├─ 5 [id K] <Scalar(int64, shape=())>
E │ └─ Squeeze{axis=0} [id L] <Scalar(int64, shape=())>
E │ └─ Subtensor{start:} [id M] <Vector(int64, shape=(1,))>
E │ ├─ Shape [id H] <Vector(int64, shape=(3,))>
E │ │ └─ ···
E │ └─ 2 [id N] <int64>
E ├─ 2 [id O] <Scalar(int8, shape=())>
E ├─ 2 [id O] <Scalar(int8, shape=())>
E ├─ 5 [id P] <Scalar(int8, shape=())>
E ├─ 1 [id Q] <Scalar(int8, shape=())>
E └─ 3 [id R] <Scalar(int8, shape=())>
E
E Expected:
E ExpandDims{axis=3} [id A] <Tensor5(float64, shape=(2, 2, 5, 1, 3))>
E └─ Reshape{4} [id B] <Tensor4(float64, shape=(2, 2, 5, 3))>
E ├─ x [id C] <Tensor3(float64, shape=(2, 10, 3))>
E └─ MakeVector{dtype='int64'} [id D] <Vector(int64, shape=(4,))>
E ├─ Subtensor{i} [id E] <Scalar(int64, shape=())>
E │ ├─ Shape [id F] <Vector(int64, shape=(3,))>
E │ │ └─ x [id C] <Tensor3(float64, shape=(2, 10, 3))>
E │ └─ 0 [id G] <int64>
E ├─ Cast{int64} [id H] <Scalar(int64, shape=())>
E │ └─ 2 [id I] <Scalar(int8, shape=())>
E ├─ Cast{int64} [id J] <Scalar(int64, shape=())>
E │ └─ 5 [id K] <Scalar(int8, shape=())>
E └─ Subtensor{i} [id L] <Scalar(int64, shape=())>
E ├─ Shape [id F] <Vector(int64, shape=(3,))>
E │ └─ ···
E └─ 2 [id M] <int64>```
There was a problem hiding this comment.
removing local_subtensor_remove_broadcastable_index should bring you closer, and using np.int64(2|5) for the expected shape. That will get rid of the Cast thing, which comes from #1073
There was a problem hiding this comment.
E Rewritten:
E SpecifyShape [id A] <Tensor5(float64, shape=(2, 2, 5, 1, 3))>
E ├─ ExpandDims{axis=3} [id B] <Tensor5(float64, shape=(?, 2, 5, 1, ?))>
E │ └─ Reshape{4} [id C] <Tensor4(float64, shape=(?, 2, 5, ?))>
E │ ├─ x [id D] <Tensor3(float64, shape=(2, 10, 3))>
E │ └─ MakeVector{dtype='int64'} [id E] <Vector(int64, shape=(4,))>
E │ ├─ Subtensor{i} [id F] <Scalar(int64, shape=())>
E │ │ ├─ Subtensor{:stop} [id G] <Vector(int64, shape=(1,))>
E │ │ │ ├─ Shape [id H] <Vector(int64, shape=(3,))>
E │ │ │ │ └─ x [id D] <Tensor3(float64, shape=(2, 10, 3))>
E │ │ │ └─ 1 [id I] <int64>
E │ │ └─ 0 [id J] <int64>
E │ ├─ 2 [id K] <Scalar(int64, shape=())>
E │ ├─ 5 [id L] <Scalar(int64, shape=())>
E │ └─ Subtensor{i} [id M] <Scalar(int64, shape=())>
E │ ├─ Subtensor{start:} [id N] <Vector(int64, shape=(1,))>
E │ │ ├─ Shape [id H] <Vector(int64, shape=(3,))>
E │ │ │ └─ ···
E │ │ └─ 2 [id O] <int64>
E │ └─ 0 [id J] <int64>
E ├─ 2 [id P] <Scalar(int8, shape=())>
E ├─ 2 [id P] <Scalar(int8, shape=())>
E ├─ 5 [id Q] <Scalar(int8, shape=())>
E ├─ 1 [id R] <Scalar(int8, shape=())>
E └─ 3 [id S] <Scalar(int8, shape=())>
E
E Expected:
E ExpandDims{axis=3} [id A] <Tensor5(float64, shape=(2, 2, 5, 1, 3))>
E └─ Reshape{4} [id B] <Tensor4(float64, shape=(2, 2, 5, 3))>
E ├─ x [id C] <Tensor3(float64, shape=(2, 10, 3))>
E └─ MakeVector{dtype='int64'} [id D] <Vector(int64, shape=(4,))>
E ├─ Subtensor{i} [id E] <Scalar(int64, shape=())>
E │ ├─ Shape [id F] <Vector(int64, shape=(3,))>
E │ │ └─ x [id C] <Tensor3(float64, shape=(2, 10, 3))>
E │ └─ 0 [id G] <int64>
E ├─ 2 [id H] <Scalar(int64, shape=())>
E ├─ 5 [id I] <Scalar(int64, shape=())>
E └─ Subtensor{i} [id J] <Scalar(int64, shape=())>
E ├─ Shape [id F] <Vector(int64, shape=(3,))>
E │ └─ ···
E └─ 2 [id K] <int64>
ricardoV94
left a comment
There was a problem hiding this comment.
Looks good.
However you are basically doing the same checks twice in the tests, the assert_equal_computations already checks the specific Ops are in the graph (and none unexpected ones are in it), so the redundant checks can be removed.
| # After rewrite: should have 0 JoinDims nodes | ||
| assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 | ||
| # Should have 1 DimShuffle node with is_expand_dims=True | ||
| expand_nodes = [ | ||
| node | ||
| for node in fg.toposort() | ||
| if isinstance(node.op, DimShuffle) and node.op.is_expand_dims | ||
| ] | ||
| assert len(expand_nodes) == 1 |
There was a problem hiding this comment.
This is all already covered by the assert_equal_computations
| # After rewrite: should have 0 JoinDims nodes | |
| assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 | |
| # Should have 1 DimShuffle node with is_expand_dims=True | |
| expand_nodes = [ | |
| node | |
| for node in fg.toposort() | |
| if isinstance(node.op, DimShuffle) and node.op.is_expand_dims | |
| ] | |
| assert len(expand_nodes) == 1 |
| # After rewrite: should have 0 JoinDims nodes | ||
| assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 |
There was a problem hiding this comment.
Already covered by assert_equal_computations
| # After rewrite: should have 0 JoinDims nodes | |
| assert sum([1 for node in fg.toposort() if isinstance(node.op, JoinDims)]) == 0 |
| assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 | ||
| assert fg.outputs[0].type.shape == (2, 10, 3) |
There was a problem hiding this comment.
Already covered by assert_equal_computations
| assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 | |
| assert fg.outputs[0].type.shape == (2, 10, 3) |
| assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 | ||
| assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 |
There was a problem hiding this comment.
Already covered by assert_equal_computations
| assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 | |
| assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 1 |
| # After rewrite: should have 0 SplitDims nodes | ||
| assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 | ||
| # Should have 1 DimShuffle node with is_squeeze=True (not Reshape) | ||
| squeeze_nodes = [ | ||
| node | ||
| for node in fg.toposort() | ||
| if isinstance(node.op, DimShuffle) and node.op.is_squeeze | ||
| ] | ||
| assert len(squeeze_nodes) == 1 | ||
| # Should NOT have a Reshape node | ||
| assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 0 |
There was a problem hiding this comment.
Already covered by assert_equal_computations
| # After rewrite: should have 0 SplitDims nodes | |
| assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 0 | |
| # Should have 1 DimShuffle node with is_squeeze=True (not Reshape) | |
| squeeze_nodes = [ | |
| node | |
| for node in fg.toposort() | |
| if isinstance(node.op, DimShuffle) and node.op.is_squeeze | |
| ] | |
| assert len(squeeze_nodes) == 1 | |
| # Should NOT have a Reshape node | |
| assert sum([1 for node in fg.toposort() if isinstance(node.op, Reshape)]) == 0 |
| assert sum([1 for node in fg.toposort() if isinstance(node.op, SplitDims)]) == 1 | ||
|
|
||
| rewrite_graph(fg, include=("canonicalize",)) | ||
| with pt.config.change_flags(optimizer_verbose=True): |
There was a problem hiding this comment.
Just for debug, should be removed. You'll have to dedent the inner code
| with pt.config.change_flags(optimizer_verbose=True): |
pytensor/tensor/rewriting/reshape.py
Outdated
| *[x.shape[i] for i in range(axis)], | ||
| *shape, | ||
| *x.shape[axis + 1 :], | ||
| *[x.shape[i] for i in range(axis + 1, x.type.ndim)], |
There was a problem hiding this comment.
The previous code is cleaner, this creates repeated Shape Ops in the list comprehenstion that will have to be removed.
There was a problem hiding this comment.
The previous code is cleaner, this creates repeated
ShapeOps in the list comprehenstion that will have to be removed.
That was to fix a test fail. For example, when I switched to the previous code, this happened:
E AssertionError: equal_computations failed
E
E Rewritten:
E SpecifyShape [id A] <Tensor5(float64, shape=(2, 2, 5, 1, 3))>
E ├─ ExpandDims{axis=3} [id B] <Tensor5(float64, shape=(?, 2, 5, 1, ?))>
E │ └─ Reshape{4} [id C] <Tensor4(float64, shape=(?, 2, 5, ?))>
E │ ├─ x [id D] <Tensor3(float64, shape=(2, 10, 3))>
E │ └─ MakeVector{dtype='int64'} [id E] <Vector(int64, shape=(4,))>
E │ ├─ Subtensor{i} [id F] <Scalar(int64, shape=())>
E │ │ ├─ Subtensor{:stop} [id G] <Vector(int64, shape=(1,))>
E │ │ │ ├─ Shape [id H] <Vector(int64, shape=(3,))>
E │ │ │ │ └─ x [id D] <Tensor3(float64, shape=(2, 10, 3))>
E │ │ │ └─ 1 [id I] <int64>
E │ │ └─ 0 [id J] <int64>
E │ ├─ 2 [id K] <Scalar(int64, shape=())>
E │ ├─ 5 [id L] <Scalar(int64, shape=())>
E │ └─ Subtensor{i} [id M] <Scalar(int64, shape=())>
E │ ├─ Subtensor{start:} [id N] <Vector(int64, shape=(1,))>
E │ │ ├─ Shape [id H] <Vector(int64, shape=(3,))>
E │ │ │ └─ ···
E │ │ └─ 2 [id O] <int64>
E │ └─ 0 [id J] <int64>
E ├─ 2 [id P] <Scalar(int8, shape=())>
E ├─ 2 [id P] <Scalar(int8, shape=())>
E ├─ 5 [id Q] <Scalar(int8, shape=())>
E ├─ 1 [id R] <Scalar(int8, shape=())>
E └─ 3 [id S] <Scalar(int8, shape=())>
E
E Expected:
E ExpandDims{axis=3} [id A] <Tensor5(float64, shape=(2, 2, 5, 1, 3))>
E └─ Reshape{4} [id B] <Tensor4(float64, shape=(2, 2, 5, 3))>
E ├─ x [id C] <Tensor3(float64, shape=(2, 10, 3))>
E └─ MakeVector{dtype='int64'} [id D] <Vector(int64, shape=(4,))>
E ├─ Subtensor{i} [id E] <Scalar(int64, shape=())>
E │ ├─ Shape [id F] <Vector(int64, shape=(3,))>
E │ │ └─ x [id C] <Tensor3(float64, shape=(2, 10, 3))>
E │ └─ 0 [id G] <int64>
E ├─ 2 [id H] <Scalar(int64, shape=())>
E ├─ 5 [id I] <Scalar(int64, shape=())>
E └─ Subtensor{i} [id J] <Scalar(int64, shape=())>
E ├─ Shape [id F] <Vector(int64, shape=(3,))>
E │ └─ ···
E └─ 2 [id K] <int64>
There was a problem hiding this comment.
We shouldn't change the actual code for the sake of the tests (but the other way around)
There was a problem hiding this comment.
We shouldn't change the actual code for the sake of the tests (but the other way around)
that makes sense!! just updated the test
| ) | ||
|
|
||
|
|
||
| def test_local_join_dims(): |
There was a problem hiding this comment.
Now that the dust has settled, we can revert this test to how it looked before (including the name). The ones with Reshape are a bit too messy for the assert_equal_computations approach. This isn't affected by your PR anyway
|
|
||
|
|
||
| def test_local_split_dims_to_reshape(): | ||
| def test_local_split_dims(): |
There was a problem hiding this comment.
same here, revert test changes and name. Keep only the new tests you added as they are
Description
This PR implements the 3 out of 4 canonicalization rewrites suggested in #1843:
join_dims(x, axis=axis, n_axes=1)→ identity (no-op)join_dims(x, axis=axis, n_axes=0)→expand_dims(x, axis)split_dims(x, axis=axis, shape=())→squeeze(x, axis)split_dims(x, axis=axis, shape=(dim,))→specify_shape(...)(see Block section)Questions
I tried to work on the last requested change:
The issue: specify_shape preserves the input's known shape when it's already concrete, so it doesn't match SplitDims's output type. If the input already has a known shape at a dimension, it uses that shape; and it only uses the specified shape when the input shape is None. This has caused the function to fail.
For this rewrite to work even when the input shape is known, I'd need to use reshape instead of specify_shape, but that defeats the purpose of using specify_shape for shape assertion.
Related Issue
Checklist
Type of change