diff --git a/meshmode/array_context.py b/meshmode/array_context.py index 9ac45932..8c3607bf 100644 --- a/meshmode/array_context.py +++ b/meshmode/array_context.py @@ -273,7 +273,10 @@ def untag_loopy_call_results( new_axes = (pt.Axis(frozenset()),)*expr.ndim else: new_axes = expr.axes - return expr.replace_if_different(tags=new_tags, axes=new_axes) + return expr.replace_if_different( + tags=expr.tags if expr.tags == new_tags else new_tags, + axes=expr.axes if expr.axes == new_axes else new_axes, + ) else: return expr