Skip to content

relax scan broadcastability checks#1861

Open
eclipse1605 wants to merge 3 commits intopymc-devs:mainfrom
eclipse1605:fix-scan-broadcast-compatibility
Open

relax scan broadcastability checks#1861
eclipse1605 wants to merge 3 commits intopymc-devs:mainfrom
eclipse1605:fix-scan-broadcast-compatibility

Conversation

@eclipse1605
Copy link

relax Scan broadcastability checks to allow inner outputs to be more specific than outer outputs_info

  • modified check_broadcast to only reject when outer requires broadcastable axes that inner doesn't have
  • updated Scan.validate_inner_graph to use TensorType.is_super for compatibility instead of exact broadcastable equality

@eclipse1605
Copy link
Author

related to pymc-devs/pymc#7892 and pymc-devs/pymc#8016

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should try to get rid of this function altogether, whether broadcastable matches or not is irrelevant: If the outer input is more precise, we can add a specify_shape to the inner output. Conversely, if the inner output is more precise we can add a specify_shape to the outer input.

Only case that is invalid is if we have static shape in both variables that is inconsistent

if isinstance(type_input, TensorType) and isinstance(
type_output, TensorType
):
if not type_input.is_super(type_output):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just use something like type_input.filter_variable(type_output), which works regardless of the type (TensorVariable or otherwise). There's also a kwarg that allows it to automatically add a specify_shape or something.

This can be inside a try/except to still provide a specialized error message to the user

@eclipse1605
Copy link
Author

is this what you meant?

type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable
):
if type_input.dtype != type_output.dtype:
Copy link
Member

@ricardoV94 ricardoV94 Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is covered by the new try/except below, so we don't need it separately

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wouldnt removing the dtype check change semantics? type api will be allowed to perform dtype conversions instead of raising

)
# Error, because the broadcast patterns are inconsistent.
# This should now work with relaxed broadcast checks
g = grad(y.sum(), x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's verify the gradient is correct with utt.verify_grad, and update the test name

assert g is not None


def test_filter_variable_allows_inner_more_specific():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can't be right, can't convert a vector to a matrix. I don't think you want this test anyway

outer_t.filter_variable(inner_var, allow_convert=True)


def test_filter_variable_rejects_incompatible_static_shapes():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again this test is not specific to this PR, the functionality should be tested already elsewhere. It's definitely not scan specific

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be in the tests for type.py ig

@eclipse1605
Copy link
Author

ive left the dtype check in palce for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants