diff --git a/pytato/equality.py b/pytato/equality.py index 79d038d72..8f58153f1 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -135,7 +135,10 @@ def map_size_param(self, expr1: SizeParam, expr2: Any) -> bool: ) def map_data_wrapper(self, expr1: DataWrapper, expr2: Any) -> bool: - return expr1 is expr2 + import numpy as np + return (expr1.__class__ is expr2.__class__ + and np.array_equal(expr1.data.get(), expr2.data.get()) + ) def map_index_lambda(self, expr1: IndexLambda, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__