From 1f7947a4d2887481c36da666fdf8c495b0d4904a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 20 May 2022 10:01:18 -0500 Subject: [PATCH 1/4] THG's substitution applier --- pytato/target/loopy/codegen.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 09945286e..64a1741e2 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -50,6 +50,9 @@ from pytools.tag import Tag import pytato.reductions as red +from loopy.symbolic import IdentityMapper as LoopyIdentityMapper +from pymbolic.mapper.subst_applier import SubstitutionApplier + # set in doc/conf.py if getattr(sys, "PYTATO_BUILDING_SPHINX_DOCS", False): # Avoid import unless building docs to avoid creating a hard @@ -76,10 +79,11 @@ """ -def loopy_substitute(expression: Any, variable_assigments: Mapping[str, Any]) -> Any: - from loopy.symbolic import SubstitutionMapper - from pymbolic.mapper.substitutor import make_subst_func +class LoopySubstitutionApplier(SubstitutionApplier, LoopyIdentityMapper): + pass + +def loopy_substitute(expression: Any, variable_assigments: Mapping[str, Any]) -> Any: # {{{ early exit for identity substitution if all(isinstance(v, prim.Variable) and v.name == k @@ -89,7 +93,7 @@ def loopy_substitute(expression: Any, variable_assigments: Mapping[str, Any]) -> # }}} - return SubstitutionMapper(make_subst_func(variable_assigments))(expression) + return prim.Substitution(expression, *zip(*variable_assigments.items())) # SymbolicIndex and ShapeType are semantically distinct but identical at the @@ -787,7 +791,8 @@ def add_store(name: str, expr: Array, result: ImplementedResult, for d in range(expr.ndim)) indices = tuple(prim.Variable(iname) for iname in inames) loopy_expr_context = PersistentExpressionContext(state) - loopy_expr = result.to_loopy_expression(indices, loopy_expr_context) + loopy_expr = LoopySubstitutionApplier()( + result.to_loopy_expression(indices, loopy_expr_context)) # Make the instruction from loopy.kernel.instruction import make_assignment From fad967e1c53ff9b93c944823d5dd7bb09ddc48b5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Jun 2022 16:17:00 -0500 Subject: [PATCH 2/4] better caching --- pytato/target/loopy/codegen.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 64a1741e2..ea33c81e1 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -80,7 +80,8 @@ class LoopySubstitutionApplier(SubstitutionApplier, LoopyIdentityMapper): - pass + def get_cache_key(self, expr, current_substs): + return (type(expr), expr, tuple(sorted(current_substs.items()))) def loopy_substitute(expression: Any, variable_assigments: Mapping[str, Any]) -> Any: From b863dd43638fe776c926223ba4214abbcab0fce1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Jun 2022 16:29:31 -0500 Subject: [PATCH 3/4] mypy fixes --- pytato/target/loopy/codegen.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index ea33c81e1..49f558ce5 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -79,8 +79,12 @@ """ -class LoopySubstitutionApplier(SubstitutionApplier, LoopyIdentityMapper): - def get_cache_key(self, expr, current_substs): +# type-ignore-reason: superclasses have no type information +class LoopySubstitutionApplier( + SubstitutionApplier, LoopyIdentityMapper): # type: ignore + def get_cache_key(self, expr: ScalarExpression, + current_substs: Dict[ScalarExpression, ScalarExpression])\ + -> Tuple[Any, ScalarExpression, Any]: return (type(expr), expr, tuple(sorted(current_substs.items()))) From cd49225880c4ad5a8c8e3f07b05d1760ed9e4e5e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Jun 2022 16:33:25 -0500 Subject: [PATCH 4/4] spelling fix --- pytato/target/loopy/codegen.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 49f558ce5..d05c94b13 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -88,17 +88,17 @@ def get_cache_key(self, expr: ScalarExpression, return (type(expr), expr, tuple(sorted(current_substs.items()))) -def loopy_substitute(expression: Any, variable_assigments: Mapping[str, Any]) -> Any: +def loopy_substitute(expression: Any, variable_assignments: Mapping[str, Any]) -> Any: # {{{ early exit for identity substitution if all(isinstance(v, prim.Variable) and v.name == k - for k, v in variable_assigments.items()): + for k, v in variable_assignments.items()): # Nothing to do here, move on. return expression # }}} - return prim.Substitution(expression, *zip(*variable_assigments.items())) + return prim.Substitution(expression, *zip(*variable_assignments.items())) # SymbolicIndex and ShapeType are semantically distinct but identical at the