From 3922c86f00dc5431b598db850c9947cdba951218 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 19 Oct 2022 12:34:36 -0500 Subject: [PATCH 1/3] implement checks to avoid developer mishaps --- pytato/array.py | 22 +++++++++++++++++++--- test/test_pytato.py | 16 ++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 4ea42079f..1140e22c0 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -446,6 +446,18 @@ class Array(Taggable): # disallow numpy arithmetic from taking precedence __array_priority__: ClassVar[int] = 1 + def _is_eq_valid(self) -> bool: + return (self.__class__.__eq__ is Array.__eq__ + and self.__class__.__hash__ is Array.__hash__) + + def __post_init__(self) -> None: + # ensure that a developer does not uses dataclass' "__eq__" + # or "__hash__" implementation as they have exponential complexity. + assert self._is_eq_valid() + + def __attrs_post_init__(self) -> None: + return self.__post_init__() + def copy(self: ArrayT, **kwargs: Any) -> ArrayT: for field in self._fields: if field not in kwargs: @@ -1471,6 +1483,7 @@ class Reshape(IndexRemappingBase): def __post_init__(self) -> None: # FIXME: Get rid of this restriction assert self.order == "C" + super().__post_init__() __attrs_post_init__ = __post_init__ @@ -1689,12 +1702,15 @@ class DataWrapper(InputArgumentBase): def name(self) -> None: return None + def _is_eq_valid(self) -> bool: + # we override __hash__ as hashing DataInterface is impractical + # => promise the __post_init__ that the change was intentional + # and valid by returning True + return True + def __hash__(self) -> int: return id(self) - def __eq__(self, other: Any) -> bool: - return self is other - @property def shape(self) -> ShapeType: return self._shape diff --git a/test/test_pytato.py b/test/test_pytato.py index 3b0b4fa55..88a11c6e6 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -29,6 +29,7 @@ import numpy as np import pytest +import attrs import pytato as pt @@ -941,6 +942,21 @@ def test_with_tagged_reduction(): .tags_of_type(FooRednTag)) +def test_derived_class_uses_correct_array_eq(): + @attrs.define(frozen=True) + class MyNewArrayT(pt.Array): + pass + + with pytest.raises(AssertionError): + MyNewArrayT(tags=frozenset(), axes=()) + + @attrs.define(frozen=True, eq=False) + class MyNewAndCorrectArrayT(pt.Array): + pass + + MyNewAndCorrectArrayT(tags=frozenset(), axes=()) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From 18fc31da07a9bb77d75afd4d14dbbcc27f82f186 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 19 Oct 2022 12:36:10 -0500 Subject: [PATCH 2/3] explain the choice of attrs over dataclasses --- doc/design.rst | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/doc/design.rst b/doc/design.rst index b5db78a30..5eaee8023 100644 --- a/doc/design.rst +++ b/doc/design.rst @@ -231,6 +231,22 @@ that relies on memory layout information to do its job is undefined in :mod:`pyt At the most basic level, the attribute :attr:`numpy.ndarray.strides` is not available on subclasses of :class:`pytato.Array`. +Dataclasses / :mod:`attrs` +-------------------------- + +:mod:`dataclasses` helps us reduce most of the boilerplate involved in +instantiating a new type. However, :mod:`dataclasses` does not support +keyword-only argument until Python-3.10. To overcome this, we prefer +:mod:`attrs` which gives us all the required functionality of +:mod:`dataclasses` and works with Python-3.8. + + +We have checks in place to avoid developer errors that could happen by using +the defaults of these libraries. For eg. both :mod:`dataclasses` and +:mod:`attrs` override the implementation of ``__eq__`` for the class being +implemented, which could potentially lead lead to an `exponential complex +operation `_. + Lessons learned =============== From 321d293a9e7e8b5f593f6ca6662bc95d5e9f4451 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 19 Oct 2022 12:40:33 -0500 Subject: [PATCH 3/3] add attrs intersphinx mapping --- doc/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/conf.py b/doc/conf.py index d5abd6d27..bdc505745 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -28,6 +28,7 @@ "https://documen.tician.de/islpy/": None, "https://pyrsistent.readthedocs.io/en/latest/": None, "https://jax.readthedocs.io/en/latest/": None, + "https://www.attrs.org/en/stable/": None, } # Some modules need to import things just so that sphinx can resolve symbols in