From d0aacd99baec77591f0d236b1468fabbb4232da4 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 16 Oct 2022 22:32:24 -0500 Subject: [PATCH 1/2] make attrs a dep --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 4363f3044..41772cb79 100644 --- a/setup.py +++ b/setup.py @@ -38,6 +38,7 @@ "pytools>=2021.1", "pyrsistent", "immutables", + "attrs", ], package_data={"pytato": ["py.typed"]}, author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei", From b3c35c5fc13e38d5221766807421399593d2aee0 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sun, 16 Oct 2022 11:05:28 -0500 Subject: [PATCH 2/2] use attrs for pt.Array --- .pylintrc-local.yml | 8 + pytato/array.py | 313 ++++++++++++++--------------------- pytato/codegen.py | 2 +- pytato/distributed.py | 52 +++--- pytato/loopy.py | 14 +- pytato/partition.py | 6 +- pytato/transform/__init__.py | 31 ++-- test/test_pytato.py | 4 +- 8 files changed, 185 insertions(+), 245 deletions(-) diff --git a/.pylintrc-local.yml b/.pylintrc-local.yml index 98a555e4e..f5a171e96 100644 --- a/.pylintrc-local.yml +++ b/.pylintrc-local.yml @@ -6,3 +6,11 @@ - ply - pygments.lexers - pygments.formatters + +# https://github.com/PyCQA/pylint/issues/7623 +- arg: disable + val: + - unexpected-keyword-arg + - too-many-function-args + - redundant-keyword-arg + - no-value-for-parameter diff --git a/pytato/array.py b/pytato/array.py index 3a1bdc1b1..68d48098c 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -148,14 +148,19 @@ .. class:: AxesT A :class:`tuple` of :class:`Axis` objects. + +.. class:: IntegralT + + An integer data type which is a union of integral types of :mod:`numpy` and + :class:`int`. """ # }}} -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from functools import partialmethod, cached_property import operator -from dataclasses import dataclass +import attrs from typing import ( Optional, Callable, ClassVar, Dict, Any, Mapping, Tuple, Union, Protocol, Sequence, cast, TYPE_CHECKING, List, Iterator, TypeVar, @@ -288,7 +293,7 @@ def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype[A return dtype -@dataclass(frozen=True, eq=True) +@attrs.define(frozen=True) class NormalizedSlice: """ A normalized version of :class:`slice`. "Normalized" is explained in @@ -313,7 +318,7 @@ class NormalizedSlice: step: IntegralT -@dataclass(eq=True, frozen=True) +@attrs.define(frozen=True) class Axis(Taggable): """ A type for recording the information about an :class:`~pytato.Array`'s @@ -322,11 +327,11 @@ class Axis(Taggable): tags: FrozenSet[Tag] def _with_new_tags(self, tags: FrozenSet[Tag]) -> Taggable: - from dataclasses import replace + from attrs import evolve as replace return replace(self, tags=tags) -@dataclass(eq=True, frozen=True) +@attrs.define(frozen=True) class ReductionDescriptor(Taggable): """ Records information about a reduction dimension in an @@ -335,10 +340,11 @@ class ReductionDescriptor(Taggable): tags: FrozenSet[Tag] def _with_new_tags(self, tags: FrozenSet[Tag]) -> ReductionDescriptor: - from dataclasses import replace + from attrs import evolve as replace return replace(self, tags=tags) +@attrs.define(frozen=True, eq=False, repr=False) class Array(Taggable): r""" A base class (abstract interface + supplemental functionality) for lazily @@ -428,16 +434,17 @@ class Array(Taggable): .. attribute:: ndim """ + axes: AxesT = attrs.field(kw_only=True) + tags: FrozenSet[Tag] = attrs.field(kw_only=True) + _mapper_method: ClassVar[str] + # A tuple of field names. Fields must be equality comparable and # hashable. Dicts of hashable keys and values are also permitted. _fields: ClassVar[Tuple[str, ...]] = ("axes", "tags",) - __array_priority__ = 1 # disallow numpy arithmetic to take precedence - - def __init__(self, axes: AxesT, tags: FrozenSet[Tag]) -> None: - self.axes = axes - self.tags = tags + # disallow numpy arithmetic from taking precedence + __array_priority__: ClassVar[int] = 1 def copy(self: ArrayT, **kwargs: Any) -> ArrayT: for field in self._fields: @@ -450,17 +457,17 @@ def _with_new_tags(self: ArrayT, tags: FrozenSet[Tag]) -> ArrayT: @property def shape(self) -> ShapeType: - raise NotImplementedError + raise NotImplementedError() + + @property + def dtype(self) -> _dtype_any: + raise NotImplementedError() @property def size(self) -> ShapeComponent: from pytools import product return product(self.shape) # type: ignore[no-any-return] - @property - def dtype(self) -> np.dtype[Any]: - raise NotImplementedError - def __len__(self) -> ShapeComponent: if self.ndim == 0: raise TypeError("len() of unsized object") @@ -663,18 +670,12 @@ def __repr__(self) -> str: # {{{ mixins -class _SuppliedShapeAndDtypeMixin(object): +class _SuppliedShapeAndDtypeMixin: """A mixin class for when an array must store its own *shape* and *dtype*, rather than when it can derive them easily from inputs. """ - - def __init__(self, - shape: ShapeType, - dtype: np.dtype[Any], - **kwargs: Any): - super().__init__(**kwargs) - self._shape = shape - self._dtype = dtype + _shape: ShapeType + _dtype: np.dtype[Any] @property def shape(self) -> ShapeType: @@ -689,6 +690,7 @@ def dtype(self) -> np.dtype[Any]: # {{{ dict of named arrays +@attrs.define(frozen=True, eq=False, repr=False) class NamedArray(Array): """An entry in a :class:`AbstractResultWithNamedArrays`. Holds a reference back to thecontaining instance as well as the name by which *self* is @@ -696,17 +698,11 @@ class NamedArray(Array): .. automethod:: __init__ """ - _fields = Array._fields + ("_container", "name") - _mapper_method = "map_named_array" + _container: AbstractResultWithNamedArrays + name: str - def __init__(self, - container: AbstractResultWithNamedArrays, - name: str, - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()) -> None: - super().__init__(axes=axes, tags=tags) - self._container = container - self.name = name + _fields: ClassVar[Tuple[str, ...]] = ("_container", "name", "axes", "tags",) + _mapper_method: ClassVar[str] = "map_named_array" # type-ignore reason: `copy` signature incompatible with super-class def copy(self, *, # type: ignore[override] @@ -780,7 +776,7 @@ class DictOfNamedArrays(AbstractResultWithNamedArrays): .. automethod:: __init__ """ - _mapper_method = "map_dict_of_named_arrays" + _mapper_method: ClassVar[str] = "map_dict_of_named_arrays" def __init__(self, data: Mapping[str, Array]): super().__init__() @@ -797,7 +793,8 @@ def __getitem__(self, name: str) -> NamedArray: if name not in self._data: raise KeyError(name) return NamedArray(self, name, - axes=_get_default_axes(self._data[name].ndim)) + axes=self._data[name].axes, + tags=self._data[name].tags) def __len__(self) -> int: return len(self._data) @@ -820,6 +817,7 @@ def __repr__(self) -> str: # {{{ index lambda +@attrs.define(frozen=True, eq=False, repr=False) class IndexLambda(_SuppliedShapeAndDtypeMixin, Array): r"""Represents an array that can be computed by evaluating :attr:`expr` for every value of the input indices. The @@ -854,25 +852,16 @@ class IndexLambda(_SuppliedShapeAndDtypeMixin, Array): .. automethod:: with_tagged_reduction """ + expr: prim.Expression + _shape: ShapeType + _dtype: np.dtype[Any] + bindings: Dict[str, Array] + var_to_reduction_descr: Mapping[str, ReductionDescriptor] - _fields = Array._fields + ("expr", "shape", "dtype", - "bindings", "var_to_reduction_descr") - _mapper_method = "map_index_lambda" - - def __init__(self, - expr: prim.Expression, - shape: ShapeType, - dtype: np.dtype[Any], - bindings: Dict[str, Array], - axes: AxesT, - var_to_reduction_descr: Mapping[str, ReductionDescriptor], - tags: FrozenSet[Tag] = frozenset()): - - super().__init__(shape=shape, dtype=dtype, axes=axes, tags=tags) - - self.expr = expr - self.bindings = bindings - self.var_to_reduction_descr = var_to_reduction_descr + _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("expr", "shape", "dtype", + "bindings", + "var_to_reduction_descr") + _mapper_method: ClassVar[str] = "map_index_lambda" def with_tagged_reduction(self, reduction_variable: str, @@ -924,7 +913,7 @@ class EinsumAxisDescriptor: pass -@dataclass(eq=True, frozen=True) +@attrs.define(frozen=True) class EinsumElementwiseAxis(EinsumAxisDescriptor): """ Describes an elementwise access pattern of an array's axis. In terms of the @@ -934,7 +923,7 @@ class EinsumElementwiseAxis(EinsumAxisDescriptor): dim: int -@dataclass(eq=True, frozen=True) +@attrs.define(frozen=True) class EinsumReductionAxis(EinsumAxisDescriptor): """ Describes a reduction access pattern of an array's axis. In terms of the @@ -944,6 +933,7 @@ class EinsumReductionAxis(EinsumAxisDescriptor): dim: int +@attrs.define(frozen=True, eq=False, repr=False) class Einsum(Array): """ An array expression using the `Einstein summation convention @@ -976,25 +966,17 @@ class Einsum(Array): .. automethod:: with_tagged_reduction """ - _fields = Array._fields + ("access_descriptors", - "args", - "redn_axis_to_redn_descr", - "index_to_access_descr") - _mapper_method = "map_einsum" - - def __init__(self, - access_descriptors: Tuple[Tuple[EinsumAxisDescriptor, ...], ...], - args: Tuple[Array, ...], - axes: AxesT, - redn_axis_to_redn_descr: Mapping[EinsumReductionAxis, - ReductionDescriptor], - index_to_access_descr: Mapping[str, EinsumAxisDescriptor], - tags: FrozenSet[Tag] = frozenset()): - super().__init__(axes=axes, tags=tags) - self.access_descriptors = access_descriptors - self.args = args - self.redn_axis_to_redn_descr = redn_axis_to_redn_descr - self.index_to_access_descr = index_to_access_descr + + access_descriptors: Tuple[Tuple[EinsumAxisDescriptor, ...], ...] + args: Tuple[Array, ...] + redn_axis_to_redn_descr: Mapping[EinsumReductionAxis, + ReductionDescriptor] + index_to_access_descr: Mapping[str, EinsumAxisDescriptor] + _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("access_descriptors", + "args", + "redn_axis_to_redn_descr", + "index_to_access_descr") + _mapper_method: ClassVar[str] = "map_einsum" @memoize_method def _access_descr_to_axis_len(self @@ -1298,6 +1280,7 @@ def einsum(subscripts: str, *operands: Array, # {{{ stack +@attrs.define(frozen=True, eq=False, repr=False) class Stack(Array): """Join a sequence of arrays along a new axis. @@ -1310,18 +1293,11 @@ class Stack(Array): The output axis """ + arrays: Tuple[Array, ...] + axis: int - _fields = Array._fields + ("arrays", "axis") - _mapper_method = "map_stack" - - def __init__(self, - arrays: Tuple[Array, ...], - axis: int, - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()): - super().__init__(axes=axes, tags=tags) - self.arrays = arrays - self.axis = axis + _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("arrays", "axis") + _mapper_method: ClassVar[str] = "map_stack" @property def dtype(self) -> np.dtype[Any]: @@ -1338,6 +1314,7 @@ def shape(self) -> ShapeType: # {{{ concatenate +@attrs.define(frozen=True, eq=False, repr=False) class Concatenate(Array): """Join a sequence of arrays along an existing axis. @@ -1350,18 +1327,11 @@ class Concatenate(Array): The axis along which the *arrays* are to be concatenated. """ + arrays: Tuple[Array, ...] + axis: int - _fields = Array._fields + ("arrays", "axis") - _mapper_method = "map_concatenate" - - def __init__(self, - arrays: Tuple[Array, ...], - axis: int, - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()): - super().__init__(axes=axes, tags=tags) - self.arrays = arrays - self.axis = axis + _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("arrays", "axis") + _mapper_method: ClassVar[str] = "map_concatenate" @property def dtype(self) -> np.dtype[Any]: @@ -1382,6 +1352,7 @@ def shape(self) -> ShapeType: # {{{ index remapping +@attrs.define(frozen=True, eq=False, repr=False) class IndexRemappingBase(Array): """Base class for operations that remap the indices of an array. @@ -1393,14 +1364,8 @@ class IndexRemappingBase(Array): The input :class:`~pytato.Array` """ - _fields = Array._fields + ("array",) - - def __init__(self, - array: Array, - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()): - super().__init__(axes=axes, tags=tags) - self.array = array + array: Array + _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("array",) @property def dtype(self) -> np.dtype[Any]: @@ -1411,6 +1376,7 @@ def dtype(self) -> np.dtype[Any]: # {{{ roll +@attrs.define(frozen=True, eq=False, repr=False) class Roll(IndexRemappingBase): """Roll an array along an axis. @@ -1422,18 +1388,12 @@ class Roll(IndexRemappingBase): Shift axis. """ - _fields = IndexRemappingBase._fields + ("shift", "axis") - _mapper_method = "map_roll" + shift: int + axis: int - def __init__(self, - array: Array, - shift: int, - axis: int, - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()): - super().__init__(array, axes, tags) - self.shift = shift - self.axis = axis + _fields: ClassVar[Tuple[str, ...]] = IndexRemappingBase._fields + ("shift", + "axis") + _mapper_method: ClassVar[str] = "map_roll" @property def shape(self) -> ShapeType: @@ -1444,6 +1404,7 @@ def shape(self) -> ShapeType: # {{{ axis permutation +@attrs.define(frozen=True, eq=False, repr=False) class AxisPermutation(IndexRemappingBase): r"""Permute the axes of an array. @@ -1453,17 +1414,11 @@ class AxisPermutation(IndexRemappingBase): A permutation of the input axes. """ - _fields = IndexRemappingBase._fields + ("axis_permutation",) - _mapper_method = "map_axis_permutation" + axis_permutation: Tuple[int, ...] - def __init__(self, - array: Array, - axis_permutation: Tuple[int, ...], - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()): - super().__init__(array, axes, tags) - self.array = array - self.axis_permutation = axis_permutation + _fields: ClassVar[Tuple[str, ...]] = (IndexRemappingBase._fields + + ("axis_permutation",)) + _mapper_method: ClassVar[str] = "map_axis_permutation" @property def shape(self) -> ShapeType: @@ -1478,6 +1433,7 @@ def shape(self) -> ShapeType: # {{{ reshape +@attrs.define(frozen=True, eq=False, repr=False) class Reshape(IndexRemappingBase): """ Reshape an array. @@ -1494,22 +1450,18 @@ class Reshape(IndexRemappingBase): Output layout order, either ``C`` or ``F``. """ + newshape: ShapeType + order: str - _fields = IndexRemappingBase._fields + ("newshape", "order") - _mapper_method = "map_reshape" + _fields: ClassVar[Tuple[str, ...]] = IndexRemappingBase._fields + ("newshape", + "order") + _mapper_method: ClassVar[str] = "map_reshape" - def __init__(self, - array: Array, - newshape: ShapeType, - order: str, - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()): + def __post_init__(self) -> None: # FIXME: Get rid of this restriction - assert order == "C" + assert self.order == "C" - super().__init__(array, axes, tags) - self.newshape = newshape - self.order = order + __attrs_post_init__ = __post_init__ @property def shape(self) -> ShapeType: @@ -1520,25 +1472,15 @@ def shape(self) -> ShapeType: # {{{ indexing -class IndexBase(IndexRemappingBase, ABC): +@attrs.define(frozen=True, eq=False, repr=False) +class IndexBase(IndexRemappingBase): """ Abstract class for all index expressions on an array. .. attribute:: indices """ - _fields = IndexRemappingBase._fields + ("indices",) - - def __init__(self, - array: Array, - indices: Tuple[IndexExpr, ...], - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()): - super().__init__(array, axes, tags) - self.indices = indices - - @abstractproperty - def shape(self) -> ShapeType: - pass + indices: Tuple[IndexExpr, ...] + _fields: ClassVar[Tuple[str, ...]] = IndexRemappingBase._fields + ("indices",) class BasicIndex(IndexBase): @@ -1546,7 +1488,7 @@ class BasicIndex(IndexBase): An indexing expression with all indices being either an :class:`int` or :class:`slice`. """ - _mapper_method = "map_basic_index" + _mapper_method: ClassVar[str] = "map_basic_index" @property def shape(self) -> ShapeType: @@ -1570,7 +1512,7 @@ class AdvancedIndexInContiguousAxes(IndexBase): :class:`AdvancedIndexInNoncontiguousAxes` is that :mod:`numpy` treats those two cases differently, and we're bound to follow its precedent. """ - _mapper_method = "map_contiguous_advanced_index" + _mapper_method: ClassVar[str] = "map_contiguous_advanced_index" @property def shape(self) -> ShapeType: @@ -1614,7 +1556,7 @@ class AdvancedIndexInNoncontiguousAxes(IndexBase): :class:`AdvancedIndexInContiguousAxes` is that :mod:`numpy` treats those two cases differently, and we're bound to follow its precedent. """ - _mapper_method = "map_non_contiguous_advanced_index" + _mapper_method: ClassVar[str] = "map_non_contiguous_advanced_index" @property def shape(self) -> ShapeType: @@ -1645,6 +1587,7 @@ def shape(self) -> ShapeType: # {{{ base class for arguments +@attrs.define(frozen=True, eq=False, repr=False) class InputArgumentBase(Array): r"""Base class for input arguments. @@ -1684,6 +1627,7 @@ def dtype(self) -> np.dtype[Any]: pass +@attrs.define(frozen=True, eq=False, repr=False) class DataWrapper(InputArgumentBase): """Takes concrete array data and packages it to be compatible with the :class:`Array` interface. @@ -1723,19 +1667,12 @@ class DataWrapper(InputArgumentBase): wrapped, a :class:`DataWrapper` instances compare equal to themselves (i.e. the very same instance). """ + data: DataInterface + _shape: ShapeType - _fields = InputArgumentBase._fields + ("data", "shape") - _mapper_method = "map_data_wrapper" - - def __init__(self, - data: DataInterface, - shape: ShapeType, - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()): - super().__init__(axes=axes, tags=tags) - - self.data = data - self._shape = shape + _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("data", + "shape") + _mapper_method: ClassVar[str] = "map_data_wrapper" @property def name(self) -> None: @@ -1760,6 +1697,7 @@ def dtype(self) -> np.dtype[Any]: # {{{ placeholder +@attrs.define(frozen=True, eq=False, repr=False) class Placeholder(_SuppliedShapeAndDtypeMixin, InputArgumentBase): r"""A named placeholder for an array whose concrete value is supplied by the user during evaluation. @@ -1771,27 +1709,22 @@ class Placeholder(_SuppliedShapeAndDtypeMixin, InputArgumentBase): .. automethod:: __init__ """ + name: str + _shape: ShapeType + _dtype: np.dtype[Any] - _fields = InputArgumentBase._fields + ("shape", "dtype", "name") - _mapper_method = "map_placeholder" + _fields: ClassVar[Tuple[str, ...]] = InputArgumentBase._fields + ("shape", + "dtype", + "name") - def __init__(self, - name: str, - shape: ShapeType, - dtype: np.dtype[Any], - axes: AxesT, - tags: FrozenSet[Tag] = frozenset()): - """Should not be called directly. Use :func:`make_placeholder` - instead. - """ - super().__init__(shape=shape, dtype=dtype, axes=axes, tags=tags) - self.name = name + _mapper_method: ClassVar[str] = "map_placeholder" # }}} # {{{ size parameter +@attrs.define(frozen=True, eq=False, repr=False) class SizeParam(InputArgumentBase): r"""A named placeholder for a scalar that may be used as a variable in symbolic expressions for array sizes. @@ -1801,17 +1734,11 @@ class SizeParam(InputArgumentBase): The name by which a value is supplied for the argument once computation begins. """ + name: str + axes: AxesT = attrs.field(kw_only=True, default=()) - _mapper_method = "map_size_param" - - _fields = InputArgumentBase._fields + ("name",) - - def __init__(self, - name: str, - axes: AxesT = (), - tags: FrozenSet[Tag] = frozenset()): - super().__init__(axes=axes, tags=tags) - self.name = name + _mapper_method: ClassVar[str] = "map_size_param" + _fields: ClassVar[Tuple[str, ...]] = InputArgumentBase._fields + ("name",) @property def shape(self) -> ShapeType: @@ -2223,7 +2150,7 @@ def eye(N: int, M: Optional[int] = None, k: int = 0, # noqa: N803 # {{{ arange -@dataclass +@attrs.define class _ArangeInfo: start: Optional[int] stop: Optional[int] diff --git a/pytato/codegen.py b/pytato/codegen.py index df77f5229..223ede340 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -121,7 +121,7 @@ def __init__(self, target: Target) -> None: def map_size_param(self, expr: SizeParam) -> Array: name = expr.name assert name is not None - return SizeParam(name=name, tags=expr.tags) + return SizeParam(name, tags=expr.tags) def map_placeholder(self, expr: Placeholder) -> Array: name = expr.name diff --git a/pytato/distributed.py b/pytato/distributed.py index ee4b4f940..7e48ad95b 100644 --- a/pytato/distributed.py +++ b/pytato/distributed.py @@ -25,11 +25,12 @@ """ from typing import (Any, Dict, Hashable, Tuple, Optional, Set, # noqa: F401 - List, FrozenSet, Callable, cast, Mapping, Iterable + List, FrozenSet, Callable, cast, Mapping, Iterable, + ClassVar ) # Mapping required by sphinx from pyrsistent.typing import PMap as PMapT -from dataclasses import dataclass +import attrs from pytools import UniqueNameGenerator from pytools.tag import Taggable, UniqueTag, Tag @@ -53,7 +54,7 @@ __doc__ = r""" Distributed-memory evaluation of expression graphs is accomplished by :ref:`partitioning ` the graph to reveal communication-free -pieces of the computation. Communication (i.e. sending/receving data) is then +pieces of the computation. Communication (i.e. sending/receiving data) is then accomplished at the boundaries of the parts of the resulting graph partitioning. Recall the requirement for partitioning that, "no part may depend on its own @@ -171,6 +172,7 @@ def copy(self, **kwargs: Any) -> DistributedSend: tags=tags if tags is not None else self.tags) +@attrs.define(frozen=True, eq=False, repr=False, init=False) class DistributedSendRefHolder(Array): """A node acting as an identity on :attr:`passthrough_data` while also holding a reference to a :class:`DistributedSend` in :attr:`send`. Since @@ -205,15 +207,17 @@ class DistributedSendRefHolder(Array): :class:`DistributedSendRefHolder` to be constructed and yet to not become part of the graph constructed by the user. """ + send: DistributedSend + passthrough_data: Array - _mapper_method = "map_distributed_send_ref_holder" - _fields = Array._fields + ("passthrough_data", "send",) + _mapper_method: ClassVar[str] = "map_distributed_send_ref_holder" + _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("passthrough_data", "send") def __init__(self, send: DistributedSend, passthrough_data: Array, tags: FrozenSet[Tag] = frozenset()) -> None: super().__init__(axes=passthrough_data.axes, tags=tags) - self.send = send - self.passthrough_data = passthrough_data + object.__setattr__(self, "send", send) + object.__setattr__(self, "passthrough_data", passthrough_data) @property def shape(self) -> ShapeType: @@ -238,6 +242,7 @@ def copy(self, **kwargs: Any) -> DistributedSendRefHolder: tags) +@attrs.define(frozen=True, eq=False) class DistributedRecv(_SuppliedShapeAndDtypeMixin, Array): """Class representing a distributed receive operation. @@ -264,21 +269,14 @@ class DistributedRecv(_SuppliedShapeAndDtypeMixin, Array): :class:`DistributedRecv` to be constructed and yet to not become part of the graph constructed by the user. """ + src_rank: int + comm_tag: CommTagType + _shape: ShapeType + _dtype: Any # FIXME: sphinx does not like `_dtype: _dtype_any` - _fields = Array._fields + ("shape", "dtype", "src_rank", "comm_tag") - _mapper_method = "map_distributed_recv" - - def __init__(self, src_rank: int, comm_tag: CommTagType, - shape: ShapeType, dtype: Any, - tags: Optional[FrozenSet[Tag]] = frozenset(), - axes: Optional[AxesT] = None) -> None: - - if not axes: - axes = _get_default_axes(len(shape)) - super().__init__(shape=shape, dtype=dtype, tags=tags, - axes=axes) - self.src_rank = src_rank - self.comm_tag = comm_tag + _fields: ClassVar[Tuple[str, ...]] = Array._fields + ("shape", "dtype", + "src_rank", "comm_tag") + _mapper_method: ClassVar[str] = "map_distributed_recv" def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, @@ -309,14 +307,14 @@ def make_distributed_recv(src_rank: int, comm_tag: CommTagType, if axes is None: axes = _get_default_axes(len(shape)) dtype = np.dtype(dtype) - return DistributedRecv(src_rank, comm_tag, shape, dtype, tags, axes=axes) + return DistributedRecv(src_rank, comm_tag, shape, dtype, tags=tags, axes=axes) # }}} # {{{ distributed info collection -@dataclass(frozen=True) +@attrs.define(frozen=True, slots=False) class DistributedGraphPart(GraphPart): """For one graph partition, record send/receive information for input/ output names. @@ -330,7 +328,7 @@ class DistributedGraphPart(GraphPart): distributed_sends: List[DistributedSend] -@dataclass(frozen=True) +@attrs.define(frozen=True, slots=False) class DistributedGraphPartition(GraphPartition): """Store information about distributed graph partitions. This has the same attributes as :class:`~pytato.partition.GraphPartition`, @@ -348,8 +346,8 @@ def _map_distributed_graph_partion_nodes( mapped by *map_array* and all :class:`DistributedSend` instances mapped by *map_send*. """ + from attrs import evolve as replace - from dataclasses import replace return replace( gp, var_name_to_result={name: map_array(ary) @@ -776,7 +774,7 @@ def _get_materialized_arrays_promoted_to_partition_outputs( if users != {stored_ary_to_part_id[ary]}}) -@dataclass(frozen=True, eq=True, repr=True) +@attrs.define(frozen=True, eq=True, repr=True) class PartIDTag(UniqueTag): """ A tag applicable to a :class:`pytato.Array` recording to which part the @@ -1068,7 +1066,7 @@ def set_union( else: sym_tag_to_int_tag, next_tag = mpi_communicator.bcast(None, root=root_rank) - from dataclasses import replace + from attrs import evolve as replace return DistributedGraphPartition( parts={ pid: replace(part, diff --git a/pytato/loopy.py b/pytato/loopy.py index 317943d24..b33e5a11c 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -167,16 +167,20 @@ def expr(self) -> Array: @property def shape(self) -> ShapeType: - loopy_arg = self._container._entry_kernel.arg_dict[ # type:ignore - self.name] - shape: ShapeType = self._container._to_pytato( # type:ignore + # pylint: disable=E1101 + # reason: (pylint doesn't respect the asserts) + assert isinstance(self._container, LoopyCall) + loopy_arg = self._container._entry_kernel.arg_dict[self.name] + shape: ShapeType = self._container._to_pytato( # type:ignore[assignment] loopy_arg.shape) return shape @property def dtype(self) -> np.dtype[Any]: - loopy_arg = self._container._entry_kernel.arg_dict[ # type:ignore - self.name] + # pylint: disable=E1101 + # reason: (pylint doesn't respect the asserts) + assert isinstance(self._container, LoopyCall) + loopy_arg = self._container._entry_kernel.arg_dict[self.name] return np.dtype(loopy_arg.dtype.numpy_dtype) diff --git a/pytato/partition.py b/pytato/partition.py index a74719a48..0f7bc754a 100644 --- a/pytato/partition.py +++ b/pytato/partition.py @@ -26,7 +26,7 @@ from typing import (Any, Callable, Dict, Union, Set, List, Hashable, Tuple, TypeVar, FrozenSet, Mapping, Optional, Type) -from dataclasses import dataclass +import attrs import logging logger = logging.getLogger(__name__) @@ -240,7 +240,7 @@ def map_placeholder(self, expr: Placeholder, *args: Any) -> Any: # {{{ graph partition -@dataclass(frozen=True) +@attrs.define(frozen=True, slots=False) class GraphPart: """ .. attribute:: pid @@ -284,7 +284,7 @@ def all_input_names(self) -> FrozenSet[str]: return self.user_input_names | self. partition_input_names -@dataclass(frozen=True) +@attrs.define(frozen=True, slots=False) class GraphPartition: """Store information about a partitioning of an expression graph. diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index ad26ecead..2553cea55 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -284,7 +284,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: def map_size_param(self, expr: SizeParam) -> Array: assert expr.name is not None - return SizeParam(name=expr.name, axes=expr.axes, tags=expr.tags) + return SizeParam(expr.name, axes=expr.axes, tags=expr.tags) def map_einsum(self, expr: Einsum) -> Array: return Einsum(expr.access_descriptors, @@ -344,7 +344,7 @@ def map_distributed_recv(self, expr: DistributedRecv) -> Array: return DistributedRecv( src_rank=expr.src_rank, comm_tag=expr.comm_tag, shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, tags=expr.tags) + dtype=expr.dtype, tags=expr.tags, axes=expr.axes) class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): @@ -470,7 +470,7 @@ def map_data_wrapper(self, expr: DataWrapper, def map_size_param(self, expr: SizeParam, *args: Any, **kwargs: Any) -> Array: assert expr.name is not None - return SizeParam(name=expr.name, axes=expr.axes, tags=expr.tags) + return SizeParam(expr.name, axes=expr.axes, tags=expr.tags) def map_einsum(self, expr: Einsum, *args: Any, **kwargs: Any) -> Array: return Einsum(expr.access_descriptors, @@ -536,7 +536,7 @@ def map_distributed_recv(self, expr: DistributedRecv, return DistributedRecv( src_rank=expr.src_rank, comm_tag=expr.comm_tag, shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - dtype=expr.dtype, tags=expr.tags) + dtype=expr.dtype, tags=expr.tags, axes=expr.axes) # }}} @@ -1098,7 +1098,7 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] new_expr = Stack(tuple(ary.expr for ary in rec_arrays), - expr.axis, expr.axes, expr.tags) + expr.axis, axes=expr.axes, tags=expr.tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1108,15 +1108,16 @@ def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] new_expr = Concatenate(tuple(ary.expr for ary in rec_arrays), expr.axis, - expr.axes, - expr.tags) + axes=expr.axes, + tags=expr.tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], rec_arrays) def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) - new_expr = Roll(rec_array.expr, expr.shift, expr.axis, expr.axes, expr.tags) + new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, + tags=expr.tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) @@ -1124,7 +1125,7 @@ def map_axis_permutation(self, expr: AxisPermutation ) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, - expr.axes, expr.tags) + axes=expr.axes, tags=expr.tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) @@ -1141,8 +1142,8 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: else expr.indices[i] for i in range( len(expr.indices))), - expr.axes, - expr.tags) + axes=expr.axes, + tags=expr.tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1156,7 +1157,7 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) new_expr = Reshape(rec_array.expr, expr.newshape, - expr.order, expr.axes, expr.tags) + expr.order, axes=expr.axes, tags=expr.tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1166,10 +1167,10 @@ def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.args] new_expr = Einsum(expr.access_descriptors, tuple(ary.expr for ary in rec_arrays), - expr.axes, expr.redn_axis_to_redn_descr, expr.index_to_access_descr, - expr.tags) + axes=expr.axes, + tags=expr.tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1682,7 +1683,7 @@ def map_placeholder(self, expr: Placeholder, *args: Any) -> Placeholder: def map_size_param(self, expr: SizeParam, *args: Any) -> SizeParam: assert expr.name - return SizeParam(name=expr.name, axes=expr.axes, tags=expr.tags) + return SizeParam(expr.name, axes=expr.axes, tags=expr.tags) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: return LoopyCall( diff --git a/test/test_pytato.py b/test/test_pytato.py index e1f281d45..3b0b4fa55 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -775,7 +775,9 @@ def count_data_wrappers(expr): a = pt.make_data_wrapper(np.arange(27)) b = pt.make_data_wrapper(np.arange(27)) - c = pt.make_data_wrapper(a.data.view()) + # pylint-disable-reason: pylint is correct, DataInterface doesn't declare a + # view method, but for numpy-like arrays it should be OK. + c = pt.make_data_wrapper(a.data.view()) # pylint: disable=E1101 d = pt.make_data_wrapper(np.arange(1, 28)) res = a+b+c+d