diff --git a/.pylintrc-local.yml b/.pylintrc-local.yml
index f5a171e96..ae0f92032 100644
--- a/.pylintrc-local.yml
+++ b/.pylintrc-local.yml
@@ -1,6 +1,5 @@
- arg: ignored-modules
val:
- - asciidag
- matplotlib
- ipykernel
- ply
diff --git a/doc/array.rst b/doc/array.rst
index a31fb2e53..56772affa 100644
--- a/doc/array.rst
+++ b/doc/array.rst
@@ -5,12 +5,17 @@ Array Expressions
.. automodule:: pytato.array
+Functions in Pytato IR
+----------------------
+
+.. automodule:: pytato.function
Raising :class:`~pytato.array.IndexLambda` nodes
------------------------------------------------
.. automodule:: pytato.raising
+
Calling :mod:`loopy` kernels in an array expression
---------------------------------------------------
diff --git a/examples/visualization.py b/examples/visualization.py
index 2b569a392..ac71e6060 100755
--- a/examples/visualization.py
+++ b/examples/visualization.py
@@ -22,8 +22,6 @@ def main():
stack = pt.stack([array, 2*array, array + 6])
result = stack @ stack.T
- pt.show_ascii_graph(result)
-
dot_code = pt.get_dot_graph(result)
with open(GRAPH_DOT, "w") as outf:
diff --git a/pytato/__init__.py b/pytato/__init__.py
index 0117d6ba9..572e4a7ab 100644
--- a/pytato/__init__.py
+++ b/pytato/__init__.py
@@ -89,12 +89,13 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.target.loopy import LoopyPyOpenCLTarget
from pytato.target.python.jax import generate_jax
from pytato.visualization import (get_dot_graph, show_dot_graph,
- get_ascii_graph, show_ascii_graph,
get_dot_graph_from_partition,
show_fancy_placeholder_data_flow,
)
+from pytato.transform.calls import tag_all_calls_to_be_inlined, inline_calls
import pytato.analysis as analysis
import pytato.tags as tags
+import pytato.function as function
import pytato.transform as transform
from pytato.distributed.nodes import (make_distributed_send, make_distributed_recv,
DistributedRecv, DistributedSend,
@@ -111,6 +112,7 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.transform.remove_broadcasts_einsum import (
rewrite_einsums_with_no_broadcasts)
from pytato.transform.metadata import unify_axes_tags
+from pytato.function import trace_call
__all__ = (
"dtype",
@@ -133,8 +135,8 @@ def set_debug_enabled(flag: bool) -> None:
"Target", "LoopyPyOpenCLTarget",
- "get_dot_graph", "show_dot_graph", "get_ascii_graph",
- "show_ascii_graph", "get_dot_graph_from_partition",
+ "get_dot_graph", "show_dot_graph",
+ "get_dot_graph_from_partition",
"show_fancy_placeholder_data_flow",
"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh",
@@ -156,11 +158,15 @@ def set_debug_enabled(flag: bool) -> None:
"broadcast_to", "pad",
+ "trace_call",
+
"make_distributed_recv", "make_distributed_send", "DistributedRecv",
"DistributedSend", "staple_distributed_send", "DistributedSendRefHolder",
"DistributedGraphPart",
"DistributedGraphPartition",
+ "tag_all_calls_to_be_inlined", "inline_calls",
+
"find_distributed_partition",
"number_distributed_tags",
@@ -175,6 +181,6 @@ def set_debug_enabled(flag: bool) -> None:
"unify_axes_tags",
# sub-modules
- "analysis", "tags", "transform",
+ "analysis", "tags", "transform", "function",
)
diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py
index a3c5e99ce..d0fd6ef1e 100644
--- a/pytato/analysis/__init__.py
+++ b/pytato/analysis/__init__.py
@@ -31,9 +31,11 @@
DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase,
ShapeType)
+from pytato.function import FunctionDefinition, Call
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
from pytato.loopy import LoopyCall
from pymbolic.mapper.optimize import optimize_mapper
+from pytools import memoize_method
if TYPE_CHECKING:
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
@@ -47,10 +49,14 @@
.. autofunction:: get_num_nodes
+.. autofunction:: get_num_call_sites
+
.. autoclass:: DirectPredecessorsGetter
"""
+# {{{ NUserCollector
+
class NUserCollector(Mapper):
"""
A :class:`pytato.transform.CachedWalkMapper` that records the number of
@@ -168,6 +174,8 @@ def map_distributed_recv(self, expr: DistributedRecv) -> None:
self.nusers[dim] += 1
self.rec(dim)
+# }}}
+
def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]:
"""
@@ -181,6 +189,8 @@ def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]:
return nuser_collector.nusers
+# {{{ is_einsum_similar_to_subscript
+
def _get_indices_from_input_subscript(subscript: str,
is_output: bool,
) -> Tuple[str, ...]:
@@ -208,8 +218,6 @@ def _get_indices_from_input_subscript(subscript: str,
# }}}
- # {{{
-
if is_output:
if len(normalized_indices) != len(set(normalized_indices)):
repeated_idx = next(idx
@@ -273,6 +281,8 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:
return True
+# }}}
+
# {{{ DirectPredecessorsGetter
@@ -388,3 +398,59 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
return ncm.count
# }}}
+
+
+# {{{ CallSiteCountMapper
+
+@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
+class CallSiteCountMapper(CachedWalkMapper):
+ """
+ Counts the number of :class:`~pytato.Call` nodes in a DAG.
+
+ .. attribute:: count
+
+ The number of nodes.
+ """
+
+ def __init__(self) -> None:
+ super().__init__()
+ self.count = 0
+
+ # type-ignore-reason: dropped the extra `*args, **kwargs`.
+ def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
+ return id(expr)
+
+ @memoize_method
+ def map_function_definition(self, /, expr: FunctionDefinition,
+ *args: Any, **kwargs: Any) -> None:
+ if not self.visit(expr):
+ return
+
+ new_mapper = self.clone_for_callee()
+ for subexpr in expr.returns.values():
+ new_mapper(subexpr, *args, **kwargs)
+
+ self.count += new_mapper.count
+
+ self.post_visit(expr, *args, **kwargs)
+
+ # type-ignore-reason: dropped the extra `*args, **kwargs`.
+ def post_visit(self, expr: Any) -> None: # type: ignore[override]
+ if isinstance(expr, Call):
+ self.count += 1
+
+
+def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int:
+ """Returns the number of nodes in DAG *outputs*."""
+
+ from pytato.codegen import normalize_outputs
+ outputs = normalize_outputs(outputs)
+
+ cscm = CallSiteCountMapper()
+ cscm(outputs)
+
+ return cscm.count
+
+# }}}
+
+# vim: fdm=marker
diff --git a/pytato/codegen.py b/pytato/codegen.py
index 63067fbd9..0bc85d649 100644
--- a/pytato/codegen.py
+++ b/pytato/codegen.py
@@ -23,7 +23,7 @@
"""
import dataclasses
-from typing import Union, Dict, Tuple, List, Any
+from typing import Union, Dict, Tuple, List, Any, Optional
from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder,
DataInterface, SizeParam, InputArgumentBase,
@@ -102,12 +102,17 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
====================================== =====================================
"""
- def __init__(self, target: Target) -> None:
+ def __init__(self, target: Target,
+ kernels_seen: Optional[Dict[str, lp.LoopKernel]] = None
+ ) -> None:
super().__init__()
self.bound_arguments: Dict[str, DataInterface] = {}
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
self.target = target
- self.kernels_seen: Dict[str, lp.LoopKernel] = {}
+ self.kernels_seen: Dict[str, lp.LoopKernel] = kernels_seen or {}
+
+ def clone_for_callee(self) -> CodeGenPreprocessor:
+ return CodeGenPreprocessor(self.target, self.kernels_seen)
def map_size_param(self, expr: SizeParam) -> Array:
name = expr.name
@@ -262,6 +267,7 @@ class PreprocessResult:
def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:
"""Preprocess a computation for code generation."""
from pytato.transform import copy_dict_of_named_arrays
+ from pytato.transform.calls import inline_calls
check_validity_of_outputs(outputs)
@@ -289,12 +295,14 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:
# }}}
- mapper = CodeGenPreprocessor(target)
+ new_outputs = inline_calls(outputs)
+ assert isinstance(new_outputs, DictOfNamedArrays)
- new_outputs = copy_dict_of_named_arrays(outputs, mapper)
+ mapper = CodeGenPreprocessor(target)
+ new_outputs = copy_dict_of_named_arrays(new_outputs, mapper)
return PreprocessResult(outputs=new_outputs,
- compute_order=tuple(output_order),
- bound_arguments=mapper.bound_arguments)
+ compute_order=tuple(output_order),
+ bound_arguments=mapper.bound_arguments)
# vim: fdm=marker
diff --git a/pytato/distributed/execute.py b/pytato/distributed/execute.py
index 79444cd22..fd1e6c6c4 100644
--- a/pytato/distributed/execute.py
+++ b/pytato/distributed/execute.py
@@ -231,7 +231,7 @@ def wait_for_some_recvs() -> None:
assert count == 0
assert name not in context
- return context
+ return {name: context[name] for name in partition.overall_output_names}
# }}}
diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py
index a59bd0e40..c8ab048f3 100644
--- a/pytato/distributed/partition.py
+++ b/pytato/distributed/partition.py
@@ -249,10 +249,21 @@ class DistributedGraphPartition:
.. attribute:: name_to_output
Mapping of placeholder names to the respective :class:`pytato.array.Array`
- they represent.
+ they represent. This is where the actual expressions are stored, for
+ all parts. Observe that the :class:`DistributedGraphPart`, for the most
+ part, only stores names. These "outputs" may be 'part outputs' (i.e.
+ data computed in one part for use by another, effectively tempoarary
+ variables), or 'overall outputs' of the comutation.
+
+ .. attribute:: overall_output_names
+
+ The names of the outputs (in :attr:`name_to_output`) that were given to
+ :func:`find_distributed_partition` to specify the overall computaiton.
+
"""
parts: Mapping[PartId, DistributedGraphPart]
name_to_output: Mapping[str, Array]
+ overall_output_names: Sequence[str]
# }}}
@@ -365,6 +376,7 @@ def _make_distributed_partition(
sptpo_ary_to_name: Mapping[Array, str],
local_recv_id_to_recv_node: Dict[CommunicationOpIdentifier, DistributedRecv],
local_send_id_to_send_node: Dict[CommunicationOpIdentifier, DistributedSend],
+ overall_output_names: Sequence[str],
) -> DistributedGraphPartition:
name_to_output = {}
parts: Dict[PartId, DistributedGraphPart] = {}
@@ -402,6 +414,7 @@ def _make_distributed_partition(
result = DistributedGraphPartition(
parts=parts,
name_to_output=name_to_output,
+ overall_output_names=overall_output_names,
)
return result
@@ -967,7 +980,8 @@ def gen_array_name(ary: Array) -> str:
sent_ary_to_name,
sptpo_ary_to_name,
lsrdg.local_recv_id_to_recv_node,
- lsrdg.local_send_id_to_send_node)
+ lsrdg.local_send_id_to_send_node,
+ tuple(outputs))
from pytato.distributed.verify import _run_partition_diagnostics
_run_partition_diagnostics(outputs, partition)
diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py
index 4fa419cc2..9e3bde8d0 100644
--- a/pytato/distributed/tags.py
+++ b/pytato/distributed/tags.py
@@ -123,6 +123,7 @@ def set_union(
for pid, part in partition.parts.items()
},
name_to_output=partition.name_to_output,
+ overall_output_names=partition.overall_output_names,
), next_tag
# }}}
diff --git a/pytato/equality.py b/pytato/equality.py
index 5eb1814c6..42c2978cd 100644
--- a/pytato/equality.py
+++ b/pytato/equality.py
@@ -31,6 +31,8 @@
IndexBase, IndexLambda, NamedArray,
Reshape, Roll, Stack, AbstractResultWithNamedArrays,
Array, DictOfNamedArrays, Placeholder, SizeParam)
+from pytato.function import Call, NamedCallResult, FunctionDefinition
+from pytools import memoize_method
if TYPE_CHECKING:
from pytato.loopy import LoopyCall, LoopyCallResult
@@ -273,6 +275,31 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
and expr1.tags == expr2.tags
)
+ @memoize_method
+ def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
+ ) -> bool:
+ return (expr1.__class__ is expr2.__class__
+ and expr1.parameters == expr2.parameters
+ and (set(expr1.returns.keys()) == set(expr2.returns.keys()))
+ and all(self.rec(expr1.returns[k], expr2.returns[k])
+ for k in expr1.returns)
+ and expr1.tags == expr2.tags
+ )
+
+ def map_call(self, expr1: Call, expr2: Any) -> bool:
+ return (expr1.__class__ is expr2.__class__
+ and self.map_function_definition(expr1.function, expr2.function)
+ and frozenset(expr1.bindings) == frozenset(expr2.bindings)
+ and all(self.rec(bnd,
+ expr2.bindings[name])
+ for name, bnd in expr1.bindings.items())
+ )
+
+ def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool:
+ return (expr1.__class__ is expr2.__class__
+ and expr1.name == expr2.name
+ and self.rec(expr1._container, expr2._container))
+
# }}}
# vim: fdm=marker
diff --git a/pytato/function.py b/pytato/function.py
new file mode 100644
index 000000000..004aac18a
--- /dev/null
+++ b/pytato/function.py
@@ -0,0 +1,387 @@
+from __future__ import annotations
+
+__doc__ = """
+.. currentmodule:: pytato
+
+.. autofunction:: trace_call
+
+.. currentmodule:: pytato.function
+
+.. autoclass:: Call
+.. autoclass:: NamedCallResult
+.. autoclass:: FunctionDefinition
+.. autoclass:: ReturnType
+
+.. class:: ReturnT
+
+ A type variable corresponding to the return type of the function
+ :func:`pytato.trace_call`.
+"""
+
+__copyright__ = """
+Copyright (C) 2022 Andreas Kloeckner
+Copyright (C) 2022 Kaushik Kulkarni
+"""
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+import attrs
+import re
+import enum
+
+from typing import (Callable, Dict, FrozenSet, Tuple, Union, TypeVar, Optional,
+ Hashable, Sequence, ClassVar, Iterator, Iterable, Mapping)
+from immutables import Map
+from functools import cached_property
+from pytato.array import (Array, AbstractResultWithNamedArrays,
+ Placeholder, NamedArray, ShapeType, _dtype_any,
+ InputArgumentBase)
+from pytools.tag import Tag, Taggable
+
+ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array])
+
+
+# {{{ Call/NamedCallResult
+
+
+@enum.unique
+class ReturnType(enum.Enum):
+ """
+ Records the function body's return type in :class:`FunctionDefinition`.
+ """
+ ARRAY = 0
+ DICT_OF_ARRAYS = 1
+ TUPLE_OF_ARRAYS = 2
+
+
+# eq=False to avoid equality comparison without EqualityMaper
+@attrs.define(frozen=True, eq=False, hash=True)
+class FunctionDefinition(Taggable):
+ r"""
+ A function definition that represents its outputs as instances of
+ :class:`~pytato.Array` with the inputs being
+ :class:`~pytato.array.Placeholder`\ s. The outputs of the function
+ can be a single :class:`pytato.Array`, a tuple of :class:`pytato.Array`\ s or an
+ instance of ``Dict[str, Array]``.
+
+ .. attribute:: parameters
+
+ Names of the input :class:`~pytato.array.Placeholder`\ s to the
+ function node. This is a superset of the names of
+ :class:`~pytato.array.Placeholder` instances encountered in
+ :attr:`returns`. Unused parameters are allowed.
+
+ .. attribute:: return_type
+
+ An instance of :class:`ReturnType`.
+
+ .. attribute:: returns
+
+ The outputs of the function call which are array expressions that
+ depend on the *parameters*. The keys of the mapping depend on
+ :attr:`return_type` as:
+
+ - If the function returns a single :class:`pytato.Array`, then
+ *returns* contains a single array expression with ``"_"`` as the
+ key.
+ - If the function returns a :class:`tuple` of
+ :class:`pytato.Array`\ s, then *returns* contains entries with
+ the key ``"_N"`` mapping the ``N``-th entry of the result-tuple.
+ - If the function returns a :class:`dict` mapping identifiers to
+ :class:`pytato.Array`\ s, then *returns* uses the same mapping.
+
+ .. automethod:: get_placeholder
+
+ .. note::
+
+ A :class:`FunctionDefinition` comes with its own namespace based on
+ :attr:`parameters`. A :class:`~pytato.transform.Mapper`-implementer
+ must ensure **not** to reuse the cached result between the caller's
+ expressions and a function definition's expressions to avoid unsound
+ cache hits that could lead to incorrect mappings.
+
+ .. note::
+
+ At this point, code generation/execution does not support
+ distributed-memory communication nodes (:class:`~pytato.DistributedSend`,
+ :class:`~pytato.DistributedRecv`) within function bodies.
+ """
+ parameters: FrozenSet[str]
+ return_type: ReturnType
+ returns: Map[str, Array]
+ tags: FrozenSet[Tag] = attrs.field(kw_only=True)
+
+ @cached_property
+ def _placeholders(self) -> Mapping[str, Placeholder]:
+ from pytato.transform import InputGatherer
+ from functools import reduce
+
+ mapper = InputGatherer()
+
+ all_input_args: FrozenSet[InputArgumentBase] = reduce(
+ frozenset.union,
+ (mapper(ary) for ary in self.returns.values()),
+ frozenset()
+ )
+
+ return Map({input_arg.name: input_arg
+ for input_arg in all_input_args
+ if isinstance(input_arg, Placeholder)})
+
+ def get_placeholder(self, name: str) -> Placeholder:
+ """
+ Returns the instance of :class:`pytato.array.Placeholder` corresponding
+ to the parameter *name* in function body.
+ """
+ return self._placeholders[name]
+
+ def _with_new_tags(
+ self: FunctionDefinition, tags: FrozenSet[Tag]) -> FunctionDefinition:
+ return attrs.evolve(self, tags=tags)
+
+ def __call__(self, **kwargs: Array
+ ) -> Union[Array,
+ Tuple[Array, ...],
+ Dict[str, Array]]:
+ from pytato.array import _get_default_tags
+ from pytato.utils import are_shapes_equal
+
+ # {{{ sanity checks
+
+ if self.parameters != frozenset(kwargs):
+ missing_params = self.parameters - frozenset(kwargs)
+ extra_params = self.parameters - frozenset(kwargs)
+
+ raise TypeError(
+ "Incorrect arguments."
+ + (f" Missing: '{missing_params}'." if missing_params else "")
+ + (f" Extra: '{extra_params}'." if extra_params else "")
+ )
+
+ for argname, expected_arg in self._placeholders.items():
+ if expected_arg.dtype != kwargs[argname].dtype:
+ raise ValueError(f"Argument '{argname}' expected to "
+ f" be of type '{expected_arg.dtype}', got"
+ f" '{kwargs[argname].dtype}'.")
+ if not are_shapes_equal(expected_arg.shape, kwargs[argname].shape):
+ raise ValueError(f"Argument '{argname}' expected to "
+ f" have shape '{expected_arg.shape}', got"
+ f" '{kwargs[argname].shape}'.")
+
+ # }}}
+
+ call_site = Call(self, bindings=Map(kwargs), tags=_get_default_tags())
+
+ if self.return_type == ReturnType.ARRAY:
+ return call_site["_"]
+ elif self.return_type == ReturnType.TUPLE_OF_ARRAYS:
+ return tuple(call_site[f"_{iarg}"]
+ for iarg in range(len(self.returns)))
+ elif self.return_type == ReturnType.DICT_OF_ARRAYS:
+ return {kw: call_site[kw] for kw in self.returns}
+ else:
+ raise NotImplementedError(self.return_type)
+
+
+class NamedCallResult(NamedArray):
+ """
+ One of the arrays that are returned from a call to :class:`FunctionDefinition`.
+
+ .. attribute:: call
+
+ The function invocation that led to *self*.
+
+ .. attribute:: name
+
+ The name by which the returned array is referred to in
+ :attr:`FunctionDefinition.returns`.
+ """
+ call: Call
+ name: str
+ _mapper_method: ClassVar[str] = "map_named_call_result"
+
+ def __init__(self,
+ call: Call,
+ name: str) -> None:
+ super().__init__(call, name,
+ axes=call.function.returns[name].axes,
+ tags=call.function.returns[name].tags)
+
+ def with_tagged_axis(self, iaxis: int,
+ tags: Union[Sequence[Tag], Tag]) -> Array:
+ raise ValueError("Tagging a NamedCallResult's axis is illegal, use"
+ " Call.with_tagged_axis instead")
+
+ def tagged(self,
+ tags: Union[Iterable[Tag], Tag, None]) -> NamedCallResult:
+ raise ValueError("Tagging a NamedCallResult is illegal, use"
+ " Call.tagged instead")
+
+ def without_tags(self,
+ tags: Union[Iterable[Tag], Tag, None],
+ verify_existence: bool = True,
+ ) -> NamedCallResult:
+ raise ValueError("Untagging a NamedCallResult is illegal, use"
+ " Call.without_tags instead")
+
+ @property
+ def shape(self) -> ShapeType:
+ assert isinstance(self._container, Call)
+ return self._container.function.returns[self.name].shape
+
+ @property
+ def dtype(self) -> _dtype_any:
+ assert isinstance(self._container, Call)
+ return self._container.function.returns[self.name].dtype
+
+
+# eq=False to avoid equality comparison without EqualityMaper
+@attrs.define(frozen=True, eq=False, hash=True, cache_hash=True, repr=False)
+class Call(AbstractResultWithNamedArrays):
+ """
+ Records an invocation to a :class:`FunctionDefinition`.
+
+ .. attribute:: function
+
+ The instance of :class:`FunctionDefinition` being called by this call site.
+
+ .. attribute:: bindings
+
+ A mapping from the placeholder names of :class:`FunctionDefinition` to
+ their corresponding parameters in the invocation to :attr:`function`.
+
+ """
+ function: FunctionDefinition
+ bindings: Map[str, Array]
+
+ _mapper_method: ClassVar[str] = "map_call"
+
+ copy = attrs.evolve
+
+ def __post_init__(self) -> None:
+ # check that the invocation parameters and the function definition
+ # parameters agree with each other.
+ assert frozenset(self.bindings) == self.function.parameters
+ super().__post_init__()
+
+ def __contains__(self, name: object) -> bool:
+ return name in self.function.returns
+
+ def __iter__(self) -> Iterator[str]:
+ return iter(self.function.returns)
+
+ def __getitem__(self, name: str) -> NamedCallResult:
+ return NamedCallResult(self, name)
+
+ def __len__(self) -> int:
+ return len(self.function.returns)
+
+ def _with_new_tags(self: Call, tags: FrozenSet[Tag]) -> Call:
+ return attrs.evolve(self, tags=tags)
+
+# }}}
+
+
+# {{{ user-facing routines
+
+class _Guess:
+ pass
+
+
+RE_ARGNAME = re.compile(r"^_pt_(\d+)$")
+
+
+def trace_call(f: Callable[..., ReturnT],
+ *args: Array,
+ identifier: Optional[Hashable] = _Guess,
+ **kwargs: Array) -> ReturnT:
+ """
+ Returns the expressions returned after calling *f* with the arguments
+ *args* and keyword arguments *kwargs*. The subexpressions in the returned
+ expressions are outlined (opposite of 'inlined') as a
+ :class:`~pytato.function.FunctionDefinition`.
+
+ :arg identifier: A hashable object that acts as
+ :attr:`pytato.tags.FunctionIdentifier.identifier` for the
+ :class:`~pytato.tags.FunctionIdentifier` tagged to the outlined
+ :class:`~pytato.function.FunctionDefinition`. If ``None`` the function
+ definition is not tagged with a
+ :class:`~pytato.tags.FunctionIdentifier` tag, if ``_Guess`` the
+ function identifier is guessed from ``f.__name__``.
+ """
+ from pytato.tags import FunctionIdentifier
+ from pytato.array import _get_default_tags
+
+ if identifier is _Guess:
+ # partials might not have a __name__ attribute
+ identifier = getattr(f, "__name__", None)
+
+ for kw in kwargs:
+ if RE_ARGNAME.match(kw):
+ # avoid collision between argument names
+ raise ValueError(f"Kw argument named '{kw}' not allowed.")
+
+ # Get placeholders from the ``args``, ``kwargs``.
+ pl_args = tuple(Placeholder(f"in__pt_{iarg}", arg.shape, arg.dtype,
+ axes=arg.axes, tags=arg.tags)
+ for iarg, arg in enumerate(args))
+ pl_kwargs = {kw: Placeholder(f"in_{kw}", arg.shape,
+ arg.dtype, axes=arg.axes, tags=arg.tags)
+ for kw, arg in kwargs.items()}
+
+ # Pass the placeholders
+ output = f(*pl_args, **pl_kwargs)
+
+ if isinstance(output, Array):
+ returns = {"_": output}
+ return_type = ReturnType.ARRAY
+ elif isinstance(output, tuple):
+ assert all(isinstance(el, Array) for el in output)
+ returns = {f"_{iout}": out for iout, out in enumerate(output)}
+ return_type = ReturnType.TUPLE_OF_ARRAYS
+ elif isinstance(output, dict):
+ assert all(isinstance(el, Array) for el in output.values())
+ returns = output
+ return_type = ReturnType.DICT_OF_ARRAYS
+ else:
+ raise ValueError("The function being traced must return one of"
+ f"pytato.Array, tuple, dict. Got {type(output)}.")
+
+ # construct the function
+ function = FunctionDefinition(
+ frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs),
+ return_type,
+ Map(returns),
+ tags=_get_default_tags() | (frozenset([FunctionIdentifier(identifier)])
+ if identifier
+ else frozenset())
+ )
+
+ # type-ignore-reason: return type is dependent on dynamic state i.e.
+ # ret_type and hence mypy is unhappy
+ return function( # type: ignore[return-value]
+ **{pl.name: arg for pl, arg in zip(pl_args, args)},
+ **{pl_kwargs[kw].name: arg for kw, arg in kwargs.items()}
+ )
+
+# }}}
+
+# vim:foldmethod=marker
diff --git a/pytato/tags.py b/pytato/tags.py
index 8d080f66e..1e1180ebe 100644
--- a/pytato/tags.py
+++ b/pytato/tags.py
@@ -10,9 +10,12 @@
.. autoclass:: PrefixNamed
.. autoclass:: AssumeNonNegative
.. autoclass:: ExpandedDimsReshape
+.. autoclass:: FunctionIdentifier
+.. autoclass:: CallImplementationTag
+.. autoclass:: InlineCallTag
"""
-from typing import Tuple
+from typing import Tuple, Hashable
from pytools.tag import Tag, UniqueTag
from dataclasses import dataclass
@@ -125,3 +128,33 @@ class ExpandedDimsReshape(UniqueTag):
frozenset({ExpandedDimsReshape(new_dims=(0, 2, 4))})
"""
new_dims: Tuple[int, ...]
+
+
+@dataclass(frozen=True)
+class FunctionIdentifier(UniqueTag):
+ """
+ A tag that can be attached to a :class:`~pytato.function.FunctionDefinition`
+ node to to describe the function's identifier. One can use this to refer
+ all instances of :class:`~pytato.function.FunctionDefinition`, for example in
+ transformations.transform.calls.concatenate_calls`.
+
+ .. attribute:: identifier
+ """
+ identifier: Hashable
+
+
+@dataclass(frozen=True)
+class CallImplementationTag(UniqueTag):
+ """
+ A tag that can be attached to a :class:`~pytato.function.Call` node to
+ direct a :class:`~pytato.target.Target` how the call site should be
+ lowered.
+ """
+
+
+@dataclass(frozen=True)
+class InlineCallTag(CallImplementationTag):
+ r"""
+ A :class:`CallImplementationTag` that directs the
+ :class:`pytato.target.Target` to inline the call site.
+ """
diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py
index ad7b73491..7d8760ed6 100644
--- a/pytato/target/loopy/codegen.py
+++ b/pytato/target/loopy/codegen.py
@@ -46,6 +46,7 @@
from pytato.transform import Mapper
from pytato.scalar_expr import ScalarExpression, INT_CLASSES
from pytato.codegen import preprocess, normalize_outputs, SymbolicIndex
+from pytato.function import Call, NamedCallResult
from pytato.loopy import LoopyCall
from pytato.tags import (ImplStored, ImplInlined, Named, PrefixNamed,
ImplementationStrategy)
@@ -575,6 +576,19 @@ def _get_sub_array_ref(array: Array, name: str) -> "lp.symbolic.SubArrayRef":
state.update_kernel(kernel)
+ def map_named_call_result(self, expr: NamedCallResult,
+ state: CodeGenState) -> None:
+ raise NotImplementedError("LoopyTarget does not support outlined calls"
+ " (yet). As a fallback, the call"
+ " could be inlined using"
+ " pt.mark_all_calls_to_be_inlined.")
+
+ def map_call(self, expr: Call, state: CodeGenState) -> None:
+ raise NotImplementedError("LoopyTarget does not support outlined calls"
+ " (yet). As a fallback, the call"
+ " could be inlined using"
+ " pt.mark_all_calls_to_be_inlined.")
+
# }}}
@@ -972,36 +986,30 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]],
.. note::
- :mod:`pytato` metadata :math:`\mapsto` :mod:`loopy` metadata semantics:
-
- - Inames that index over an :class:`~pytato.array.Array`'s axis in the
- allocation instruction are tagged with the corresponding
- :class:`~pytato.array.Axis`'s tags. The caller may choose to not
- propagate axis tags of type *axis_tag_t_to_not_propagate*.
- - :attr:`pytato.Array.tags` of inputs/outputs in *outputs*
- would be copied over to the tags of the corresponding
- :class:`loopy.ArrayArg`. The caller may choose to not
- propagate array tags of type *array_tag_t_to_not_propagate*.
- - Arrays tagged with :class:`pytato.tags.ImplStored` would have their
- tags copied over to the tags of corresponding
- :class:`loopy.TemporaryVariable`. The caller may choose to not
- propagate array tags of type *array_tag_t_to_not_propagate*.
+ - :mod:`pytato` metadata :math:`\mapsto` :mod:`loopy` metadata semantics:
+
+ - Inames that index over an :class:`~pytato.array.Array`'s axis in the
+ allocation instruction are tagged with the corresponding
+ :class:`~pytato.array.Axis`'s tags. The caller may choose to not
+ propagate axis tags of type *axis_tag_t_to_not_propagate*.
+ - :attr:`pytato.Array.tags` of inputs/outputs in *outputs*
+ would be copied over to the tags of the corresponding
+ :class:`loopy.ArrayArg`. The caller may choose to not
+ propagate array tags of type *array_tag_t_to_not_propagate*.
+ - Arrays tagged with :class:`pytato.tags.ImplStored` would have their
+ tags copied over to the tags of corresponding
+ :class:`loopy.TemporaryVariable`. The caller may choose to not
+ propagate array tags of type *array_tag_t_to_not_propagate*.
+
+ .. warning::
+
+ Currently only :class:`~pytato.function.Call` nodes that are tagged with
+ :class:`pytato.tags.InlineCallTag` can be lowered to :mod:`loopy` IR.
"""
result_is_dict = isinstance(result, (dict, DictOfNamedArrays))
orig_outputs: DictOfNamedArrays = normalize_outputs(result)
- # optimization: remove any ImplStored tags on outputs to avoid redundant
- # store-load operations (see https://github.com/inducer/pytato/issues/415)
- orig_outputs = DictOfNamedArrays(
- {name: (output.without_tags(ImplStored(),
- verify_existence=False)
- if not isinstance(output,
- InputArgumentBase)
- else output)
- for name, output in orig_outputs._data.items()},
- tags=orig_outputs.tags)
-
del result
if cl_device is not None:
@@ -1017,6 +1025,18 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]],
preproc_result = preprocess(orig_outputs, target)
outputs = preproc_result.outputs
+ # optimization: remove any ImplStored tags on outputs to avoid redundant
+ # store-load operations (see https://github.com/inducer/pytato/issues/415)
+ # (This must be done after all the calls have been inlined)
+ outputs = DictOfNamedArrays(
+ {name: (output.without_tags(ImplStored(),
+ verify_existence=False)
+ if not isinstance(output,
+ InputArgumentBase)
+ else output)
+ for name, output in outputs._data.items()},
+ tags=outputs.tags)
+
compute_order = preproc_result.compute_order
if options is None:
diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py
index a75083f9a..953cbfa46 100644
--- a/pytato/transform/__init__.py
+++ b/pytato/transform/__init__.py
@@ -1,5 +1,7 @@
from __future__ import annotations
+from pytools import memoize_method
+
__copyright__ = """
Copyright (C) 2020 Matt Wala
Copyright (C) 2020-21 Kaushik Kulkarni
@@ -28,6 +30,7 @@
import logging
import numpy as np
+from immutables import Map
from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic,
List, Mapping, Iterable, Tuple, Optional,
Hashable)
@@ -43,10 +46,12 @@
from pytato.distributed.nodes import (
DistributedSendRefHolder, DistributedRecv, DistributedSend)
from pytato.loopy import LoopyCall, LoopyCallResult
+from pytato.function import Call, NamedCallResult, FunctionDefinition
from dataclasses import dataclass
from pytato.tags import ImplStored
from pymbolic.mapper.optimize import optimize_mapper
+
ArrayOrNames = Union[Array, AbstractResultWithNamedArrays]
MappedT = TypeVar("MappedT",
Array, AbstractResultWithNamedArrays, ArrayOrNames)
@@ -56,6 +61,7 @@
CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper
IndexOrShapeExpr = TypeVar("IndexOrShapeExpr")
R = FrozenSet[Array]
+_SelfMapper = TypeVar("_SelfMapper", bound="Mapper")
__doc__ = """
.. currentmodule:: pytato.transform
@@ -90,6 +96,14 @@
.. autofunction:: tag_user_nodes
.. autofunction:: rec_get_user_nodes
+
+Transforming call sites
+-----------------------
+
+.. automodule:: pytato.transform.calls
+
+.. currentmodule:: pytato.transform
+
Internal stuff that is only here because the documentation tool wants it
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -100,6 +114,11 @@
.. class:: CombineT
A type variable representing the type of a :class:`CombineMapper`.
+
+.. class:: _SelfMapper
+
+ A type variable used to represent the type of a mapper in
+ :meth:`CopyMapper.clone_for_callee`.
"""
transform_logger = logging.getLogger(__file__)
@@ -213,6 +232,8 @@ class CopyMapper(CachedMapper[ArrayOrNames]):
The typical use of this mapper is to override individual ``map_`` methods
in subclasses to permit term rewriting on an expression graph.
+ .. automethod:: clone_for_callee
+
.. note::
This does not copy the data of a :class:`pytato.array.DataWrapper`.
@@ -229,6 +250,13 @@ def __call__(self, # type: ignore[override]
expr: CopyMapperResultT) -> CopyMapperResultT:
return self.rec(expr)
+ def clone_for_callee(self: _SelfMapper) -> _SelfMapper:
+ """
+ Called to clone *self* before starting traversal of a
+ :class:`pytato.function.FunctionDefinition`.
+ """
+ return type(self)()
+
def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...]
) -> Tuple[IndexOrShapeExpr, ...]:
# type-ignore-reason: apparently mypy cannot substitute typevars
@@ -372,6 +400,32 @@ def map_distributed_recv(self, expr: DistributedRecv) -> Array:
shape=self.rec_idx_or_size_tuple(expr.shape),
dtype=expr.dtype, tags=expr.tags, axes=expr.axes)
+ @memoize_method
+ def map_function_definition(self,
+ expr: FunctionDefinition) -> FunctionDefinition:
+ # spawn a new mapper to avoid unsound cache hits, since the namespace of the
+ # function's body is different from that of the caller.
+ new_mapper = self.clone_for_callee()
+ new_returns = {name: new_mapper(ret)
+ for name, ret in expr.returns.items()}
+ return FunctionDefinition(expr.parameters,
+ expr.return_type,
+ Map(new_returns),
+ tags=expr.tags
+ )
+
+ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays:
+ return Call(self.map_function_definition(expr.function),
+ Map({name: self.rec(bnd)
+ for name, bnd in expr.bindings.items()}),
+ tags=expr.tags,
+ )
+
+ def map_named_call_result(self, expr: NamedCallResult) -> Array:
+ call = self.rec(expr._container)
+ assert isinstance(call, Call)
+ return NamedCallResult(call, expr.name)
+
class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]):
"""
@@ -575,6 +629,26 @@ def map_distributed_recv(self, expr: DistributedRecv,
shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs),
dtype=expr.dtype, tags=expr.tags, axes=expr.axes)
+ def map_function_definition(self, expr: FunctionDefinition,
+ *args: Any, **kwargs: Any) -> FunctionDefinition:
+ raise NotImplementedError("Function definitions are purposefully left"
+ " unimplemented as the default arguments to a new"
+ " DAG traversal are tricky to guess.")
+
+ def map_call(self, expr: Call,
+ *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays:
+ return Call(self.map_function_definition(expr.function, *args, **kwargs),
+ Map({name: self.rec(bnd, *args, **kwargs)
+ for name, bnd in expr.bindings.items()}),
+ tags=expr.tags,
+ )
+
+ def map_named_call_result(self, expr: NamedCallResult,
+ *args: Any, **kwargs: Any) -> Array:
+ call = self.rec(expr._container, *args, **kwargs)
+ assert isinstance(call, Call)
+ return NamedCallResult(call, expr.name)
+
# }}}
@@ -685,6 +759,18 @@ def map_distributed_send_ref_holder(
def map_distributed_recv(self, expr: DistributedRecv) -> CombineT:
return self.combine(*self.rec_idx_or_size_tuple(expr.shape))
+ def map_function_definition(self, expr: FunctionDefinition) -> CombineT:
+ raise NotImplementedError("Combining results from a callee expression"
+ " is context-dependent. Derived classes"
+ " must override map_function_definition.")
+
+ def map_call(self, expr: Call) -> CombineT:
+ return self.combine(self.map_function_definition(expr.function),
+ *[self.rec(bnd) for bnd in expr.bindings.values()])
+
+ def map_named_call_result(self, expr: NamedCallResult) -> CombineT:
+ return self.rec(expr._container)
+
# }}}
@@ -752,6 +838,18 @@ def map_distributed_send_ref_holder(
def map_distributed_recv(self, expr: DistributedRecv) -> R:
return self.combine(frozenset([expr]), super().map_distributed_recv(expr))
+ def map_function_definition(self, expr: FunctionDefinition) -> R:
+ # do not include arrays from the function's body as it would involve
+ # putting arrays from different namespaces into the same collection.
+ return frozenset()
+
+ def map_call(self, expr: Call) -> R:
+ return self.combine(self.map_function_definition(expr.function),
+ *[self.rec(bnd) for bnd in expr.bindings.values()])
+
+ def map_named_call_result(self, expr: NamedCallResult) -> R:
+ return self.rec(expr._container)
+
# }}}
@@ -796,6 +894,27 @@ def map_data_wrapper(self, expr: DataWrapper) -> FrozenSet[InputArgumentBase]:
def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]:
return frozenset([expr])
+ @memoize_method
+ def map_function_definition(self, expr: FunctionDefinition
+ ) -> FrozenSet[InputArgumentBase]:
+ # get rid of placeholders local to the function.
+ new_mapper = InputGatherer()
+ all_callee_inputs = new_mapper.combine(*[new_mapper(ret)
+ for ret in expr.returns.values()])
+ result: Set[InputArgumentBase] = set()
+ for inp in all_callee_inputs:
+ if isinstance(inp, Placeholder):
+ if inp.name in expr.parameters:
+ # drop, reference to argument
+ pass
+ else:
+ raise ValueError("function definition refers to non-argument "
+ f"placeholder named '{inp.name}'")
+ else:
+ result.add(inp)
+
+ return frozenset(result)
+
# }}}
@@ -814,6 +933,12 @@ def combine(self, *args: FrozenSet[SizeParam]
def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]:
return frozenset([expr])
+ @memoize_method
+ def map_function_definition(self, expr: FunctionDefinition
+ ) -> FrozenSet[SizeParam]:
+ return self.combine(*[self.rec(ret)
+ for ret in expr.returns.values()])
+
# }}}
@@ -830,6 +955,9 @@ class WalkMapper(Mapper):
.. automethod:: post_visit
"""
+ def clone_for_callee(self: _SelfMapper) -> _SelfMapper:
+ return type(self)()
+
def visit(self, expr: Any, *args: Any, **kwargs: Any) -> bool:
"""
If this method returns *True*, *expr* is traversed during the walk.
@@ -982,6 +1110,36 @@ def map_loopy_call(self, expr: LoopyCall, *args: Any, **kwargs: Any) -> None:
self.post_visit(expr, *args, **kwargs)
+ def map_function_definition(self, expr: FunctionDefinition,
+ *args: Any, **kwargs: Any) -> None:
+ if not self.visit(expr):
+ return
+
+ new_mapper = self.clone_for_callee()
+ for subexpr in expr.returns.values():
+ new_mapper(subexpr, *args, **kwargs)
+
+ self.post_visit(expr, *args, **kwargs)
+
+ def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> None:
+ if not self.visit(expr):
+ return
+
+ self.map_function_definition(expr.function)
+ for bnd in expr.bindings.values():
+ self.rec(bnd)
+
+ self.post_visit(expr)
+
+ def map_named_call_result(self, expr: NamedCallResult,
+ *args: Any, **kwargs: Any) -> None:
+ if not self.visit(expr, *args, **kwargs):
+ return
+
+ self.rec(expr._container, *args, **kwargs)
+
+ self.post_visit(expr, *args, **kwargs)
+
# }}}
@@ -1019,6 +1177,11 @@ class TopoSortMapper(CachedWalkMapper):
"""A mapper that creates a list of nodes in topological order.
:members: topological_order
+
+ .. note::
+
+ Does not consider the nodes inside a
+ :class:`~pytato.function.FunctionDefinition`.
"""
def __init__(self) -> None:
@@ -1033,6 +1196,12 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
def post_visit(self, expr: Any) -> None: # type: ignore[override]
self.topological_order.append(expr)
+ # type-ignore-reason: dropped the extra `*args, **kwargs`.
+ def map_function_definition(self, # type: ignore[override]
+ expr: FunctionDefinition) -> None:
+ # do nothing as it includes arrays from a different namespace.
+ return
+
# }}}
@@ -1048,6 +1217,11 @@ def __init__(self, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) -> None:
super().__init__()
self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn
+ def clone_for_callee(self: _SelfMapper) -> _SelfMapper:
+ # type-ignore-reason: self.__init__ has a different function signature
+ # than Mapper.__init__ and does not have map_fn
+ return type(self)(self.map_fn) # type: ignore[call-arg,attr-defined]
+
# type-ignore-reason:incompatible with Mapper.rec()
def rec(self, expr: MappedT) -> MappedT: # type: ignore[override]
if expr in self._cache:
@@ -1102,7 +1276,14 @@ def _materialize_if_mpms(expr: Array,
class MPMSMaterializer(Mapper):
- """See :func:`materialize_with_mpms` for an explanation."""
+ """
+ See :func:`materialize_with_mpms` for an explanation.
+
+ .. attribute:: nsuccessors
+
+ A mapping from a node in the expression graph (i.e. an
+ :class:`~pytato.Array`) to its number of successors.
+ """
def __init__(self, nsuccessors: Mapping[Array, int]):
super().__init__()
self.nsuccessors = nsuccessors
@@ -1252,6 +1433,52 @@ def map_distributed_recv(self, expr: DistributedRecv
) -> MPMSMaterializerAccumulator:
return MPMSMaterializerAccumulator(frozenset([expr]), expr)
+ @memoize_method
+ def map_function_definition(self, expr: FunctionDefinition
+ ) -> FunctionDefinition:
+ # spawn a new traversal here.
+ from pytato.analysis import get_nusers
+
+ returns_dict_of_named_arys = DictOfNamedArrays(expr.returns)
+ func_nsuccessors = get_nusers(returns_dict_of_named_arys)
+ new_mapper = MPMSMaterializer(func_nsuccessors)
+ new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()}
+ return FunctionDefinition(expr.parameters,
+ expr.return_type,
+ Map(new_returns),
+ tags=expr.tags)
+
+ @memoize_method
+ def map_call(self, expr: Call) -> Call:
+ return Call(self.map_function_definition(expr.function),
+ Map({name: self.rec(bnd).expr
+ for name, bnd in expr.bindings.items()}),
+ tags=expr.tags)
+
+ def map_named_call_result(self, expr: NamedCallResult
+ ) -> MPMSMaterializerAccumulator:
+ assert isinstance(expr._container, Call)
+ new_call = self.map_call(expr._container)
+ new_result = new_call[expr.name]
+
+ assert isinstance(new_result, NamedCallResult)
+ assert isinstance(new_result._container, Call)
+
+ # do not use _materialize_if_mpms as tagging a NamedArray is illegal.
+ if new_result.tags_of_type(ImplStored):
+ return MPMSMaterializerAccumulator(frozenset([new_result]),
+ new_result)
+ else:
+ from functools import reduce
+ materialized_predecessors: FrozenSet[Array] = (
+ reduce(frozenset.union,
+ (self.rec(bnd).materialized_predecessors
+ for bnd in new_result._container.bindings.values()),
+ frozenset())
+ )
+ return MPMSMaterializerAccumulator(materialized_predecessors,
+ new_result)
+
# }}}
@@ -1487,9 +1714,26 @@ def map_distributed_send_ref_holder(
def map_distributed_recv(self, expr: DistributedRecv) -> None:
self.rec_idx_or_size_tuple(expr, expr.shape)
+ def map_function_definition(self, expr: FunctionDefinition, *args: Any
+ ) -> None:
+ raise AssertionError("Control shouldn't reach at this point."
+ " Instantiate another UsersCollector to"
+ " traverse the callee function.")
+
+ def map_call(self, expr: Call, *args: Any) -> None:
+ for bnd in expr.bindings.values():
+ self.rec(bnd)
+
+ def map_named_call(self, expr: NamedCallResult, *args: Any) -> None:
+ assert isinstance(expr._container, Call)
+ for bnd in expr._container.bindings.values():
+ self.node_to_users.setdefault(bnd, set()).add(expr)
+
+ self.rec(expr._container)
+
def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames,
- Set[ArrayOrNames]]:
+ Set[ArrayOrNames]]:
"""
Returns a mapping from node in *expr* to its direct users.
"""
diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py
new file mode 100644
index 000000000..a25dd8311
--- /dev/null
+++ b/pytato/transform/calls.py
@@ -0,0 +1,114 @@
+"""
+.. currentmodule:: pytato.transform.calls
+
+.. autofunction:: inline_calls
+.. autofunction:: tag_all_calls_to_be_inlined
+"""
+__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni"
+
+__license__ = """
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+"""
+
+from immutables import Map
+from pytato.transform import (ArrayOrNames, CopyMapper)
+from pytato.array import (AbstractResultWithNamedArrays, Array,
+ DictOfNamedArrays, Placeholder)
+
+from pytato.function import Call, NamedCallResult
+from pytato.tags import InlineCallTag
+
+
+# {{{ inlining
+
+class PlaceholderSubstitutor(CopyMapper):
+ """
+ .. attribute:: substitutions
+
+ A mapping from the placeholder name to the array that it is to be
+ substituted with.
+ """
+ def __init__(self, substitutions: Map[str, Array]) -> None:
+ super().__init__()
+ self.substitutions = substitutions
+
+ def map_placeholder(self, expr: Placeholder) -> Array:
+ return self.substitutions[expr.name]
+
+
+class Inliner(CopyMapper):
+ """
+ Primary mapper for :func:`inline_calls`.
+ """
+ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays:
+ # inline call sites within the callee.
+ new_expr = super().map_call(expr)
+ assert isinstance(new_expr, Call)
+
+ if expr.tags_of_type(InlineCallTag):
+ substitutor = PlaceholderSubstitutor(expr.bindings)
+
+ return DictOfNamedArrays(
+ {name: substitutor(ret)
+ for name, ret in new_expr.function.returns.items()},
+ tags=expr.tags
+ )
+ else:
+ return new_expr
+
+ def map_named_call_result(self, expr: NamedCallResult) -> Array:
+ new_call = self.rec(expr._container)
+ assert isinstance(new_call, AbstractResultWithNamedArrays)
+ return new_call[expr.name]
+
+
+class InlineMarker(CopyMapper):
+ """
+ Primary mapper for :func:`tag_all_calls_to_be_inlined`.
+ """
+ def map_call(self, expr: Call) -> AbstractResultWithNamedArrays:
+ return super().map_call(expr).tagged(InlineCallTag())
+
+
+def inline_calls(expr: ArrayOrNames) -> ArrayOrNames:
+ """
+ Returns a copy of *expr* with call sites tagged with
+ :class:`pytato.tags.InlineCallTag` inlined into the expression graph.
+ """
+ inliner = Inliner()
+ return inliner(expr)
+
+
+def tag_all_calls_to_be_inlined(expr: ArrayOrNames) -> ArrayOrNames:
+ """
+ Returns a copy of *expr* with all reachable instances of
+ :class:`pytato.function.Call` nodes tagged with
+ :class:`pytato.tags.InlineCallTag`.
+
+ .. note::
+
+ This routine does NOT inline calls, to inline the calls
+ use :func:`tag_all_calls_to_be_inlined` on this routine's
+ output.
+ """
+ return InlineMarker()(expr)
+
+# }}}
+
+# vim:foldmethod=marker
diff --git a/pytato/visualization/__init__.py b/pytato/visualization/__init__.py
index 6fed19fec..e0138a78c 100644
--- a/pytato/visualization/__init__.py
+++ b/pytato/visualization/__init__.py
@@ -2,19 +2,15 @@
.. currentmodule:: pytato
.. automodule:: pytato.visualization.dot
-.. automodule:: pytato.visualization.ascii
.. automodule:: pytato.visualization.fancy_placeholder_data_flow
"""
from .dot import get_dot_graph, show_dot_graph, get_dot_graph_from_partition
-from .ascii import get_ascii_graph, show_ascii_graph
from .fancy_placeholder_data_flow import show_fancy_placeholder_data_flow
__all__ = [
"get_dot_graph", "show_dot_graph", "get_dot_graph_from_partition",
- "get_ascii_graph", "show_ascii_graph",
-
"show_fancy_placeholder_data_flow",
]
diff --git a/pytato/visualization/ascii.py b/pytato/visualization/ascii.py
deleted file mode 100644
index e417a9bfc..000000000
--- a/pytato/visualization/ascii.py
+++ /dev/null
@@ -1,124 +0,0 @@
-"""
-.. currentmodule:: pytato
-
-.. autofunction:: get_ascii_graph
-.. autofunction:: show_ascii_graph
-"""
-__copyright__ = """
-Copyright (C) 2021 University of Illinois Board of Trustees
-"""
-
-__license__ = """
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in
-all copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
-THE SOFTWARE.
-"""
-
-from typing import Union, List, Dict
-from pytato.transform import ArrayOrNames
-from pytato.array import Array, DictOfNamedArrays, InputArgumentBase
-from pytato.visualization.dot import ArrayToDotNodeInfoMapper
-from pytato.codegen import normalize_outputs
-from pytools import UniqueNameGenerator
-
-
-# {{{ Show ASCII representation of DAG
-
-def get_ascii_graph(result: Union[Array, DictOfNamedArrays],
- use_color: bool = True) -> str:
- """Return a string representing the computation of *result*
- using the `asciidag `_ package.
-
- :arg result: Outputs of the computation (cf.
- :func:`pytato.generate_loopy`).
- :arg use_color: Colorized output
- """
- outputs: DictOfNamedArrays = normalize_outputs(result)
- del result
-
- mapper = ArrayToDotNodeInfoMapper()
- for elem in outputs._data.values():
- mapper(elem)
-
- nodes = mapper.nodes
-
- input_arrays: List[Array] = []
- internal_arrays: List[ArrayOrNames] = []
- array_to_id: Dict[ArrayOrNames, str] = {}
-
- id_gen = UniqueNameGenerator()
- for array in nodes:
- array_to_id[array] = id_gen("array")
- if isinstance(array, InputArgumentBase):
- input_arrays.append(array)
- else:
- internal_arrays.append(array)
-
- # Since 'asciidag' prints the DAG from top to bottom (ie, with the inputs
- # at the bottom), we need to invert our representation of it, that is, the
- # 'parents' constructor argument to Node() actually means 'children'.
- from asciidag.node import Node # type: ignore[import]
- asciidag_nodes: Dict[ArrayOrNames, Node] = {}
-
- from collections import defaultdict
- asciidag_edges: Dict[ArrayOrNames, List[ArrayOrNames]] = defaultdict(list)
-
- # Reverse edge directions
- for array in internal_arrays:
- for _, v in nodes[array].edges.items():
- asciidag_edges[v].append(array)
-
- # Add the internal arrays in reversed order
- for array in internal_arrays[::-1]:
- ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]]
-
- if array == internal_arrays[-1]:
- ary_edges.append(Node("Outputs"))
-
- asciidag_nodes[array] = Node(f"{nodes[array].title}",
- parents=ary_edges)
-
- # Add the input arrays last since they have no predecessors
- for array in input_arrays:
- ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]]
- asciidag_nodes[array] = Node(f"{nodes[array].title}", parents=ary_edges)
-
- input_node = Node("Inputs", parents=[asciidag_nodes[v] for v in input_arrays])
-
- from asciidag.graph import Graph # type: ignore[import]
- from io import StringIO
-
- f = StringIO()
- graph = Graph(fh=f, use_color=use_color)
-
- graph.show_nodes([input_node])
-
- # Get the graph and remove trailing whitespace
- res = "\n".join([s.rstrip() for s in f.getvalue().split("\n")])
-
- return res
-
-
-def show_ascii_graph(result: Union[Array, DictOfNamedArrays]) -> None:
- """Print a graph representing the computation of *result* to stdout using the
- `asciidag `_ package.
-
- :arg result: Outputs of the computation (cf.
- :func:`pytato.generate_loopy`) or the output of :func:`get_dot_graph`.
- """
-
- print(get_ascii_graph(result, use_color=True))
diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py
index 6f302fad4..0008cea20 100644
--- a/pytato/visualization/dot.py
+++ b/pytato/visualization/dot.py
@@ -26,17 +26,19 @@
"""
-import contextlib
+from functools import partial
import dataclasses
import html
-from typing import (TYPE_CHECKING, Callable, Dict, Union, Iterator, List,
- Mapping, Hashable, Any, FrozenSet)
+from typing import (TYPE_CHECKING, Callable, Dict, Tuple, Union, List,
+ Mapping, Any, FrozenSet, Set, Optional)
from pytools import UniqueNameGenerator
from pytools.tag import Tag
-from pytools.codegen import CodeGenerator as CodeGeneratorBase
from pytato.loopy import LoopyCall
+from pytato.function import Call, FunctionDefinition, NamedCallResult
+from pytato.tags import FunctionIdentifier
+from pytools.codegen import remove_common_indentation
from pytato.array import (
Array, DataWrapper, DictOfNamedArrays, IndexLambda, InputArgumentBase,
@@ -44,7 +46,7 @@
IndexBase)
from pytato.codegen import normalize_outputs
-from pytato.transform import CachedMapper, ArrayOrNames
+from pytato.transform import CachedMapper, ArrayOrNames, InputGatherer
from pytato.distributed.partition import (
DistributedGraphPartition, DistributedGraphPart, PartId)
@@ -62,13 +64,88 @@
"""
+# {{{ _DotEmitter
+
+@dataclasses.dataclass
+class _SubgraphTree:
+ contents: Optional[List[str]]
+ subgraphs: Dict[str, _SubgraphTree]
+
+
+class DotEmitter:
+ def __init__(self) -> None:
+ self.subgraph_to_lines: Dict[Tuple[str, ...], List[str]] = {}
+
+ def __call__(self, subgraph_path: Tuple[str, ...], s: str) -> None:
+ line_list = self.subgraph_to_lines.setdefault(subgraph_path, [])
+
+ if not s.strip():
+ line_list.append("")
+ else:
+ if "\n" in s:
+ s = remove_common_indentation(s)
+
+ for line in s.split("\n"):
+ line_list.append(line)
+
+ def _get_subgraph_tree(self) -> _SubgraphTree:
+ subgraph_tree = _SubgraphTree(contents=None, subgraphs={})
+
+ def insert_into_subgraph_tree(
+ root: _SubgraphTree, path: Tuple[str, ...], contents: List[str]
+ ) -> None:
+ if not path:
+ assert root.contents is None
+ root.contents = contents
+
+ else:
+ subgraph = root.subgraphs.setdefault(
+ path[0],
+ _SubgraphTree(contents=None, subgraphs={}))
+
+ insert_into_subgraph_tree(subgraph, path[1:], contents)
+
+ for sgp, lines in self.subgraph_to_lines.items():
+ insert_into_subgraph_tree(subgraph_tree, sgp, lines)
+
+ return subgraph_tree
+
+ def generate(self) -> str:
+ result = ["digraph computation {"]
+
+ indent_level = 1
+
+ def emit_subgraph(sg: _SubgraphTree) -> None:
+ nonlocal indent_level
+
+ indent = (indent_level*4)*" "
+ if sg.contents:
+ for ln in sg.contents:
+ result.append(indent + ln)
+
+ indent_level += 1
+ for sg_name, sub_sg in sg.subgraphs.items():
+ result.append(f"{indent}subgraph {sg_name} {{")
+ emit_subgraph(sub_sg)
+ result.append(f"{indent}" "}")
+ indent_level -= 1
+
+ emit_subgraph(self._get_subgraph_tree())
+
+ result.append("}")
+
+ return "\n".join(result)
+
+# }}}
+
+
# {{{ array -> dot node converter
@dataclasses.dataclass
-class DotNodeInfo:
+class _DotNodeInfo:
title: str
fields: Dict[str, str]
- edges: Dict[str, ArrayOrNames]
+ edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]]
def stringify_tags(tags: FrozenSet[Tag]) -> str:
@@ -88,17 +165,18 @@ def stringify_shape(shape: ShapeType) -> str:
class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
def __init__(self) -> None:
super().__init__()
- self.nodes: Dict[ArrayOrNames, DotNodeInfo] = {}
+ self.node_to_dot: Dict[ArrayOrNames, _DotNodeInfo] = {}
+ self.functions: Set[FunctionDefinition] = set()
- def get_common_dot_info(self, expr: Array) -> DotNodeInfo:
+ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo:
title = type(expr).__name__
fields = {"addr": hex(id(expr)),
"shape": stringify_shape(expr.shape),
"dtype": str(expr.dtype),
"tags": stringify_tags(expr.tags)}
- edges: Dict[str, ArrayOrNames] = {}
- return DotNodeInfo(title, fields, edges)
+ edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {}
+ return _DotNodeInfo(title, fields, edges)
# type-ignore-reason: incompatible with supertype
def handle_unsupported_array(self, # type: ignore[override]
@@ -125,7 +203,7 @@ def handle_unsupported_array(self, # type: ignore[override]
else:
info.fields[field] = str(attr)
- self.nodes[expr] = info
+ self.node_to_dot[expr] = info
def map_data_wrapper(self, expr: DataWrapper) -> None:
info = self.get_common_dot_info(expr)
@@ -137,7 +215,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> None:
with np.printoptions(threshold=4, precision=2):
info.fields["data"] = str(expr.data)
- self.nodes[expr] = info
+ self.node_to_dot[expr] = info
def map_index_lambda(self, expr: IndexLambda) -> None:
info = self.get_common_dot_info(expr)
@@ -147,7 +225,7 @@ def map_index_lambda(self, expr: IndexLambda) -> None:
self.rec(val)
info.edges[name] = val
- self.nodes[expr] = info
+ self.node_to_dot[expr] = info
def map_stack(self, expr: Stack) -> None:
info = self.get_common_dot_info(expr)
@@ -157,7 +235,7 @@ def map_stack(self, expr: Stack) -> None:
self.rec(array)
info.edges[str(i)] = array
- self.nodes[expr] = info
+ self.node_to_dot[expr] = info
map_concatenate = map_stack
@@ -188,7 +266,7 @@ def map_basic_index(self, expr: IndexBase) -> None:
self.rec(expr.array)
info.edges["array"] = expr.array
- self.nodes[expr] = info
+ self.node_to_dot[expr] = info
map_contiguous_advanced_index = map_basic_index
map_non_contiguous_advanced_index = map_basic_index
@@ -201,27 +279,27 @@ def map_einsum(self, expr: Einsum) -> None:
self.rec(val)
info.edges[f"{iarg}: {access_descr}"] = val
- self.nodes[expr] = info
+ self.node_to_dot[expr] = info
def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
- edges: Dict[str, ArrayOrNames] = {}
+ edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {}
for name, val in expr._data.items():
edges[name] = val
self.rec(val)
- self.nodes[expr] = DotNodeInfo(
+ self.node_to_dot[expr] = _DotNodeInfo(
title=type(expr).__name__,
fields={},
edges=edges)
def map_loopy_call(self, expr: LoopyCall) -> None:
- edges: Dict[str, ArrayOrNames] = {}
+ edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {}
for name, arg in expr.bindings.items():
if isinstance(arg, Array):
edges[name] = arg
self.rec(arg)
- self.nodes[expr] = DotNodeInfo(
+ self.node_to_dot[expr] = _DotNodeInfo(
title=type(expr).__name__,
fields={"addr": hex(id(expr)), "entrypoint": expr.entrypoint},
edges=edges)
@@ -241,25 +319,45 @@ def map_distributed_send_ref_holder(
info.fields["comm_tag"] = str(expr.send.comm_tag)
- self.nodes[expr] = info
+ self.node_to_dot[expr] = info
+
+ def map_call(self, expr: Call) -> None:
+ self.functions.add(expr.function)
+
+ for bnd in expr.bindings.values():
+ self.rec(bnd)
+
+ self.node_to_dot[expr] = _DotNodeInfo(
+ title=expr.__class__.__name__,
+ edges={
+ "": expr.function,
+ **expr.bindings},
+ fields={
+ "addr": hex(id(expr)),
+ "tags": stringify_tags(expr.tags),
+ }
+ )
+
+ def map_named_call_result(self, expr: NamedCallResult) -> None:
+ self.rec(expr._container)
+ self.node_to_dot[expr] = _DotNodeInfo(
+ title=expr.__class__.__name__,
+ edges={"": expr._container},
+ fields={"addr": hex(id(expr)),
+ "name": expr.name},
+ )
+
+# }}}
def dot_escape(s: str) -> str:
# "\" and HTML are significant in graphviz.
- return html.escape(s.replace("\\", "\\\\"))
+ return html.escape(s.replace("\\", "\\\\").replace(" ", "_"))
-class DotEmitter(CodeGeneratorBase):
- @contextlib.contextmanager
- def block(self, name: str) -> Iterator[None]:
- self(name + " {")
- self.indent()
- yield
- self.dedent()
- self("}")
+# {{{ emit helpers
-
-def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str],
+def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, str],
dot_node_id: str, color: str = "white") -> None:
td_attrib = 'border="0"'
table_attrib = 'border="0" cellborder="1" cellspacing="0"'
@@ -277,84 +375,189 @@ def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str],
emit("%s [label=<%s> style=filled fillcolor=%s]" % (dot_node_id, table, color))
-def _emit_name_cluster(emit: DotEmitter, names: Mapping[str, ArrayOrNames],
+def _emit_name_cluster(
+ emit: DotEmitter, subgraph_path: Tuple[str, ...],
+ names: Mapping[str, ArrayOrNames],
array_to_id: Mapping[ArrayOrNames, str], id_gen: Callable[[str], str],
label: str) -> None:
edges = []
- with emit.block("subgraph cluster_%s" % label):
- emit("node [shape=ellipse]")
- emit('label="%s"' % label)
+ cluster_subgraph_path = subgraph_path + (f"cluster_{dot_escape(label)}",)
+ emit_cluster = partial(emit, cluster_subgraph_path)
+ emit_cluster("node [shape=ellipse]")
+ emit_cluster(f'label="{label}"')
- for name, array in names.items():
- name_id = id_gen(label)
- emit('%s [label="%s"]' % (name_id, dot_escape(name)))
- array_id = array_to_id[array]
- # Edges must be outside the cluster.
- edges.append((name_id, array_id))
+ for name, array in names.items():
+ name_id = id_gen(dot_escape(name))
+ emit_cluster('%s [label="%s"]' % (name_id, dot_escape(name)))
+ array_id = array_to_id[array]
+ # Edges must be outside the cluster.
+ edges.append((name_id, array_id))
for name_id, array_id in edges:
- emit("%s -> %s" % (name_id, array_id))
-
-# }}}
-
-
-def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str:
- r"""Return a string in the `dot `_ language depicting the
- graph of the computation of *result*.
-
- :arg result: Outputs of the computation (cf.
- :func:`pytato.generate_loopy`).
- """
- outputs: DictOfNamedArrays = normalize_outputs(result)
- del result
-
- mapper = ArrayToDotNodeInfoMapper()
- for elem in outputs._data.values():
- mapper(elem)
+ emit(subgraph_path, "%s -> %s" % (array_id, name_id))
- nodes = mapper.nodes
+def _emit_function(
+ emitter: DotEmitter, subgraph_path: Tuple[str, ...],
+ id_gen: UniqueNameGenerator,
+ node_to_dot: Mapping[ArrayOrNames, _DotNodeInfo],
+ func_to_id: Mapping[FunctionDefinition, str],
+ outputs: Mapping[str, Array]) -> None:
input_arrays: List[Array] = []
internal_arrays: List[ArrayOrNames] = []
array_to_id: Dict[ArrayOrNames, str] = {}
- id_gen = UniqueNameGenerator()
- for array in nodes:
+ emit = partial(emitter, subgraph_path)
+ for array in node_to_dot:
array_to_id[array] = id_gen("array")
if isinstance(array, InputArgumentBase):
input_arrays.append(array)
else:
internal_arrays.append(array)
- emit = DotEmitter()
+ # Emit inputs.
+ input_subgraph_path = subgraph_path + ("cluster_inputs",)
+ emit_input = partial(emitter, input_subgraph_path)
+ emit_input('label="Arguments"')
+
+ for array in input_arrays:
+ _emit_array(
+ emit_input,
+ node_to_dot[array].title,
+ node_to_dot[array].fields,
+ array_to_id[array])
+
+ # Emit non-inputs.
+ for array in internal_arrays:
+ _emit_array(emit,
+ node_to_dot[array].title,
+ node_to_dot[array].fields,
+ array_to_id[array])
+
+ # Emit edges.
+ for array, node in node_to_dot.items():
+ for label, tail_item in node.edges.items():
+ head = array_to_id[array]
+ if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)):
+ tail = array_to_id[tail_item]
+ elif isinstance(tail_item, FunctionDefinition):
+ tail = func_to_id[tail_item]
+ else:
+ raise ValueError(
+ f"unexpected type of tail on edge: {type(tail_item)}")
- with emit.block("digraph computation"):
- emit("node [shape=rectangle]")
+ emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label)))
- # Emit inputs.
- with emit.block("subgraph cluster_Inputs"):
- emit('label="Inputs"')
- for array in input_arrays:
- _emit_array(emit,
- nodes[array].title, nodes[array].fields, array_to_id[array])
+ # Emit output/namespace name mappings.
+ _emit_name_cluster(
+ emitter, subgraph_path, outputs, array_to_id, id_gen, label="Returns")
- # Emit non-inputs.
- for array in internal_arrays:
- _emit_array(emit,
- nodes[array].title, nodes[array].fields, array_to_id[array])
+# }}}
+
+
+# {{{ information gathering
+
+def _get_function_name(f: FunctionDefinition) -> Optional[str]:
+ func_id_tags = f.tags_of_type(FunctionIdentifier)
+ if func_id_tags:
+ func_id_tag, = func_id_tags
+ return str(func_id_tag.identifier)
+ else:
+ return None
- # Emit edges.
- for array, node in nodes.items():
- for label, tail_array in node.edges.items():
- tail = array_to_id[tail_array]
- head = array_to_id[array]
- emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label)))
- # Emit output/namespace name mappings.
- _emit_name_cluster(emit, outputs._data, array_to_id, id_gen, label="Outputs")
+def _gather_partition_node_information(
+ id_gen: UniqueNameGenerator,
+ partition: DistributedGraphPartition
+ ) -> Tuple[
+ Mapping[PartId, Mapping[FunctionDefinition, str]],
+ Mapping[Tuple[PartId, Optional[FunctionDefinition]],
+ Mapping[ArrayOrNames, _DotNodeInfo]]
+ ]:
+ part_id_to_func_to_id: Dict[PartId, Dict[FunctionDefinition, str]] = {}
+ part_id_func_to_node_info: Dict[Tuple[PartId, Optional[FunctionDefinition]],
+ Dict[ArrayOrNames, _DotNodeInfo]] = {}
- return emit.get()
+ for part in partition.parts.values():
+ mapper = ArrayToDotNodeInfoMapper()
+ for out_name in part.output_names:
+ mapper(partition.name_to_output[out_name])
+
+ part_id_func_to_node_info[part.pid, None] = mapper.node_to_dot
+ part_id_to_func_to_id[part.pid] = {}
+
+ # It is important that seen functions are emitted callee-first.
+ # (Otherwise function 'entry' nodes will get declared in the wrong
+ # cluster.) So use a data type that preserves order.
+ seen_functions: List[FunctionDefinition] = []
+
+ def gather_function_info(f: FunctionDefinition) -> None:
+ key = (part.pid, f) # noqa: B023
+ if key in part_id_func_to_node_info:
+ return
+
+ mapper = ArrayToDotNodeInfoMapper()
+ for elem in f.returns.values():
+ mapper(elem)
+
+ part_id_func_to_node_info[key] = mapper.node_to_dot
+
+ for subfunc in mapper.functions:
+ gather_function_info(subfunc)
+
+ if f not in seen_functions: # noqa: B023
+ seen_functions.append(f) # noqa: B023
+
+ for f in mapper.functions:
+ gather_function_info(f)
+
+ # Again, important to preserve function order. Here we're relying
+ # on dicts to preserve order.
+ for f in seen_functions:
+ func_name = _get_function_name(f)
+ if func_name is not None:
+ fid = id_gen(dot_escape(func_name))
+ else:
+ fid = id_gen("func")
+
+ part_id_to_func_to_id.setdefault(part.pid, {})[f] = fid
+
+ return part_id_to_func_to_id, part_id_func_to_node_info
+
+# }}}
+
+
+def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str:
+ r"""Return a string in the `dot `_ language depicting the
+ graph of the computation of *result*.
+
+ :arg result: Outputs of the computation (cf.
+ :func:`pytato.generate_loopy`).
+ """
+
+ outputs: DictOfNamedArrays = normalize_outputs(result)
+
+ return get_dot_graph_from_partition(
+ DistributedGraphPartition(
+ parts={
+ None: DistributedGraphPart(
+ pid=None,
+ needed_pids=frozenset(),
+ user_input_names=frozenset(
+ expr.name
+ for expr in InputGatherer()(outputs)
+ if isinstance(expr, Placeholder)
+ ),
+ partition_input_names=frozenset(),
+ output_names=frozenset(outputs.keys()),
+ name_to_recv_node={},
+ name_to_send_nodes={},
+ )
+ },
+ name_to_output=outputs._data,
+ overall_output_names=tuple(outputs),
+ ))
def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
@@ -363,170 +566,226 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
:arg partition: Outputs of :func:`~pytato.find_distributed_partition`.
"""
- # Maps each partition to a dict of its arrays with the node info
- part_id_to_node_info: Dict[Hashable, Dict[ArrayOrNames, DotNodeInfo]] = {}
+ id_gen = UniqueNameGenerator()
- for part in partition.parts.values():
- mapper = ArrayToDotNodeInfoMapper()
- for out_name in part.output_names:
- mapper(partition.name_to_output[out_name])
+ # {{{ gather up node info, per partition and per function
- part_id_to_node_info[part.pid] = mapper.nodes
+ # The "None" function is the body of the partition.
- id_gen = UniqueNameGenerator()
+ part_id_to_func_to_id, part_id_func_to_node_info = \
+ _gather_partition_node_information(id_gen, partition)
+
+ # }}}
- emit = DotEmitter()
+ emitter = DotEmitter()
+ emit_root = partial(emitter, ())
emitted_placeholders = set()
- with emit.block("digraph computation"):
- emit("node [shape=rectangle]")
- placeholder_to_id: Dict[ArrayOrNames, str] = {}
- part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {}
-
- # First pass: generate names for all nodes
- for part in partition.parts.values():
- array_to_id = {}
- for array, _ in part_id_to_node_info[part.pid].items():
- if isinstance(array, Placeholder):
- # Placeholders are only emitted once
- if array in placeholder_to_id:
- node_id = placeholder_to_id[array]
- else:
- node_id = id_gen("array")
- placeholder_to_id[array] = node_id
+ emit_root("node [shape=rectangle]")
+
+ placeholder_to_id: Dict[ArrayOrNames, str] = {}
+ part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {}
+
+ part_id_to_id = {pid: dot_escape(str(pid)) for pid in partition.parts}
+ assert len(set(part_id_to_id.values())) == len(partition.parts)
+
+ # {{{ generate names for all nodes in the root/None function
+
+ for part in partition.parts.values():
+ array_to_id = {}
+ for array in part_id_func_to_node_info[part.pid, None].keys():
+ if isinstance(array, Placeholder):
+ # Placeholders are only emitted once
+ if array in placeholder_to_id:
+ node_id = placeholder_to_id[array]
else:
node_id = id_gen("array")
- array_to_id[array] = node_id
- part_id_to_array_to_id[part.pid] = array_to_id
-
- # Second pass: emit the graph.
- for part in partition.parts.values():
- array_to_id = part_id_to_array_to_id[part.pid]
-
- # {{{ emit receives nodes if distributed
-
- if isinstance(part, DistributedGraphPart):
- part_dist_recv_var_name_to_node_id = {}
- for name, recv in (
- part.name_to_recv_node.items()):
- node_id = id_gen("recv")
- _emit_array(emit, "DistributedRecv", {
- "shape": stringify_shape(recv.shape),
- "dtype": str(recv.dtype),
- "src_rank": str(recv.src_rank),
- "comm_tag": str(recv.comm_tag),
- }, node_id)
+ placeholder_to_id[array] = node_id
+ else:
+ node_id = id_gen("array")
+ array_to_id[array] = node_id
+ part_id_to_array_to_id[part.pid] = array_to_id
+
+ # }}}
+
+ # {{{ emit the graph
+
+ for part in partition.parts.values():
+ array_to_id = part_id_to_array_to_id[part.pid]
+
+ is_trivial_partition = part.pid is None and len(partition.parts) == 1
+ if is_trivial_partition:
+ part_subgraph_path: Tuple[str, ...] = ()
+ else:
+ part_subgraph_path = (f"cluster_{part_id_to_id[part.pid]}",)
+
+ emit_part = partial(emitter, part_subgraph_path)
+
+ if not is_trivial_partition:
+ emit_part("style=dashed")
+ emit_part(f'label="{part.pid}"')
+
+ # {{{ emit functions
+
+ # It is important that seen functions are emitted callee-first.
+ # Here we're relying on the part_id_to_func_to_id dict to preserve order.
+
+ for func, fid in part_id_to_func_to_id[part.pid].items():
+ func_subgraph_path = part_subgraph_path + (f"cluster_{fid}",)
+ label = _get_function_name(func) or fid
+
+ emitter(func_subgraph_path, f'label="{label}"')
+ emitter(func_subgraph_path, f'{fid} [label="{label}",shape="ellipse"]')
- part_dist_recv_var_name_to_node_id[name] = node_id
+ _emit_function(emitter, func_subgraph_path,
+ id_gen, part_id_func_to_node_info[part.pid, func],
+ part_id_to_func_to_id[part.pid],
+ func.returns)
+
+ # }}}
+
+ # {{{ emit receives nodes
+
+ part_dist_recv_var_name_to_node_id = {}
+ for name, recv in (
+ part.name_to_recv_node.items()):
+ node_id = id_gen("recv")
+ _emit_array(emit_part, "DistributedRecv", {
+ "shape": stringify_shape(recv.shape),
+ "dtype": str(recv.dtype),
+ "src_rank": str(recv.src_rank),
+ "comm_tag": str(recv.comm_tag),
+ }, node_id)
+
+ part_dist_recv_var_name_to_node_id[name] = node_id
+
+ # }}}
+
+ part_node_to_info = part_id_func_to_node_info[part.pid, None]
+ input_arrays: List[Array] = []
+ internal_arrays: List[ArrayOrNames] = []
+
+ for array in part_node_to_info.keys():
+ if isinstance(array, InputArgumentBase):
+ input_arrays.append(array)
else:
- part_dist_recv_var_name_to_node_id = {}
+ internal_arrays.append(array)
- # }}}
+ # {{{ emit inputs
- part_node_to_info = part_id_to_node_info[part.pid]
- input_arrays: List[Array] = []
- internal_arrays: List[ArrayOrNames] = []
+ # Placeholders are unique, i.e. the same Placeholder object may be
+ # shared among partitions. Therefore, they should not live inside
+ # the (dot) subgraph, otherwise they would be forced into multiple
+ # subgraphs.
- for array in part_node_to_info.keys():
- if isinstance(array, InputArgumentBase):
- input_arrays.append(array)
+ for array in input_arrays:
+ if not isinstance(array, Placeholder):
+ _emit_array(emit_part,
+ part_node_to_info[array].title,
+ part_node_to_info[array].fields,
+ array_to_id[array], "deepskyblue")
+ else:
+ # Is a Placeholder
+ if array in emitted_placeholders:
+ continue
+
+ _emit_array(emit_root,
+ part_node_to_info[array].title,
+ part_node_to_info[array].fields,
+ array_to_id[array], "deepskyblue")
+
+ # Emit cross-partition edges
+ if array.name in part_dist_recv_var_name_to_node_id:
+ tgt = part_dist_recv_var_name_to_node_id[array.name]
+ emit_root(f"{tgt} -> {array_to_id[array]} [style=dotted]")
+ emitted_placeholders.add(array)
+ elif array.name in part.user_input_names:
+ # no arrows for these
+ pass
else:
- internal_arrays.append(array)
-
- # {{{ emit inputs
-
- # Placeholders are unique, i.e. the same Placeholder object may be
- # shared among partitions. Therefore, they should not live inside
- # the (dot) subgraph, otherwise they would be forced into multiple
- # subgraphs.
-
- for array in input_arrays:
- # Non-Placeholders are emitted *inside* their subgraphs below.
- if isinstance(array, Placeholder):
- if array not in emitted_placeholders:
- _emit_array(emit,
- part_node_to_info[array].title,
- part_node_to_info[array].fields,
- array_to_id[array], "deepskyblue")
-
- # Emit cross-partition edges
- if array.name in part_dist_recv_var_name_to_node_id:
- tgt = part_dist_recv_var_name_to_node_id[array.name]
- emit(f"{tgt} -> {array_to_id[array]} [style=dotted]")
- emitted_placeholders.add(array)
- elif array.name in part.user_input_names:
- # These are placeholders for external input. They
- # are cleanly associated with a single partition
- # and thus emitted below.
- pass
- else:
- # placeholder for a value from a different partition
- computing_pid = None
- for other_part in partition.parts.values():
- if array.name in other_part.output_names:
- computing_pid = other_part.pid
- break
- assert computing_pid is not None
- tgt = part_id_to_array_to_id[computing_pid][
- partition.name_to_output[array.name]]
- emit(f"{tgt} -> {array_to_id[array]} [style=dashed]")
- emitted_placeholders.add(array)
-
- # }}}
-
- with emit.block(f'subgraph "cluster_part_{part.pid}"'):
- emit("style=dashed")
- emit(f'label="{part.pid}"')
-
- for array in input_arrays:
- if (not isinstance(array, Placeholder)
- or array.name in part.user_input_names):
- _emit_array(emit,
- part_node_to_info[array].title,
- part_node_to_info[array].fields,
- array_to_id[array], "deepskyblue")
-
- # Emit internal nodes
- for array in internal_arrays:
- _emit_array(emit,
- part_node_to_info[array].title,
- part_node_to_info[array].fields,
- array_to_id[array])
-
- # {{{ emit send nodes if distributed
-
- deferred_send_edges = []
- if isinstance(part, DistributedGraphPart):
- for name, sends in (
- part.name_to_send_nodes.items()):
- for send in sends:
- node_id = id_gen("send")
- _emit_array(emit, "DistributedSend", {
- "dest_rank": str(send.dest_rank),
- "comm_tag": str(send.comm_tag),
- }, node_id)
-
- deferred_send_edges.append(
- f"{array_to_id[send.data]} -> {node_id}"
- f'[style=dotted, label="{dot_escape(name)}"]')
-
- # }}}
-
- # If an edge is emitted in a subgraph, it drags its nodes into the
- # subgraph, too. Not what we want.
- for edge in deferred_send_edges:
- emit(edge)
-
- # Emit intra-partition edges
- for array, node in part_node_to_info.items():
- for label, tail_array in node.edges.items():
- tail = array_to_id[tail_array]
- head = array_to_id[array]
- emit('%s -> %s [label="%s"]' %
- (tail, head, dot_escape(label)))
-
- return emit.get()
+ # placeholder for a value from a different partition
+ computing_pid = None
+ for other_part in partition.parts.values():
+ if array.name in other_part.output_names:
+ computing_pid = other_part.pid
+ break
+ assert computing_pid is not None
+ tgt = part_id_to_array_to_id[computing_pid][
+ partition.name_to_output[array.name]]
+ emit_root(f"{tgt} -> {array_to_id[array]} [style=dashed]")
+ emitted_placeholders.add(array)
+
+ # }}}
+
+ # Emit internal nodes
+ for array in internal_arrays:
+ _emit_array(emit_part,
+ part_node_to_info[array].title,
+ part_node_to_info[array].fields,
+ array_to_id[array])
+
+ # {{{ emit send nodes if distributed
+
+ if isinstance(part, DistributedGraphPart):
+ for name, sends in part.name_to_send_nodes.items():
+ for send in sends:
+ node_id = id_gen("send")
+ _emit_array(emit_part, "DistributedSend", {
+ "dest_rank": str(send.dest_rank),
+ "comm_tag": str(send.comm_tag),
+ }, node_id)
+
+ # If an edge is emitted in a subgraph, it drags its
+ # nodes into the subgraph, too. Not what we want.
+ emit_root(
+ f"{array_to_id[send.data]} -> {node_id}"
+ f'[style=dotted, label="{dot_escape(name)}"]')
+
+ # }}}
+
+ # Emit intra-partition edges
+ for array, node in part_node_to_info.items():
+ for label, tail_item in node.edges.items():
+ head = array_to_id[array]
+
+ if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)):
+ tail = array_to_id[tail_item]
+ elif isinstance(tail_item, FunctionDefinition):
+ tail = part_id_to_func_to_id[part.pid][tail_item]
+ else:
+ raise ValueError(
+ f"unexpected type of tail on edge: {type(tail_item)}")
+
+ emit_root('%s -> %s [label="%s"]' %
+ (tail, head, dot_escape(label)))
+
+ _emit_name_cluster(
+ emitter, part_subgraph_path,
+ {name: partition.name_to_output[name] for name in part.output_names},
+ array_to_id, id_gen, "Part outputs")
+
+ # }}}
+
+ # Arrays may occur in multiple partitions, they get drawn separately anyhow
+ # (unless they're Placeholders). Don't be tempted to use
+ # combined_array_to_id everywhere.
+
+ # {{{ draw overall outputs
+
+ combined_array_to_id: Dict[ArrayOrNames, str] = {}
+ for part_id in partition.parts.keys():
+ combined_array_to_id.update(part_id_to_array_to_id[part_id])
+
+ _emit_name_cluster(
+ emitter, (),
+ {name: partition.name_to_output[name]
+ for name in partition.overall_output_names},
+ combined_array_to_id, id_gen, "Overall outputs")
+
+ # }}}
+
+ return emitter.generate()
def show_dot_graph(result: Union[str, Array, DictOfNamedArrays,
@@ -551,4 +810,4 @@ def show_dot_graph(result: Union[str, Array, DictOfNamedArrays,
from pytools.graphviz import show_dot
show_dot(dot_code, **kwargs)
-# }}}
+# vim:fdm=marker
diff --git a/requirements.txt b/requirements.txt
index 7e69f17f9..4f9b14d65 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,5 +3,4 @@ git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/inducer/genpy.git#egg=genpy
git+https://github.com/inducer/loopy.git#egg=loopy
-asciidag
mako
diff --git a/test/test_codegen.py b/test/test_codegen.py
index 874906528..b6d0bba8b 100755
--- a/test/test_codegen.py
+++ b/test/test_codegen.py
@@ -1854,6 +1854,100 @@ def test_pad(ctx_factory):
np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array)
+def test_function_call(ctx_factory, visualize=False):
+ from functools import partial
+ cl_ctx = ctx_factory()
+ cq = cl.CommandQueue(cl_ctx)
+
+ def f(x):
+ return 2*x
+
+ def g(tracer, x):
+ return tracer(f, x), 3*x
+
+ def h(x, y):
+ return {"twice": 2*x+y, "thrice": 3*x+y}
+
+ def build_expression(tracer):
+ x = pt.arange(500, dtype=np.float32)
+ twice_x = tracer(f, x)
+ twice_x_2, thrice_x_2 = tracer(partial(g, tracer), x)
+
+ result = tracer(h, x, 2*x)
+ twice_x_3 = result["twice"]
+ thrice_x_3 = result["thrice"]
+
+ return {"foo": 3.14 + twice_x_3,
+ "bar": 4 * thrice_x_3,
+ "baz": 65 * twice_x,
+ "quux": 7 * twice_x_2}
+
+ result_with_functions = pt.tag_all_calls_to_be_inlined(
+ pt.make_dict_of_named_arrays(build_expression(pt.trace_call)))
+ result_without_functions = pt.make_dict_of_named_arrays(
+ build_expression(lambda fn, *args: fn(*args)))
+
+ # test that visualizing graphs with functions works
+ dot = pt.get_dot_graph(result_with_functions)
+
+ if visualize:
+ pt.show_dot_graph(dot)
+
+ _, outputs = pt.generate_loopy(result_with_functions)(cq, out_host=True)
+ _, expected = pt.generate_loopy(result_without_functions)(cq, out_host=True)
+
+ assert len(outputs) == len(expected)
+
+ for key in outputs.keys():
+ np.testing.assert_allclose(outputs[key], expected[key])
+
+
+def test_nested_function_calls(ctx_factory):
+ from functools import partial
+
+ ctx = ctx_factory()
+ cq = cl.CommandQueue(ctx)
+
+ rng = np.random.default_rng(0)
+ x_np = rng.random((10,))
+
+ x = pt.make_placeholder("x", 10, np.float64).tagged(pt.tags.ImplStored())
+ prg = pt.generate_loopy({"out1": 3*x, "out2": x})
+ _, out = prg(cq, x=x_np)
+ np.testing.assert_allclose(out["out1"], 3*x_np)
+ np.testing.assert_allclose(out["out2"], x_np)
+ ref_tracer = lambda f, *args, identifier: f(*args) # noqa: E731
+
+ def foo(tracer, x, y):
+ return 2*x + 3*y
+
+ def bar(tracer, x, y):
+ foo_x_y = tracer(partial(foo, tracer), x, y, identifier="foo")
+ return foo_x_y * x * y
+
+ def call_bar(tracer, x, y):
+ return tracer(partial(bar, tracer), x, y, identifier="bar")
+
+ x1_np, y1_np = rng.random((2, 13, 29))
+ x2_np, y2_np = rng.random((2, 4, 29))
+ x1, y1 = pt.make_data_wrapper(x1_np), pt.make_data_wrapper(y1_np)
+ x2, y2 = pt.make_data_wrapper(x2_np), pt.make_data_wrapper(y2_np)
+ result = pt.make_dict_of_named_arrays({"out1": call_bar(pt.trace_call, x1, y1),
+ "out2": call_bar(pt.trace_call, x2, y2)}
+ )
+ result = pt.tag_all_calls_to_be_inlined(result)
+ expect = pt.make_dict_of_named_arrays({"out1": call_bar(ref_tracer, x1, y1),
+ "out2": call_bar(ref_tracer, x2, y2)}
+ )
+
+ _, result_out = pt.generate_loopy(result)(cq)
+ _, expect_out = pt.generate_loopy(expect)(cq)
+
+ assert result_out.keys() == expect_out.keys()
+ for k in expect_out:
+ np.testing.assert_allclose(result_out[k], expect_out[k])
+
+
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
diff --git a/test/test_pytato.py b/test/test_pytato.py
index 5d2343ee9..5b35bdde7 100644
--- a/test/test_pytato.py
+++ b/test/test_pytato.py
@@ -371,37 +371,6 @@ def test_userscollector():
assert nuc[dag] == 0
-def test_asciidag():
- pytest.importorskip("asciidag")
-
- n = pt.make_size_param("n")
- array = pt.make_placeholder(name="array", shape=n, dtype=np.float64)
- stack = pt.stack([array, 2*array, array + 6])
- y = stack @ stack.T
-
- from pytato import get_ascii_graph
-
- res = get_ascii_graph(y, use_color=False)
-
- ref_str = r"""* Inputs
-*-. Placeholder
-|\ \
-* | | IndexLambda
-| |/
-|/|
-| * IndexLambda
-|/
-* Stack
-|\
-* | AxisPermutation
-|/
-* Einsum
-* Outputs
-"""
-
- assert res == ref_str
-
-
def test_linear_complexity_inequality():
# See https://github.com/inducer/pytato/issues/163
import pytato as pt