3333THE SOFTWARE.
3434"""
3535
36- from dataclasses import dataclass
37- from functools import cached_property
3836from typing import TYPE_CHECKING , Any , cast , overload
3937
4038import numpy as np
4139from typing_extensions import override
4240
4341import loopy as lp
44- from pytools .tag import normalize_tags
4542
4643from arraycontext .container .traversal import (
4744 rec_map_array_container as rec_map_array_container ,
5047)
5148from arraycontext .context import (
5249 ArrayContext ,
53- CSRMatrix as _BaseCSRMatrix ,
50+ CSRMatrix ,
51+ SparseMatrix ,
5452 UntransformedCodeWarning ,
5553)
5654from arraycontext .typing import (
5755 Array ,
5856 ArrayOrContainer ,
5957 ArrayOrContainerOrScalar ,
6058 ArrayOrContainerOrScalarT ,
59+ ArrayOrScalar ,
6160 ContainerOrScalarT ,
6261 NumpyOrContainerOrScalar ,
6362 is_scalar_like ,
6463)
6564
6665
6766if TYPE_CHECKING :
68- import scipy .sparse
69-
7067 from pymbolic import Scalar
7168 from pytools .tag import Tag , ToTagSetConvertible
7269
@@ -86,29 +83,6 @@ class NumpyNonObjectArray(metaclass=NumpyNonObjectArrayMetaclass):
8683 pass
8784
8885
89- @dataclass (frozen = True , eq = False , repr = False )
90- class CSRMatrix (_BaseCSRMatrix ):
91- @cached_property
92- def _np_matrix (self ) -> scipy .sparse .csr_matrix :
93- assert isinstance (self .elem_values , np .ndarray )
94- assert isinstance (self .elem_col_indices , np .ndarray )
95- assert isinstance (self .row_starts , np .ndarray )
96- # FIXME: Not sure if the scipy dependency is OK or if it should just use the
97- # call_loopy fallback? Currently getting errors with the loopy version:
98- # loopy.diagnostic.LoopyError: One of the kernels in the program has
99- # been preprocessed, cannot modify target now.
100- from scipy .sparse import csr_matrix
101- return csr_matrix (
102- (self .elem_values , self .elem_col_indices , self .row_starts ),
103- shape = self .shape )
104-
105- @override
106- def __matmul__ (self , other : ArrayOrContainer ) -> ArrayOrContainer :
107- return cast (
108- "ArrayOrContainer" ,
109- rec_map_container (lambda ary : self ._np_matrix @ ary , other ))
110-
111-
11286class NumpyArrayContext (ArrayContext ):
11387 """
11488 A :class:`ArrayContext` that uses :class:`numpy.ndarray` to represent arrays.
@@ -233,22 +207,30 @@ def einsum(self, spec, *args, arg_names=None, tagged=()):
233207 return np .einsum (spec , * args , optimize = "optimal" )
234208
235209 @override
236- def make_csr_matrix (
237- self ,
238- shape : tuple [int , int ],
239- elem_values : Array ,
240- elem_col_indices : Array ,
241- row_starts : Array ,
242- * ,
243- tags : ToTagSetConvertible = _EMPTY_TAG_SET ,
244- axes : tuple [ToTagSetConvertible , ...] | None = None ) -> CSRMatrix :
245- tags = normalize_tags (tags )
246- if axes is None :
247- axes = (frozenset (), frozenset ())
248- return CSRMatrix (
249- shape , elem_values , elem_col_indices , row_starts ,
250- tags = tags , axes = axes ,
251- _actx = self )
210+ def sparse_matmul (
211+ self , x1 : SparseMatrix , x2 : ArrayOrContainer ) -> ArrayOrContainer :
212+ if isinstance (x1 , CSRMatrix ):
213+ assert isinstance (x1 .elem_values , np .ndarray )
214+ assert isinstance (x1 .elem_col_indices , np .ndarray )
215+ assert isinstance (x1 .row_starts , np .ndarray )
216+
217+ # FIXME: Not sure if the scipy dependency is OK or if it should just use
218+ # the call_loopy fallback? Currently getting errors with the loopy version:
219+ # loopy.diagnostic.LoopyError: One of the kernels in the program has
220+ # been preprocessed, cannot modify target now.
221+ from scipy .sparse import csr_matrix
222+ np_matrix = csr_matrix (
223+ (x1 .elem_values , x1 .elem_col_indices , x1 .row_starts ),
224+ shape = x1 .shape )
225+
226+ def _matmul (ary : ArrayOrScalar ) -> ArrayOrScalar :
227+ assert isinstance (ary , np .ndarray )
228+ return np_matrix @ ary
229+
230+ return cast ("ArrayOrContainer" , rec_map_container (_matmul , x2 ))
231+
232+ else :
233+ raise TypeError (f"unrecognized sparse matrix type '{ type (x1 ).__name__ } '" )
252234
253235 @property
254236 def permits_inplace_modification (self ):
0 commit comments