We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 89fcb49 commit 9be0522Copy full SHA for 9be0522
1 file changed
meshmode/array_context.py
@@ -264,9 +264,13 @@ def transform_dag(self,
264
def untag_loopy_call_results(
265
expr: pt.Array | pt.AbstractResultWithNamedArrays
266
) -> pt.Array | pt.AbstractResultWithNamedArrays:
267
- if isinstance(expr, pt.NamedArray):
268
- return expr.copy(tags=frozenset(),
269
- axes=(pt.Axis(frozenset()),)*expr.ndim)
+ if isinstance(expr, pt.loopy.LoopyCallResult):
+ new_tags = frozenset()
+ if any(axis.tags for axis in expr.axes):
270
+ new_axes = (pt.Axis(frozenset()),)*expr.ndim
271
+ else:
272
+ new_axes = expr.axes
273
+ return expr.replace_if_different(tags=new_tags, axes=new_axes)
274
else:
275
return expr
276
0 commit comments