relax scan broadcastability checks#1861
Conversation
|
related to pymc-devs/pymc#7892 and pymc-devs/pymc#8016 |
pytensor/scan/op.py
Outdated
There was a problem hiding this comment.
We should try to get rid of this function altogether, whether broadcastable matches or not is irrelevant: If the outer input is more precise, we can add a specify_shape to the inner output. Conversely, if the inner output is more precise we can add a specify_shape to the outer input.
Only case that is invalid is if we have static shape in both variables that is inconsistent
pytensor/scan/op.py
Outdated
| if isinstance(type_input, TensorType) and isinstance( | ||
| type_output, TensorType | ||
| ): | ||
| if not type_input.is_super(type_output): |
There was a problem hiding this comment.
I think we can just use something like type_input.filter_variable(type_output), which works regardless of the type (TensorVariable or otherwise). There's also a kwarg that allows it to automatically add a specify_shape or something.
This can be inside a try/except to still provide a specialized error message to the user
|
is this what you meant? |
| type_input.dtype != type_output.dtype | ||
| or type_input.broadcastable != type_output.broadcastable | ||
| ): | ||
| if type_input.dtype != type_output.dtype: |
There was a problem hiding this comment.
This case is covered by the new try/except below, so we don't need it separately
There was a problem hiding this comment.
wouldnt removing the dtype check change semantics? type api will be allowed to perform dtype conversions instead of raising
tests/scan/test_basic.py
Outdated
| ) | ||
| # Error, because the broadcast patterns are inconsistent. | ||
| # This should now work with relaxed broadcast checks | ||
| g = grad(y.sum(), x) |
There was a problem hiding this comment.
let's verify the gradient is correct with utt.verify_grad, and update the test name
tests/scan/test_basic.py
Outdated
| assert g is not None | ||
|
|
||
|
|
||
| def test_filter_variable_allows_inner_more_specific(): |
There was a problem hiding this comment.
This can't be right, can't convert a vector to a matrix. I don't think you want this test anyway
tests/scan/test_basic.py
Outdated
| outer_t.filter_variable(inner_var, allow_convert=True) | ||
|
|
||
|
|
||
| def test_filter_variable_rejects_incompatible_static_shapes(): |
There was a problem hiding this comment.
again this test is not specific to this PR, the functionality should be tested already elsewhere. It's definitely not scan specific
There was a problem hiding this comment.
it should be in the tests for type.py ig
|
ive left the dtype check in palce for now |
relax Scan broadcastability checks to allow inner outputs to be more specific than outer outputs_info
check_broadcastto only reject when outer requires broadcastable axes that inner doesn't haveScan.validate_inner_graphto useTensorType.is_superfor compatibility instead of exact broadcastable equality