114114"""
115115
116116
117+ import dataclasses
117118from abc import ABC , abstractmethod
118119from collections .abc import Callable , Hashable , Mapping
119120from typing import (
138139 from numpy .typing import DTypeLike
139140
140141 import loopy
141- from pytools .tag import ToTagSetConvertible
142+ from pytools .tag import Tag , ToTagSetConvertible
142143
143144 from .fake_numpy import BaseFakeNumpyNamespace
144145 from .typing import (
156157P = ParamSpec ("P" )
157158
158159
160+ @dataclasses .dataclass (frozen = True , eq = False , repr = False )
161+ class _CSRMatrix :
162+ # FIXME: Type for shape?
163+ shape : Any
164+ elem_values : Array
165+ elem_col_indices : Array
166+ row_starts : Array
167+ tags : frozenset [Tag ] = dataclasses .field (kw_only = True )
168+ axes : tuple [ToTagSetConvertible , ...] = dataclasses .field (kw_only = True )
169+ _matmul_func : Callable [[_CSRMatrix , Array ], Array ] = \
170+ dataclasses .field (kw_only = True )
171+
172+ def __matmul__ (self , other : Array ) -> Array :
173+ return self ._matmul_func (self , other )
174+
175+
159176# {{{ ArrayContext
160177
161178class ArrayContext (ABC ):
@@ -172,6 +189,8 @@ class ArrayContext(ABC):
172189 .. automethod:: to_numpy
173190 .. automethod:: call_loopy
174191 .. automethod:: einsum
192+ .. automethod:: make_csr_matrix
193+ .. automethod:: sparse_matmul
175194 .. attribute:: np
176195
177196 Provides access to a namespace that serves as a work-alike to
@@ -424,6 +443,177 @@ def einsum(self,
424443 )["out" ]
425444 return self .tag (tagged , out_ary )
426445
446+ # FIXME: Not sure what type annotations to use for shape and result
447+ def make_csr_matrix (
448+ self ,
449+ shape ,
450+ elem_values : Array ,
451+ elem_col_indices : Array ,
452+ row_starts : Array ,
453+ * ,
454+ tags : ToTagSetConvertible = frozenset (),
455+ axes : tuple [ToTagSetConvertible , ...] | None = None ) -> Any :
456+ """Return a context-dependent object that represents a sparse matrix in
457+ compressed sparse row (CSR) format. Result is suitable for passing to
458+ :meth:`sparse_matmul`.
459+
460+ :arg shape: the (two-dimensional) shape of the matrix
461+ :arg elem_values: a one-dimensional array containing the values of all of the
462+ nonzero entries of the matrix, grouped by row.
463+ :arg elem_col_indices: a one-dimensional array containing the column index
464+ values corresponding to each entry in *elem_values*.
465+ :arg row_starts: a one-dimensional array of length `nrows+1`, where each entry
466+ gives the starting index in *elem_values* and *elem_col_indices* for the
467+ given row, with the last entry being equal to `nrows`.
468+ """
469+ if axes is None :
470+ axes = (frozenset (), frozenset ())
471+
472+ return _CSRMatrix (
473+ shape , elem_values , elem_col_indices , row_starts ,
474+ tags = tags , axes = axes ,
475+ _matmul_func = lambda x1 , x2 : self .sparse_matmul (x1 , x2 ))
476+
477+ @memoize_method
478+ def _get_csr_matmul_prg (self , out_ndim : int ) -> loopy .TranslationUnit :
479+ import numpy as np
480+
481+ import loopy as lp
482+
483+ out_extra_inames = tuple (f"i{ n } " for n in range (1 , out_ndim ))
484+ out_inames = ("irow" , * out_extra_inames )
485+ out_inames_set = frozenset (out_inames )
486+
487+ out_extra_shape_comp_names = tuple (f"n{ n } " for n in range (1 , out_ndim ))
488+ out_shape_comp_names = ("nrows" , * out_extra_shape_comp_names )
489+
490+ domains : list [str ] = []
491+ domains .append (
492+ "{ [" + "," .join (out_inames ) + "] : "
493+ + " and " .join (
494+ f"0 <= { iname } < { shape_comp_name } "
495+ for iname , shape_comp_name in zip (
496+ out_inames , out_shape_comp_names , strict = True ))
497+ + " }" )
498+ domains .append (
499+ "{ [iel] : iel_lbound <= iel < iel_ubound }" )
500+
501+ temporary_variables : Mapping [str , lp .TemporaryVariable ] = {
502+ "iel_lbound" : lp .TemporaryVariable (
503+ "iel_lbound" ,
504+ shape = (),
505+ address_space = lp .AddressSpace .GLOBAL ,
506+ # FIXME: Need to do anything with tags?
507+ ),
508+ "iel_ubound" : lp .TemporaryVariable (
509+ "iel_ubound" ,
510+ shape = (),
511+ address_space = lp .AddressSpace .GLOBAL ,
512+ # FIXME: Need to do anything with tags?
513+ )}
514+
515+ from loopy .kernel .instruction import make_assignment
516+ from pymbolic import var
517+ # FIXME: Need tags for any of these?
518+ instructions : list [lp .Assignment | lp .CallInstruction ] = [
519+ make_assignment (
520+ (var ("iel_lbound" ),),
521+ var ("row_starts" )[var ("irow" )],
522+ id = "insn0" ,
523+ within_inames = out_inames_set ),
524+ make_assignment (
525+ (var ("iel_ubound" ),),
526+ var ("row_starts" )[var ("irow" ) + 1 ],
527+ id = "insn1" ,
528+ within_inames = out_inames_set ),
529+ make_assignment (
530+ (var ("out" )[tuple (var (iname ) for iname in out_inames )],),
531+ lp .Reduction (
532+ "sum" ,
533+ (var ("iel" ),),
534+ var ("elem_values" )[var ("iel" ),]
535+ * var ("array" )[(
536+ var ("elem_col_indices" )[var ("iel" ),],
537+ * (var (iname ) for iname in out_extra_inames ))]),
538+ id = "insn2" ,
539+ within_inames = out_inames_set ,
540+ depends_on = frozenset ({"insn0" , "insn1" }))]
541+
542+ from loopy .version import MOST_RECENT_LANGUAGE_VERSION
543+
544+ from .loopy import _DEFAULT_LOOPY_OPTIONS
545+
546+ knl = lp .make_kernel (
547+ domains = domains ,
548+ instructions = instructions ,
549+ temporary_variables = temporary_variables ,
550+ kernel_data = [
551+ lp .ValueArg ("nrows" , is_input = True ),
552+ lp .ValueArg ("ncols" , is_input = True ),
553+ lp .ValueArg ("nels" , is_input = True ),
554+ * (
555+ lp .ValueArg (shape_comp_name , is_input = True )
556+ for shape_comp_name in out_extra_shape_comp_names ),
557+ lp .GlobalArg ("elem_values" , shape = (var ("nels" ),), is_input = True ),
558+ lp .GlobalArg ("elem_col_indices" , shape = (var ("nels" ),), is_input = True ),
559+ lp .GlobalArg ("row_starts" , shape = lp .auto , is_input = True ),
560+ lp .GlobalArg (
561+ "array" ,
562+ shape = (
563+ var ("ncols" ),
564+ * (
565+ var (shape_comp_name )
566+ for shape_comp_name in out_extra_shape_comp_names ),),
567+ # order="C",
568+ is_input = True ),
569+ lp .GlobalArg (
570+ "out" ,
571+ shape = (
572+ var ("nrows" ),
573+ * (
574+ var (shape_comp_name )
575+ for shape_comp_name in out_extra_shape_comp_names ),),
576+ # order="C",
577+ is_input = False ),
578+ ...],
579+ name = "csr_matmul_kernel" ,
580+ lang_version = MOST_RECENT_LANGUAGE_VERSION ,
581+ options = _DEFAULT_LOOPY_OPTIONS ,
582+ default_order = lp .auto ,
583+ default_offset = lp .auto ,
584+ # FIXME: Need to do anything with tags?
585+ )
586+
587+ idx_dtype = knl .default_entrypoint .index_dtype
588+
589+ return lp .add_and_infer_dtypes (
590+ knl ,
591+ {
592+ "," .join ([
593+ "ncols" , "nrows" , "nels" ,
594+ * out_extra_shape_comp_names ]): idx_dtype ,
595+ "elem_values,array,out" : np .float64 ,
596+ "elem_col_indices,row_starts" : idx_dtype })
597+
598+ # FIXME: Not sure what type annotation to use for x1
599+ def sparse_matmul (self , x1 , x2 : Array ) -> Array :
600+ """Multiply a sparse matrix by an array.
601+
602+ :arg x1: the sparse matrix.
603+ :arg x2: the array.
604+ """
605+ if isinstance (x1 , _CSRMatrix ):
606+ prg = self ._get_csr_matmul_prg (len (x2 .shape ))
607+ out_ary = self .call_loopy (
608+ prg , elem_values = x1 .elem_values ,
609+ elem_col_indices = x1 .elem_col_indices ,
610+ row_starts = x1 .row_starts , array = x2 )["out" ]
611+ # FIXME
612+ # return self.tag(tagged, out_ary)
613+ return out_ary
614+ else :
615+ raise TypeError (f"unrecognized matrix type '{ type (x1 ).__name__ } '." )
616+
427617 @abstractmethod
428618 def clone (self ) -> Self :
429619 """If possible, return a version of *self* that is semantically
0 commit comments