Skip to content

Commit e21e8f8

Browse files
authored
Merge pull request #245 from inducer/pytools-tag-for-instruction
No more string tags
2 parents 65aa654 + 42d5f55 commit e21e8f8

16 files changed

Lines changed: 260 additions & 78 deletions

File tree

doc/ref_kernel.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,12 @@ Barrier Instructions
493493

494494
.. autoclass:: BarrierInstruction
495495

496+
Instruction Tags
497+
^^^^^^^^^^^^^^^^
498+
499+
.. autoclass:: LegacyStringInstructionTag
500+
.. autoclass:: UseStreamingStoreTag
501+
496502
.. }}}
497503
498504
Data: Arguments and Temporaries

loopy/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
default_function_mangler, single_arg_function_mangler)
3333

3434
from loopy.kernel.instruction import (
35+
LegacyStringInstructionTag, UseStreamingStoreTag,
3536
MemoryOrdering, memory_ordering,
3637
MemoryScope, memory_scope,
3738
VarAtomicity, OrderedAtomic, AtomicInit, AtomicUpdate,
@@ -155,6 +156,7 @@
155156
"LoopKernel",
156157
"KernelState", "kernel_state", # lower case is deprecated
157158

159+
"LegacyStringInstructionTag", "UseStreamingStoreTag",
158160
"MemoryOrdering", "memory_ordering", # lower case is deprecated
159161
"MemoryScope", "memory_scope", # lower case is deprecated
160162

loopy/frontend/fortran/translator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from islpy import dim_type
3535
from loopy.symbolic import IdentityMapper
3636
from loopy.diagnostic import LoopyError
37+
from loopy.kernel.instruction import LegacyStringInstructionTag
3738
from pymbolic.primitives import Wildcard
3839

3940

@@ -640,16 +641,16 @@ def map_Comment(self, node):
640641
stripped_comment_line)
641642

642643
if begin_tag_match:
643-
tag = begin_tag_match.group(1)
644+
tag = LegacyStringInstructionTag(begin_tag_match.group(1))
644645
if tag in self.instruction_tags:
645-
raise TranslationError("nested begin tag for tag '%s'" % tag)
646+
raise TranslationError(f"nested begin tag for tag '{tag.value}'")
646647
self.instruction_tags.append(tag)
647648

648649
elif end_tag_match:
649-
tag = end_tag_match.group(1)
650+
tag = LegacyStringInstructionTag(end_tag_match.group(1))
650651
if tag not in self.instruction_tags:
651652
raise TranslationError(
652-
"end tag without begin tag for tag '%s'" % tag)
653+
f"end tag without begin tag for tag '{tag.value}'")
653654
self.instruction_tags.remove(tag)
654655

655656
elif faulty_loopy_pragma_match is not None:

loopy/kernel/creation.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,38 @@ def __init__(self, name):
7272
# }}}
7373

7474

75+
# {{{ tag normalization
76+
77+
def _normalize_string_tag(tag):
78+
from pytools.tag import Tag
79+
80+
from loopy.kernel.instruction import (
81+
UseStreamingStoreTag, LegacyStringInstructionTag)
82+
if tag == "!streaming_store":
83+
return UseStreamingStoreTag()
84+
else:
85+
from pytools import resolve_name
86+
try:
87+
tag_cls = resolve_name(tag)
88+
except ImportError:
89+
pass
90+
except AttributeError:
91+
pass
92+
else:
93+
if issubclass(tag_cls, Tag):
94+
return tag_cls()
95+
96+
return LegacyStringInstructionTag(tag)
97+
98+
99+
def _normalize_tags(tags):
100+
return frozenset(
101+
_normalize_string_tag(t) if isinstance(t, str) else t
102+
for t in tags)
103+
104+
# }}}
105+
106+
75107
# {{{ expand defines
76108

77109
WORD_RE = re.compile(r"\b([a-zA-Z0-9_]+)\b")
@@ -328,9 +360,9 @@ def parse_nosync_option(opt_value):
328360
del new_predicates
329361

330362
elif opt_key == "tags" and opt_value is not None:
331-
result["tags"] = frozenset(
363+
result["tags"] = _normalize_tags([
332364
tag.strip() for tag in opt_value.split(":")
333-
if tag.strip())
365+
if tag.strip()])
334366

335367
elif opt_key == "atomic":
336368
if is_with_block:
@@ -803,6 +835,10 @@ def intern_if_str(s):
803835
| insn_options_stack[-1]["conflicts_with_groups"]),
804836
**kwargs)
805837

838+
norm_tags = _normalize_tags(insn.tags)
839+
if norm_tags != insn.tags:
840+
insn = insn.copy(tags=norm_tags)
841+
806842
new_instructions.append(insn)
807843
inames_to_dup.append([])
808844

loopy/kernel/instruction.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,59 @@
2222

2323
from sys import intern
2424
from pytools import ImmutableRecord, memoize_method
25+
from pytools.tag import Tag, tag_dataclass, Taggable
2526
from loopy.diagnostic import LoopyError
2627
from loopy.tools import Optional
2728
from warnings import warn
2829
import islpy as isl
2930

3031

32+
# {{{ instruction tags
33+
34+
@tag_dataclass
35+
class LegacyStringInstructionTag(Tag):
36+
"""A subclass of :class:`pytools.tag.Tag` for use in
37+
:attr:`InstructionBase.tags` used for forward compatibility of the old
38+
string-based tagging mechanism. String-based tags are automatically converted
39+
to this type.
40+
41+
.. attribute:: value
42+
"""
43+
value: str
44+
45+
# FIXME: This class should be deprecated as soon as there is a viable
46+
# alternative. For now, pattern matching and the textual syntax are
47+
# only able to generate string tags, which is why the deprecation is not
48+
# yet in effect.
49+
50+
51+
@tag_dataclass
52+
class UseStreamingStoreTag(Tag):
53+
"""A subclass of :class:`pytools.tag.Tag` for use in
54+
:attr:`InstructionBase.tags` used to indicate that if the instruction is an
55+
:class:`Assignment` or a :class:`CallInstruction`, then the 'store' part of
56+
the assignment should be realized using streaming stores.
57+
58+
.. note::
59+
60+
This tag is advisory in nature and may be ignored by targets
61+
that do not understand it or in situations where it does not
62+
apply.
63+
64+
.. warning::
65+
66+
This is a dodgy shortcut, and no promise is made that this will
67+
continue to work. Whether this is safe is target-dependent and
68+
program-dependent. No promise of safety is made.
69+
"""
70+
pass
71+
72+
# }}}
73+
74+
3175
# {{{ instructions: base class
3276

33-
class InstructionBase(ImmutableRecord):
77+
class InstructionBase(ImmutableRecord, Taggable):
3478
"""A base class for all types of instruction that can occur in
3579
a kernel.
3680
@@ -135,11 +179,11 @@ class InstructionBase(ImmutableRecord):
135179
136180
.. attribute:: tags
137181
138-
A :class:`frozenset` of string identifiers that can be used to
139-
identify groups of instructions.
140-
141-
Tags starting with exclamation marks (``!``) are reserved and may have
142-
specific meanings defined by :mod:`loopy` or its targets.
182+
A :class:`frozenset` of subclasses of :class:`pytools.tag.Tag` used to
183+
provide metadata on this object. Legacy string tags are converted to
184+
:class:`LegacyStringInstructionTag` or, if they used to carry
185+
a functional meaning, the tag carrying that same fucntional meaning
186+
(e.g. :class:`UseStreamingStoreTag`).
143187
144188
.. automethod:: __init__
145189
.. automethod:: assignee_var_names
@@ -148,6 +192,8 @@ class InstructionBase(ImmutableRecord):
148192
.. automethod:: write_dependency_names
149193
.. automethod:: dependency_names
150194
.. automethod:: copy
195+
196+
Inherits from :class:`pytools.tag.Taggable`.
151197
"""
152198

153199
# within_inames_is_final is deprecated and will be removed in version 2017.x.
@@ -257,8 +303,13 @@ def __init__(self, id, depends_on, depends_on_is_final,
257303
within_inames=within_inames,
258304
priority=priority,
259305
predicates=predicates,
306+
# Yes, tags is set by both this and the Taggable constructor.
307+
# Here, we set it so that ImmutableRecord knows about it.
308+
# The Taggable constructor call does extra validation.
260309
tags=tags)
261310

311+
Taggable.__init__(self, tags)
312+
262313
# {{{ abstract interface
263314

264315
def read_dependency_names(self):
@@ -347,7 +398,9 @@ def get_str_options(self):
347398
if self.priority:
348399
result.append("priority=%d" % self.priority)
349400
if self.tags:
350-
result.append("tags=%s" % ":".join(self.tags))
401+
from loopy.kernel.tools import stringify_instruction_tag
402+
result.append("tags=%s" % ":".join(
403+
stringify_instruction_tag(t) for t in self.tags))
351404
if hasattr(self, "atomicity") and self.atomicity:
352405
result.append("atomic=%s" % ":".join(str(a) for a in self.atomicity))
353406

loopy/kernel/tools.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1415,6 +1415,14 @@ def conform_to_uniform_length(s):
14151415

14161416
# {{{ stringify_instruction_list
14171417

1418+
def stringify_instruction_tag(tag):
1419+
from loopy.kernel.instruction import LegacyStringInstructionTag
1420+
if isinstance(tag, LegacyStringInstructionTag):
1421+
return f"S({tag.value})"
1422+
else:
1423+
return str(tag)
1424+
1425+
14181426
def stringify_instruction_list(kernel):
14191427
# {{{ topological sort
14201428

@@ -1529,7 +1537,8 @@ def adapt_to_new_inames_list(new_inames):
15291537
if insn.priority:
15301538
options.append("priority=%d" % insn.priority)
15311539
if insn.tags:
1532-
options.append("tags=%s" % ":".join(insn.tags))
1540+
options.append("tags=%s" % ":".join(
1541+
stringify_instruction_tag(t) for t in insn.tags))
15331542
if isinstance(insn, lp.Assignment) and insn.atomicity:
15341543
options.append("atomic=%s" % ":".join(
15351544
str(a) for a in insn.atomicity))

loopy/match.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,16 @@ def __call__(self, kernel, matchable):
243243

244244
class Tagged(GlobMatchExpressionBase):
245245
def __call__(self, kernel, matchable):
246+
from loopy.kernel.instruction import LegacyStringInstructionTag
246247
if matchable.tags:
247-
return any(self.re.match(tag) for tag in matchable.tags)
248+
return any(
249+
self.re.match(tag.value)
250+
if isinstance(tag, LegacyStringInstructionTag)
251+
else
252+
253+
False
254+
255+
for tag in matchable.tags)
248256
else:
249257
return False
250258

loopy/statistics.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -589,10 +589,11 @@ class MemAccess(Record):
589589
A :class:`str` that specifies the variable name of the data
590590
accessed.
591591
592-
.. attribute:: variable_tag
592+
.. attribute:: variable_tags
593593
594-
A :class:`str` that specifies the variable tag of a
595-
:class:`loopy.symbolic.TaggedVariable`.
594+
A :class:`frozenset` of subclasses of :class:`~pytools.tag.Tag`
595+
that reflects :attr:`~loopy.symbolic.TaggedVariable.tags` of
596+
an accessed variable.
596597
597598
.. attribute:: count_granularity
598599
@@ -610,27 +611,60 @@ class MemAccess(Record):
610611
"""
611612

612613
def __init__(self, mtype=None, dtype=None, lid_strides=None, gid_strides=None,
613-
direction=None, variable=None, variable_tag=None,
614+
direction=None, variable=None,
615+
*, variable_tags=None, variable_tag=None,
614616
count_granularity=None):
615617

616618
if count_granularity not in CountGranularity.ALL+[None]:
617619
raise ValueError("Op.__init__: count_granularity '%s' is "
618620
"not allowed. count_granularity options: %s"
619621
% (count_granularity, CountGranularity.ALL+[None]))
620622

623+
# {{{ normalize variable_tags
624+
625+
if variable_tags is not None and variable_tag is not None:
626+
raise TypeError(
627+
"may not specify both 'variable_tags' and 'variable_tag'")
628+
if variable_tag is not None:
629+
from loopy.kernel.creation import _normalize_string_tag
630+
variable_tags = frozenset({_normalize_string_tag(variable_tag)})
631+
632+
from warnings import warn
633+
warn("Passing 'variable_tag' to MemAccess is deprecated and will "
634+
"stop working in 2022. Pass variable_tags instead.")
635+
636+
if variable_tags is None:
637+
variable_tags = frozenset()
638+
639+
# }}}
640+
621641
if dtype is None:
622642
Record.__init__(self, mtype=mtype, dtype=dtype, lid_strides=lid_strides,
623643
gid_strides=gid_strides, direction=direction,
624-
variable=variable, variable_tag=variable_tag,
644+
variable=variable, variable_tags=variable_tags,
625645
count_granularity=count_granularity)
626646
else:
627647
from loopy.types import to_loopy_type
628648
Record.__init__(self, mtype=mtype, dtype=to_loopy_type(dtype),
629649
lid_strides=lid_strides, gid_strides=gid_strides,
630650
direction=direction, variable=variable,
631-
variable_tag=variable_tag,
651+
variable_tags=variable_tags,
632652
count_granularity=count_granularity)
633653

654+
@property
655+
def variable_tag(self):
656+
from warnings import warn
657+
warn("Accessing MemAccess.variable_tag is deprecated and will stop working "
658+
"in 2022. Use MemAccess.variable_tags instead.", DeprecationWarning,
659+
stacklevel=2)
660+
661+
if len(self.variable_tags) != 1:
662+
raise ValueError("cannot access MemAccess.variable_tag: access has "
663+
f"{len(self.variable_tags)} tags")
664+
665+
tag, = self.variable_tags
666+
return tag
667+
634668
def __hash__(self):
635669
# Note that this means lid_strides and gid_strides must be sorted
636670
# in self.__repr__()
@@ -647,7 +681,7 @@ def __repr__(self):
647681
sorted(self.gid_strides.items())),
648682
self.direction,
649683
self.variable,
650-
self.variable_tag,
684+
self.variable_tags,
651685
self.count_granularity)
652686

653687
# }}}
@@ -1031,9 +1065,9 @@ def map_variable(self, expr):
10311065
def map_subscript(self, expr):
10321066
name = expr.aggregate.name
10331067
try:
1034-
var_tag = expr.aggregate.tag
1068+
var_tags = expr.aggregate.tags
10351069
except AttributeError:
1036-
var_tag = None
1070+
var_tags = frozenset()
10371071

10381072
if name in self.knl.arg_dict:
10391073
array = self.knl.arg_dict[name]
@@ -1062,7 +1096,7 @@ def map_subscript(self, expr):
10621096
lid_strides=dict(sorted(lid_strides.items())),
10631097
gid_strides=dict(sorted(gid_strides.items())),
10641098
variable=name,
1065-
variable_tag=var_tag,
1099+
variable_tags=var_tags,
10661100
count_granularity=count_granularity
10671101
): 1}
10681102
) + self.rec(expr.index_tuple)
@@ -1678,7 +1712,7 @@ def get_mem_access_map(knl, numpy_types=True, count_redundant_work=False,
16781712
gid_strides=mem_access.gid_strides,
16791713
direction=mem_access.direction,
16801714
variable=mem_access.variable,
1681-
variable_tag=mem_access.variable_tag,
1715+
variable_tags=mem_access.variable_tags,
16821716
count_granularity=mem_access.count_granularity):
16831717
ct
16841718
for mem_access, ct in access_map.count_map.items()},

0 commit comments

Comments
 (0)