From dddc9d8feb64ca6b3e106feee1dd715218541d38 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 4 Nov 2022 19:43:18 -0500 Subject: [PATCH 1/2] ignore immutables in docs --- doc/conf.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/conf.py b/doc/conf.py index bdc505745..b5e198577 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -47,4 +47,7 @@ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"pyrsistent.typing.(.+)"], ["py:class", r"typing_extensions(.+)"], + # As of 2022-10-20, it doesn't look like there's sphinx documentation + # available. + ["py:class", r"immutables\.(.+)"], ] From 13ba86fbee741ecb91fc7da01d7edb3535bf7fb2 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 4 Nov 2022 19:55:22 -0500 Subject: [PATCH 2/2] adds check for AbstractResultWithNamedArrays.__eq__ --- pytato/array.py | 43 +++++++++++++++++++++++++++++++++++-------- pytato/equality.py | 5 ++++- pytato/loopy.py | 15 +-------------- 3 files changed, 40 insertions(+), 23 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index e0a2d418f..f9ab5fff3 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -768,6 +768,17 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC): tags: FrozenSet[Tag] = attrs.field(kw_only=True) _mapper_method: ClassVar[str] + def _is_eq_valid(self) -> bool: + return self.__class__.__eq__ is AbstractResultWithNamedArrays.__eq__ + + def __post_init__(self) -> None: + # ensure that a developer does not uses dataclass' "__eq__" + # or "__hash__" implementation as they have exponential complexity. + assert self._is_eq_valid() + + def __attrs_post_init__(self) -> None: + return self.__post_init__() + @abstractmethod def __contains__(self, name: object) -> bool: pass @@ -780,6 +791,13 @@ def __getitem__(self, name: str) -> NamedArray: def __len__(self) -> int: pass + def __eq__(self, other: Any) -> bool: + if self is other: + return True + + from pytato.equality import EqualityComparer + return EqualityComparer()(self, other) + @attrs.define(frozen=True, eq=False, init=False) class DictOfNamedArrays(AbstractResultWithNamedArrays): @@ -807,7 +825,7 @@ def __init__(self, data: Mapping[str, Array], *, object.__setattr__(self, "tags", tags) def __hash__(self) -> int: - return hash(frozenset(self._data.items())) + return hash((frozenset(self._data.items()), self.tags)) def __contains__(self, name: object) -> bool: return name in self._data @@ -826,13 +844,6 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[str]: return iter(self._data) - def __eq__(self, other: Any) -> bool: - if self is other: - return True - - from pytato.equality import EqualityComparer - return EqualityComparer()(self, other) - def __repr__(self) -> str: return "DictOfNamedArrays(" + str(self._data) + ")" @@ -2535,6 +2546,8 @@ def make_index_lambda( # }}} +# {{{ dot, vdot + def dot(a: ArrayOrScalar, b: ArrayOrScalar) -> ArrayOrScalar: """ For 1-dimensional arrays *a* and *b* computes their inner product. See @@ -2583,6 +2596,10 @@ def vdot(a: Array, b: Array) -> ArrayOrScalar: return pt.dot(pt.conj(a), b) +# }}} + + +# {{{ broadcast_to def broadcast_to(array: Array, shape: ShapeType) -> Array: """ @@ -2610,6 +2627,10 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array: axes=_get_default_axes(len(shape)), var_to_reduction_descr=Map()) +# }}} + + +# {{{ squeeze def squeeze(array: Array) -> Array: """Remove single-dimensional entries from the shape of an array.""" @@ -2619,6 +2640,10 @@ def squeeze(array: Array) -> Array: 0 if are_shape_components_equal(s_i, 1) else slice(s_i) for i, s_i in enumerate(array.shape))] +# }}} + + +# {{{ expand_dims def expand_dims(array: Array, axis: Union[Tuple[int, ...], int]) -> Array: """ @@ -2660,4 +2685,6 @@ def expand_dims(array: Array, axis: Union[Tuple[int, ...], int]) -> Array: | {ExpandedDimsReshape(tuple(normalized_axis))}), axes=_get_default_axes(len(new_shape))) +# }}} + # vim: foldmethod=marker diff --git a/pytato/equality.py b/pytato/equality.py index eef5c26c7..76c831868 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -235,6 +235,7 @@ def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: if isinstance(bnd, Array) else bnd == expr2.bindings[name] for name, bnd in expr1.bindings.items()) + and expr1.tags == expr2.tags ) def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: @@ -248,7 +249,9 @@ def map_dict_of_named_arrays(self, expr1: DictOfNamedArrays, expr2: Any) -> bool return (expr1.__class__ is expr2.__class__ and frozenset(expr1._data.keys()) == frozenset(expr2._data.keys()) and all(self.rec(expr1._data[name], expr2._data[name]) - for name in expr1._data)) + for name in expr1._data) + and expr1.tags == expr2.tags + ) def map_distributed_send_ref_holder( self, expr1: DistributedSendRefHolder, expr2: Any) -> bool: diff --git a/pytato/loopy.py b/pytato/loopy.py index 66a5c9aff..ea71e5b48 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -93,7 +93,7 @@ def _entry_kernel(self) -> lp.LoopKernel: def __hash__(self) -> int: return hash((self.translation_unit, tuple(self.bindings.items()), - self.entrypoint)) + self.entrypoint, self.tags)) def __contains__(self, name: object) -> bool: return name in self._result_names @@ -118,19 +118,6 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[str]: return iter(self._result_names) - def __eq__(self, other: Any) -> bool: - if self is other: - return True - - if not isinstance(other, LoopyCall): - return False - - if ((self.entrypoint == other.entrypoint) - and (self.bindings == other.bindings) - and (self.translation_unit == other.translation_unit)): - return True - return False - class LoopyCallResult(NamedArray): """