Skip to content

Commit 1df4b5a

Browse files
committed
put actx-specific matrix multiplication implementations inside sparse_matmul instead of using derived matrix classes
1 parent 1ec5343 commit 1df4b5a

3 files changed

Lines changed: 70 additions & 103 deletions

File tree

arraycontext/context.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,8 @@ class SparseMatrix(ABC):
175175
axes: tuple[ToTagSetConvertible, ...] = dataclasses.field(kw_only=True)
176176
_actx: ArrayContext = dataclasses.field(kw_only=True)
177177

178-
@abstractmethod
179178
def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer:
180-
...
179+
return self._actx.sparse_matmul(self, other)
181180

182181

183182
@dataclasses.dataclass(frozen=True, eq=False, repr=False)
@@ -186,21 +185,6 @@ class CSRMatrix(SparseMatrix):
186185
elem_col_indices: Array
187186
row_starts: Array
188187

189-
@override
190-
def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer:
191-
def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar:
192-
assert self._actx.is_array_type(ary)
193-
prg = self._actx._get_csr_matmul_prg(len(ary.shape))
194-
out_ary = self._actx.call_loopy(
195-
prg, elem_values=self.elem_values,
196-
elem_col_indices=self.elem_col_indices,
197-
row_starts=self.row_starts, array=ary)["out"]
198-
# FIXME
199-
# return self.tag(tagged, out_ary)
200-
return out_ary
201-
202-
return cast("ArrayOrContainer", rec_map_container(_matmul, other))
203-
204188

205189
# {{{ ArrayContext
206190

@@ -627,7 +611,22 @@ def sparse_matmul(
627611
:arg x1: the sparse matrix.
628612
:arg x2: the array.
629613
"""
630-
return x1 @ x2
614+
if isinstance(x1, CSRMatrix):
615+
def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar:
616+
assert self.is_array_type(ary)
617+
prg = self._get_csr_matmul_prg(len(ary.shape))
618+
out_ary = self.call_loopy(
619+
prg, elem_values=x1.elem_values,
620+
elem_col_indices=x1.elem_col_indices,
621+
row_starts=x1.row_starts, array=ary)["out"]
622+
# FIXME
623+
# return self.tag(tagged, out_ary)
624+
return out_ary
625+
626+
return cast("ArrayOrContainer", rec_map_container(_matmul, x2))
627+
628+
else:
629+
raise TypeError(f"unrecognized sparse matrix type '{type(x1).__name__}'")
631630

632631
@abstractmethod
633632
def clone(self) -> Self:

arraycontext/impl/numpy/__init__.py

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,12 @@
3333
THE SOFTWARE.
3434
"""
3535

36-
from dataclasses import dataclass
37-
from functools import cached_property
3836
from typing import TYPE_CHECKING, Any, cast, overload
3937

4038
import numpy as np
4139
from typing_extensions import override
4240

4341
import loopy as lp
44-
from pytools.tag import normalize_tags
4542

4643
from arraycontext.container.traversal import (
4744
rec_map_array_container as rec_map_array_container,
@@ -50,23 +47,23 @@
5047
)
5148
from arraycontext.context import (
5249
ArrayContext,
53-
CSRMatrix as _BaseCSRMatrix,
50+
CSRMatrix,
51+
SparseMatrix,
5452
UntransformedCodeWarning,
5553
)
5654
from arraycontext.typing import (
5755
Array,
5856
ArrayOrContainer,
5957
ArrayOrContainerOrScalar,
6058
ArrayOrContainerOrScalarT,
59+
ArrayOrScalar,
6160
ContainerOrScalarT,
6261
NumpyOrContainerOrScalar,
6362
is_scalar_like,
6463
)
6564

6665

6766
if TYPE_CHECKING:
68-
import scipy.sparse
69-
7067
from pymbolic import Scalar
7168
from pytools.tag import Tag, ToTagSetConvertible
7269

@@ -86,29 +83,6 @@ class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass):
8683
pass
8784

8885

89-
@dataclass(frozen=True, eq=False, repr=False)
90-
class CSRMatrix(_BaseCSRMatrix):
91-
@cached_property
92-
def _np_matrix(self) -> scipy.sparse.csr_matrix:
93-
assert isinstance(self.elem_values, np.ndarray)
94-
assert isinstance(self.elem_col_indices, np.ndarray)
95-
assert isinstance(self.row_starts, np.ndarray)
96-
# FIXME: Not sure if the scipy dependency is OK or if it should just use the
97-
# call_loopy fallback? Currently getting errors with the loopy version:
98-
# loopy.diagnostic.LoopyError: One of the kernels in the program has
99-
# been preprocessed, cannot modify target now.
100-
from scipy.sparse import csr_matrix
101-
return csr_matrix(
102-
(self.elem_values, self.elem_col_indices, self.row_starts),
103-
shape=self.shape)
104-
105-
@override
106-
def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer:
107-
return cast(
108-
"ArrayOrContainer",
109-
rec_map_container(lambda ary: self._np_matrix @ ary, other))
110-
111-
11286
class NumpyArrayContext(ArrayContext):
11387
"""
11488
A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays.
@@ -233,22 +207,30 @@ def einsum(self, spec, *args, arg_names=None, tagged=()):
233207
return np.einsum(spec, *args, optimize="optimal")
234208

235209
@override
236-
def make_csr_matrix(
237-
self,
238-
shape: tuple[int, int],
239-
elem_values: Array,
240-
elem_col_indices: Array,
241-
row_starts: Array,
242-
*,
243-
tags: ToTagSetConvertible = _EMPTY_TAG_SET,
244-
axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix:
245-
tags = normalize_tags(tags)
246-
if axes is None:
247-
axes = (frozenset(), frozenset())
248-
return CSRMatrix(
249-
shape, elem_values, elem_col_indices, row_starts,
250-
tags=tags, axes=axes,
251-
_actx=self)
210+
def sparse_matmul(
211+
self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer:
212+
if isinstance(x1, CSRMatrix):
213+
assert isinstance(x1.elem_values, np.ndarray)
214+
assert isinstance(x1.elem_col_indices, np.ndarray)
215+
assert isinstance(x1.row_starts, np.ndarray)
216+
217+
# FIXME: Not sure if the scipy dependency is OK or if it should just use
218+
# the call_loopy fallback? Currently getting errors with the loopy version:
219+
# loopy.diagnostic.LoopyError: One of the kernels in the program has
220+
# been preprocessed, cannot modify target now.
221+
from scipy.sparse import csr_matrix
222+
np_matrix = csr_matrix(
223+
(x1.elem_values, x1.elem_col_indices, x1.row_starts),
224+
shape=x1.shape)
225+
226+
def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar:
227+
assert isinstance(ary, np.ndarray)
228+
return np_matrix @ ary
229+
230+
return cast("ArrayOrContainer", rec_map_container(_matmul, x2))
231+
232+
else:
233+
raise TypeError(f"unrecognized sparse matrix type '{type(x1).__name__}'")
252234

253235
@property
254236
def permits_inplace_modification(self):

arraycontext/impl/pytato/__init__.py

Lines changed: 26 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,12 @@
5454
import abc
5555
import sys
5656
from dataclasses import dataclass
57-
from functools import cached_property
5857
from typing import TYPE_CHECKING, Any, cast
5958

6059
import numpy as np
6160
from typing_extensions import override
6261

63-
from pytools import memoize_method
62+
from pytools import memoize_in, memoize_method
6463
from pytools.tag import Tag, ToTagSetConvertible, normalize_tags
6564

6665
from arraycontext.container.traversal import (
@@ -69,7 +68,7 @@
6968
)
7069
from arraycontext.context import (
7170
ArrayContext,
72-
CSRMatrix as _BaseCSRMatrix,
71+
CSRMatrix,
7372
P,
7473
SparseMatrix,
7574
UntransformedCodeWarning,
@@ -144,28 +143,6 @@ class _NotOnlyDataWrappers(Exception): # noqa: N818
144143
pass
145144

146145

147-
@dataclass(frozen=True, eq=False, repr=False)
148-
class CSRMatrix(_BaseCSRMatrix):
149-
@cached_property
150-
def _pt_matrix(self) -> pt.CSRMatrix:
151-
import pytato as pt
152-
assert isinstance(self.elem_values, pt.Array)
153-
assert isinstance(self.elem_col_indices, pt.Array)
154-
assert isinstance(self.row_starts, pt.Array)
155-
return pt.make_csr_matrix(
156-
self.shape, self.elem_values, self.elem_col_indices, self.row_starts,
157-
tags=_preprocess_array_tags(self.tags), axes=self.axes)
158-
159-
@override
160-
def __matmul__(self, other: ArrayOrContainer) -> ArrayOrContainer:
161-
def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar:
162-
import pytato as pt
163-
assert isinstance(ary, pt.Array)
164-
return self._pt_matrix @ ary
165-
166-
return cast("ArrayOrContainer", rec_map_container(_matmul, other))
167-
168-
169146
# {{{ _BasePytatoArrayContext
170147

171148
class _BasePytatoArrayContext(ArrayContext, abc.ABC):
@@ -904,21 +881,30 @@ def preprocess_arg(name, arg):
904881
]).tagged(_preprocess_array_tags(tagged))
905882

906883
@override
907-
def make_csr_matrix(
908-
self,
909-
shape: tuple[int, int],
910-
elem_values: Array,
911-
elem_col_indices: Array,
912-
row_starts: Array,
913-
*,
914-
tags: ToTagSetConvertible = _EMPTY_TAG_SET,
915-
axes: tuple[ToTagSetConvertible, ...] | None = None) -> CSRMatrix:
916-
if axes is None:
917-
axes = (frozenset(), frozenset())
918-
return CSRMatrix(
919-
shape, elem_values, elem_col_indices, row_starts,
920-
tags=tags, axes=axes,
921-
_actx=self)
884+
def sparse_matmul(
885+
self, x1: SparseMatrix, x2: ArrayOrContainer) -> ArrayOrContainer:
886+
import pytato as pt
887+
888+
if isinstance(x1, CSRMatrix):
889+
@memoize_in(x1, "pt_matrix")
890+
def _get_pt_matrix() -> pt.CSRMatrix:
891+
assert isinstance(x1.elem_values, pt.Array)
892+
assert isinstance(x1.elem_col_indices, pt.Array)
893+
assert isinstance(x1.row_starts, pt.Array)
894+
return pt.make_csr_matrix(
895+
x1.shape, x1.elem_values, x1.elem_col_indices, x1.row_starts,
896+
tags=_preprocess_array_tags(x1.tags), axes=x1.axes)
897+
898+
pt_matrix: pt.CSRMatrix = _get_pt_matrix()
899+
900+
def _matmul(ary: ArrayOrScalar) -> ArrayOrScalar:
901+
assert isinstance(ary, pt.Array)
902+
return pt_matrix @ ary
903+
904+
return cast("ArrayOrContainer", rec_map_container(_matmul, x2))
905+
906+
else:
907+
raise TypeError(f"unrecognized sparse matrix type '{type(x1).__name__}'")
922908

923909
def clone(self):
924910
return type(self)(self.queue, self.allocator)

0 commit comments

Comments
 (0)