Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"https://documen.tician.de/islpy/": None,
"https://pyrsistent.readthedocs.io/en/latest/": None,
"https://jax.readthedocs.io/en/latest/": None,
"https://www.attrs.org/en/stable/": None,
}

# Some modules need to import things just so that sphinx can resolve symbols in
Expand Down
16 changes: 16 additions & 0 deletions doc/design.rst
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,22 @@ that relies on memory layout information to do its job is undefined in :mod:`pyt
At the most basic level, the attribute :attr:`numpy.ndarray.strides` is not available
on subclasses of :class:`pytato.Array`.

Dataclasses / :mod:`attrs`
--------------------------

:mod:`dataclasses` helps us reduce most of the boilerplate involved in
instantiating a new type. However, :mod:`dataclasses` does not support
keyword-only argument until Python-3.10. To overcome this, we prefer
:mod:`attrs` which gives us all the required functionality of
:mod:`dataclasses` and works with Python-3.8.


We have checks in place to avoid developer errors that could happen by using
the defaults of these libraries. For eg. both :mod:`dataclasses` and
:mod:`attrs` override the implementation of ``__eq__`` for the class being
implemented, which could potentially lead lead to an `exponential complex
operation <https://github.com/inducer/pytato/issues/163>`_.

Lessons learned
===============

Expand Down
22 changes: 19 additions & 3 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,18 @@ class Array(Taggable):
# disallow numpy arithmetic from taking precedence
__array_priority__: ClassVar[int] = 1

def _is_eq_valid(self) -> bool:
return (self.__class__.__eq__ is Array.__eq__
and self.__class__.__hash__ is Array.__hash__)

def __post_init__(self) -> None:
# ensure that a developer does not uses dataclass' "__eq__"
# or "__hash__" implementation as they have exponential complexity.
assert self._is_eq_valid()

def __attrs_post_init__(self) -> None:
return self.__post_init__()

def copy(self: ArrayT, **kwargs: Any) -> ArrayT:
for field in self._fields:
if field not in kwargs:
Expand Down Expand Up @@ -1471,6 +1483,7 @@ class Reshape(IndexRemappingBase):
def __post_init__(self) -> None:
# FIXME: Get rid of this restriction
assert self.order == "C"
super().__post_init__()

__attrs_post_init__ = __post_init__

Expand Down Expand Up @@ -1689,12 +1702,15 @@ class DataWrapper(InputArgumentBase):
def name(self) -> None:
return None

def _is_eq_valid(self) -> bool:
# we override __hash__ as hashing DataInterface is impractical
# => promise the __post_init__ that the change was intentional
# and valid by returning True
return True

def __hash__(self) -> int:
return id(self)

def __eq__(self, other: Any) -> bool:
return self is other

@property
def shape(self) -> ShapeType:
return self._shape
Expand Down
16 changes: 16 additions & 0 deletions test/test_pytato.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

import numpy as np
import pytest
import attrs

import pytato as pt

Expand Down Expand Up @@ -941,6 +942,21 @@ def test_with_tagged_reduction():
.tags_of_type(FooRednTag))


def test_derived_class_uses_correct_array_eq():
@attrs.define(frozen=True)
class MyNewArrayT(pt.Array):
pass

with pytest.raises(AssertionError):
MyNewArrayT(tags=frozenset(), axes=())

@attrs.define(frozen=True, eq=False)
class MyNewAndCorrectArrayT(pt.Array):
pass

MyNewAndCorrectArrayT(tags=frozenset(), axes=())


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down