Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,7 @@
["py:class", r"numpy.(u?)int[\d]+"],
["py:class", r"pyrsistent.typing.(.+)"],
["py:class", r"typing_extensions(.+)"],
# As of 2022-10-20, it doesn't look like there's sphinx documentation
# available.
["py:class", r"immutables\.(.+)"],
]
43 changes: 35 additions & 8 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,17 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC):
tags: FrozenSet[Tag] = attrs.field(kw_only=True)
_mapper_method: ClassVar[str]

def _is_eq_valid(self) -> bool:
return self.__class__.__eq__ is AbstractResultWithNamedArrays.__eq__

def __post_init__(self) -> None:
# ensure that a developer does not uses dataclass' "__eq__"
# or "__hash__" implementation as they have exponential complexity.
assert self._is_eq_valid()

def __attrs_post_init__(self) -> None:
return self.__post_init__()

@abstractmethod
def __contains__(self, name: object) -> bool:
pass
Expand All @@ -780,6 +791,13 @@ def __getitem__(self, name: str) -> NamedArray:
def __len__(self) -> int:
pass

def __eq__(self, other: Any) -> bool:
if self is other:
return True

from pytato.equality import EqualityComparer
return EqualityComparer()(self, other)


@attrs.define(frozen=True, eq=False, init=False)
class DictOfNamedArrays(AbstractResultWithNamedArrays):
Expand Down Expand Up @@ -807,7 +825,7 @@ def __init__(self, data: Mapping[str, Array], *,
object.__setattr__(self, "tags", tags)

def __hash__(self) -> int:
return hash(frozenset(self._data.items()))
return hash((frozenset(self._data.items()), self.tags))

def __contains__(self, name: object) -> bool:
return name in self._data
Expand All @@ -826,13 +844,6 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[str]:
return iter(self._data)

def __eq__(self, other: Any) -> bool:
if self is other:
return True

from pytato.equality import EqualityComparer
return EqualityComparer()(self, other)

def __repr__(self) -> str:
return "DictOfNamedArrays(" + str(self._data) + ")"

Expand Down Expand Up @@ -2535,6 +2546,8 @@ def make_index_lambda(
# }}}


# {{{ dot, vdot

def dot(a: ArrayOrScalar, b: ArrayOrScalar) -> ArrayOrScalar:
"""
For 1-dimensional arrays *a* and *b* computes their inner product. See
Expand Down Expand Up @@ -2583,6 +2596,10 @@ def vdot(a: Array, b: Array) -> ArrayOrScalar:

return pt.dot(pt.conj(a), b)

# }}}


# {{{ broadcast_to

def broadcast_to(array: Array, shape: ShapeType) -> Array:
"""
Expand Down Expand Up @@ -2610,6 +2627,10 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array:
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=Map())

# }}}


# {{{ squeeze

def squeeze(array: Array) -> Array:
"""Remove single-dimensional entries from the shape of an array."""
Expand All @@ -2619,6 +2640,10 @@ def squeeze(array: Array) -> Array:
0 if are_shape_components_equal(s_i, 1) else slice(s_i)
for i, s_i in enumerate(array.shape))]

# }}}


# {{{ expand_dims

def expand_dims(array: Array, axis: Union[Tuple[int, ...], int]) -> Array:
"""
Expand Down Expand Up @@ -2660,4 +2685,6 @@ def expand_dims(array: Array, axis: Union[Tuple[int, ...], int]) -> Array:
| {ExpandedDimsReshape(tuple(normalized_axis))}),
axes=_get_default_axes(len(new_shape)))

# }}}

# vim: foldmethod=marker
5 changes: 4 additions & 1 deletion pytato/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool:
if isinstance(bnd, Array)
else bnd == expr2.bindings[name]
for name, bnd in expr1.bindings.items())
and expr1.tags == expr2.tags
)

def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool:
Expand All @@ -248,7 +249,9 @@ def map_dict_of_named_arrays(self, expr1: DictOfNamedArrays, expr2: Any) -> bool
return (expr1.__class__ is expr2.__class__
and frozenset(expr1._data.keys()) == frozenset(expr2._data.keys())
and all(self.rec(expr1._data[name], expr2._data[name])
for name in expr1._data))
for name in expr1._data)
and expr1.tags == expr2.tags
)

def map_distributed_send_ref_holder(
self, expr1: DistributedSendRefHolder, expr2: Any) -> bool:
Expand Down
15 changes: 1 addition & 14 deletions pytato/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _entry_kernel(self) -> lp.LoopKernel:

def __hash__(self) -> int:
return hash((self.translation_unit, tuple(self.bindings.items()),
self.entrypoint))
self.entrypoint, self.tags))

def __contains__(self, name: object) -> bool:
return name in self._result_names
Expand All @@ -118,19 +118,6 @@ def __len__(self) -> int:
def __iter__(self) -> Iterator[str]:
return iter(self._result_names)

def __eq__(self, other: Any) -> bool:
if self is other:
return True

if not isinstance(other, LoopyCall):
return False

if ((self.entrypoint == other.entrypoint)
and (self.bindings == other.bindings)
and (self.translation_unit == other.translation_unit)):
return True
return False


class LoopyCallResult(NamedArray):
"""
Expand Down