@@ -589,10 +589,11 @@ class MemAccess(Record):
589589 A :class:`str` that specifies the variable name of the data
590590 accessed.
591591
592- .. attribute:: variable_tag
592+ .. attribute:: variable_tags
593593
594- A :class:`str` that specifies the variable tag of a
595- :class:`loopy.symbolic.TaggedVariable`.
594+ A :class:`frozenset` of subclasses of :class:`~pytools.tag.Tag`
595+ that reflects :attr:`~loopy.symbolic.TaggedVariable.tags` of
596+ an accessed variable.
596597
597598 .. attribute:: count_granularity
598599
@@ -610,27 +611,60 @@ class MemAccess(Record):
610611 """
611612
612613 def __init__ (self , mtype = None , dtype = None , lid_strides = None , gid_strides = None ,
613- direction = None , variable = None , variable_tag = None ,
614+ direction = None , variable = None ,
615+ * , variable_tags = None , variable_tag = None ,
614616 count_granularity = None ):
615617
616618 if count_granularity not in CountGranularity .ALL + [None ]:
617619 raise ValueError ("Op.__init__: count_granularity '%s' is "
618620 "not allowed. count_granularity options: %s"
619621 % (count_granularity , CountGranularity .ALL + [None ]))
620622
623+ # {{{ normalize variable_tags
624+
625+ if variable_tags is not None and variable_tag is not None :
626+ raise TypeError (
627+ "may not specify both 'variable_tags' and 'variable_tag'" )
628+ if variable_tag is not None :
629+ from loopy .kernel .creation import _normalize_string_tag
630+ variable_tags = frozenset ({_normalize_string_tag (variable_tag )})
631+
632+ from warnings import warn
633+ warn ("Passing 'variable_tag' to MemAccess is deprecated and will "
634+ "stop working in 2022. Pass variable_tags instead." )
635+
636+ if variable_tags is None :
637+ variable_tags = frozenset ()
638+
639+ # }}}
640+
621641 if dtype is None :
622642 Record .__init__ (self , mtype = mtype , dtype = dtype , lid_strides = lid_strides ,
623643 gid_strides = gid_strides , direction = direction ,
624- variable = variable , variable_tag = variable_tag ,
644+ variable = variable , variable_tags = variable_tags ,
625645 count_granularity = count_granularity )
626646 else :
627647 from loopy .types import to_loopy_type
628648 Record .__init__ (self , mtype = mtype , dtype = to_loopy_type (dtype ),
629649 lid_strides = lid_strides , gid_strides = gid_strides ,
630650 direction = direction , variable = variable ,
631- variable_tag = variable_tag ,
651+ variable_tags = variable_tags ,
632652 count_granularity = count_granularity )
633653
654+ @property
655+ def variable_tag (self ):
656+ from warnings import warn
657+ warn ("Accessing MemAccess.variable_tag is deprecated and will stop working "
658+ "in 2022. Use MemAccess.variable_tags instead." , DeprecationWarning ,
659+ stacklevel = 2 )
660+
661+ if len (self .variable_tags ) != 1 :
662+ raise ValueError ("cannot access MemAccess.variable_tag: access has "
663+ f"{ len (self .variable_tags )} tags" )
664+
665+ tag , = self .variable_tags
666+ return tag
667+
634668 def __hash__ (self ):
635669 # Note that this means lid_strides and gid_strides must be sorted
636670 # in self.__repr__()
@@ -647,7 +681,7 @@ def __repr__(self):
647681 sorted (self .gid_strides .items ())),
648682 self .direction ,
649683 self .variable ,
650- self .variable_tag ,
684+ self .variable_tags ,
651685 self .count_granularity )
652686
653687# }}}
@@ -1031,9 +1065,9 @@ def map_variable(self, expr):
10311065 def map_subscript (self , expr ):
10321066 name = expr .aggregate .name
10331067 try :
1034- var_tag = expr .aggregate .tag
1068+ var_tags = expr .aggregate .tags
10351069 except AttributeError :
1036- var_tag = None
1070+ var_tags = frozenset ()
10371071
10381072 if name in self .knl .arg_dict :
10391073 array = self .knl .arg_dict [name ]
@@ -1062,7 +1096,7 @@ def map_subscript(self, expr):
10621096 lid_strides = dict (sorted (lid_strides .items ())),
10631097 gid_strides = dict (sorted (gid_strides .items ())),
10641098 variable = name ,
1065- variable_tag = var_tag ,
1099+ variable_tags = var_tags ,
10661100 count_granularity = count_granularity
10671101 ): 1 }
10681102 ) + self .rec (expr .index_tuple )
@@ -1678,7 +1712,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
16781712 gid_strides = mem_access .gid_strides ,
16791713 direction = mem_access .direction ,
16801714 variable = mem_access .variable ,
1681- variable_tag = mem_access .variable_tag ,
1715+ variable_tags = mem_access .variable_tags ,
16821716 count_granularity = mem_access .count_granularity ):
16831717 ct
16841718 for mem_access , ct in access_map .count_map .items ()},
0 commit comments