Skip to content

Commit 9be0522

Browse files
committed
avoid duplication in untag_loopy_call_results
1 parent 89fcb49 commit 9be0522

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

meshmode/array_context.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,13 @@ def transform_dag(self,
264264
def untag_loopy_call_results(
265265
expr: pt.Array | pt.AbstractResultWithNamedArrays
266266
) -> pt.Array | pt.AbstractResultWithNamedArrays:
267-
if isinstance(expr, pt.NamedArray):
268-
return expr.copy(tags=frozenset(),
269-
axes=(pt.Axis(frozenset()),)*expr.ndim)
267+
if isinstance(expr, pt.loopy.LoopyCallResult):
268+
new_tags = frozenset()
269+
if any(axis.tags for axis in expr.axes):
270+
new_axes = (pt.Axis(frozenset()),)*expr.ndim
271+
else:
272+
new_axes = expr.axes
273+
return expr.replace_if_different(tags=new_tags, axes=new_axes)
270274
else:
271275
return expr
272276

0 commit comments

Comments
 (0)