From a01789910b4f79fcd5bf7b42296dee76b41ac91a Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 4 Nov 2022 23:27:58 -0500 Subject: [PATCH 01/11] Introduce `pt.trace_call` Co-authored-by: Andreas Kloeckner --- pytato/__init__.py | 6 +- pytato/function.py | 387 +++++++++++++++++++++++++++++++++++++++++++++ pytato/tags.py | 13 +- 3 files changed, 404 insertions(+), 2 deletions(-) create mode 100644 pytato/function.py diff --git a/pytato/__init__.py b/pytato/__init__.py index 0117d6ba9..46794a9af 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -95,6 +95,7 @@ def set_debug_enabled(flag: bool) -> None: ) import pytato.analysis as analysis import pytato.tags as tags +import pytato.function as function import pytato.transform as transform from pytato.distributed.nodes import (make_distributed_send, make_distributed_recv, DistributedRecv, DistributedSend, @@ -111,6 +112,7 @@ def set_debug_enabled(flag: bool) -> None: from pytato.transform.remove_broadcasts_einsum import ( rewrite_einsums_with_no_broadcasts) from pytato.transform.metadata import unify_axes_tags +from pytato.function import trace_call __all__ = ( "dtype", @@ -156,6 +158,8 @@ def set_debug_enabled(flag: bool) -> None: "broadcast_to", "pad", + "trace_call", + "make_distributed_recv", "make_distributed_send", "DistributedRecv", "DistributedSend", "staple_distributed_send", "DistributedSendRefHolder", @@ -175,6 +179,6 @@ def set_debug_enabled(flag: bool) -> None: "unify_axes_tags", # sub-modules - "analysis", "tags", "transform", + "analysis", "tags", "transform", "function", ) diff --git a/pytato/function.py b/pytato/function.py new file mode 100644 index 000000000..004aac18a --- /dev/null +++ b/pytato/function.py @@ -0,0 +1,387 @@ +from __future__ import annotations + +__doc__ = """ +.. currentmodule:: pytato + +.. autofunction:: trace_call + +.. currentmodule:: pytato.function + +.. autoclass:: Call +.. autoclass:: NamedCallResult +.. autoclass:: FunctionDefinition +.. autoclass:: ReturnType + +.. class:: ReturnT + + A type variable corresponding to the return type of the function + :func:`pytato.trace_call`. +""" + +__copyright__ = """ +Copyright (C) 2022 Andreas Kloeckner +Copyright (C) 2022 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import attrs +import re +import enum + +from typing import (Callable, Dict, FrozenSet, Tuple, Union, TypeVar, Optional, + Hashable, Sequence, ClassVar, Iterator, Iterable, Mapping) +from immutables import Map +from functools import cached_property +from pytato.array import (Array, AbstractResultWithNamedArrays, + Placeholder, NamedArray, ShapeType, _dtype_any, + InputArgumentBase) +from pytools.tag import Tag, Taggable + +ReturnT = TypeVar("ReturnT", Array, Tuple[Array, ...], Dict[str, Array]) + + +# {{{ Call/NamedCallResult + + +@enum.unique +class ReturnType(enum.Enum): + """ + Records the function body's return type in :class:`FunctionDefinition`. + """ + ARRAY = 0 + DICT_OF_ARRAYS = 1 + TUPLE_OF_ARRAYS = 2 + + +# eq=False to avoid equality comparison without EqualityMaper +@attrs.define(frozen=True, eq=False, hash=True) +class FunctionDefinition(Taggable): + r""" + A function definition that represents its outputs as instances of + :class:`~pytato.Array` with the inputs being + :class:`~pytato.array.Placeholder`\ s. The outputs of the function + can be a single :class:`pytato.Array`, a tuple of :class:`pytato.Array`\ s or an + instance of ``Dict[str, Array]``. + + .. attribute:: parameters + + Names of the input :class:`~pytato.array.Placeholder`\ s to the + function node. This is a superset of the names of + :class:`~pytato.array.Placeholder` instances encountered in + :attr:`returns`. Unused parameters are allowed. + + .. attribute:: return_type + + An instance of :class:`ReturnType`. + + .. attribute:: returns + + The outputs of the function call which are array expressions that + depend on the *parameters*. The keys of the mapping depend on + :attr:`return_type` as: + + - If the function returns a single :class:`pytato.Array`, then + *returns* contains a single array expression with ``"_"`` as the + key. + - If the function returns a :class:`tuple` of + :class:`pytato.Array`\ s, then *returns* contains entries with + the key ``"_N"`` mapping the ``N``-th entry of the result-tuple. + - If the function returns a :class:`dict` mapping identifiers to + :class:`pytato.Array`\ s, then *returns* uses the same mapping. + + .. automethod:: get_placeholder + + .. note:: + + A :class:`FunctionDefinition` comes with its own namespace based on + :attr:`parameters`. A :class:`~pytato.transform.Mapper`-implementer + must ensure **not** to reuse the cached result between the caller's + expressions and a function definition's expressions to avoid unsound + cache hits that could lead to incorrect mappings. + + .. note:: + + At this point, code generation/execution does not support + distributed-memory communication nodes (:class:`~pytato.DistributedSend`, + :class:`~pytato.DistributedRecv`) within function bodies. + """ + parameters: FrozenSet[str] + return_type: ReturnType + returns: Map[str, Array] + tags: FrozenSet[Tag] = attrs.field(kw_only=True) + + @cached_property + def _placeholders(self) -> Mapping[str, Placeholder]: + from pytato.transform import InputGatherer + from functools import reduce + + mapper = InputGatherer() + + all_input_args: FrozenSet[InputArgumentBase] = reduce( + frozenset.union, + (mapper(ary) for ary in self.returns.values()), + frozenset() + ) + + return Map({input_arg.name: input_arg + for input_arg in all_input_args + if isinstance(input_arg, Placeholder)}) + + def get_placeholder(self, name: str) -> Placeholder: + """ + Returns the instance of :class:`pytato.array.Placeholder` corresponding + to the parameter *name* in function body. + """ + return self._placeholders[name] + + def _with_new_tags( + self: FunctionDefinition, tags: FrozenSet[Tag]) -> FunctionDefinition: + return attrs.evolve(self, tags=tags) + + def __call__(self, **kwargs: Array + ) -> Union[Array, + Tuple[Array, ...], + Dict[str, Array]]: + from pytato.array import _get_default_tags + from pytato.utils import are_shapes_equal + + # {{{ sanity checks + + if self.parameters != frozenset(kwargs): + missing_params = self.parameters - frozenset(kwargs) + extra_params = self.parameters - frozenset(kwargs) + + raise TypeError( + "Incorrect arguments." + + (f" Missing: '{missing_params}'." if missing_params else "") + + (f" Extra: '{extra_params}'." if extra_params else "") + ) + + for argname, expected_arg in self._placeholders.items(): + if expected_arg.dtype != kwargs[argname].dtype: + raise ValueError(f"Argument '{argname}' expected to " + f" be of type '{expected_arg.dtype}', got" + f" '{kwargs[argname].dtype}'.") + if not are_shapes_equal(expected_arg.shape, kwargs[argname].shape): + raise ValueError(f"Argument '{argname}' expected to " + f" have shape '{expected_arg.shape}', got" + f" '{kwargs[argname].shape}'.") + + # }}} + + call_site = Call(self, bindings=Map(kwargs), tags=_get_default_tags()) + + if self.return_type == ReturnType.ARRAY: + return call_site["_"] + elif self.return_type == ReturnType.TUPLE_OF_ARRAYS: + return tuple(call_site[f"_{iarg}"] + for iarg in range(len(self.returns))) + elif self.return_type == ReturnType.DICT_OF_ARRAYS: + return {kw: call_site[kw] for kw in self.returns} + else: + raise NotImplementedError(self.return_type) + + +class NamedCallResult(NamedArray): + """ + One of the arrays that are returned from a call to :class:`FunctionDefinition`. + + .. attribute:: call + + The function invocation that led to *self*. + + .. attribute:: name + + The name by which the returned array is referred to in + :attr:`FunctionDefinition.returns`. + """ + call: Call + name: str + _mapper_method: ClassVar[str] = "map_named_call_result" + + def __init__(self, + call: Call, + name: str) -> None: + super().__init__(call, name, + axes=call.function.returns[name].axes, + tags=call.function.returns[name].tags) + + def with_tagged_axis(self, iaxis: int, + tags: Union[Sequence[Tag], Tag]) -> Array: + raise ValueError("Tagging a NamedCallResult's axis is illegal, use" + " Call.with_tagged_axis instead") + + def tagged(self, + tags: Union[Iterable[Tag], Tag, None]) -> NamedCallResult: + raise ValueError("Tagging a NamedCallResult is illegal, use" + " Call.tagged instead") + + def without_tags(self, + tags: Union[Iterable[Tag], Tag, None], + verify_existence: bool = True, + ) -> NamedCallResult: + raise ValueError("Untagging a NamedCallResult is illegal, use" + " Call.without_tags instead") + + @property + def shape(self) -> ShapeType: + assert isinstance(self._container, Call) + return self._container.function.returns[self.name].shape + + @property + def dtype(self) -> _dtype_any: + assert isinstance(self._container, Call) + return self._container.function.returns[self.name].dtype + + +# eq=False to avoid equality comparison without EqualityMaper +@attrs.define(frozen=True, eq=False, hash=True, cache_hash=True, repr=False) +class Call(AbstractResultWithNamedArrays): + """ + Records an invocation to a :class:`FunctionDefinition`. + + .. attribute:: function + + The instance of :class:`FunctionDefinition` being called by this call site. + + .. attribute:: bindings + + A mapping from the placeholder names of :class:`FunctionDefinition` to + their corresponding parameters in the invocation to :attr:`function`. + + """ + function: FunctionDefinition + bindings: Map[str, Array] + + _mapper_method: ClassVar[str] = "map_call" + + copy = attrs.evolve + + def __post_init__(self) -> None: + # check that the invocation parameters and the function definition + # parameters agree with each other. + assert frozenset(self.bindings) == self.function.parameters + super().__post_init__() + + def __contains__(self, name: object) -> bool: + return name in self.function.returns + + def __iter__(self) -> Iterator[str]: + return iter(self.function.returns) + + def __getitem__(self, name: str) -> NamedCallResult: + return NamedCallResult(self, name) + + def __len__(self) -> int: + return len(self.function.returns) + + def _with_new_tags(self: Call, tags: FrozenSet[Tag]) -> Call: + return attrs.evolve(self, tags=tags) + +# }}} + + +# {{{ user-facing routines + +class _Guess: + pass + + +RE_ARGNAME = re.compile(r"^_pt_(\d+)$") + + +def trace_call(f: Callable[..., ReturnT], + *args: Array, + identifier: Optional[Hashable] = _Guess, + **kwargs: Array) -> ReturnT: + """ + Returns the expressions returned after calling *f* with the arguments + *args* and keyword arguments *kwargs*. The subexpressions in the returned + expressions are outlined (opposite of 'inlined') as a + :class:`~pytato.function.FunctionDefinition`. + + :arg identifier: A hashable object that acts as + :attr:`pytato.tags.FunctionIdentifier.identifier` for the + :class:`~pytato.tags.FunctionIdentifier` tagged to the outlined + :class:`~pytato.function.FunctionDefinition`. If ``None`` the function + definition is not tagged with a + :class:`~pytato.tags.FunctionIdentifier` tag, if ``_Guess`` the + function identifier is guessed from ``f.__name__``. + """ + from pytato.tags import FunctionIdentifier + from pytato.array import _get_default_tags + + if identifier is _Guess: + # partials might not have a __name__ attribute + identifier = getattr(f, "__name__", None) + + for kw in kwargs: + if RE_ARGNAME.match(kw): + # avoid collision between argument names + raise ValueError(f"Kw argument named '{kw}' not allowed.") + + # Get placeholders from the ``args``, ``kwargs``. + pl_args = tuple(Placeholder(f"in__pt_{iarg}", arg.shape, arg.dtype, + axes=arg.axes, tags=arg.tags) + for iarg, arg in enumerate(args)) + pl_kwargs = {kw: Placeholder(f"in_{kw}", arg.shape, + arg.dtype, axes=arg.axes, tags=arg.tags) + for kw, arg in kwargs.items()} + + # Pass the placeholders + output = f(*pl_args, **pl_kwargs) + + if isinstance(output, Array): + returns = {"_": output} + return_type = ReturnType.ARRAY + elif isinstance(output, tuple): + assert all(isinstance(el, Array) for el in output) + returns = {f"_{iout}": out for iout, out in enumerate(output)} + return_type = ReturnType.TUPLE_OF_ARRAYS + elif isinstance(output, dict): + assert all(isinstance(el, Array) for el in output.values()) + returns = output + return_type = ReturnType.DICT_OF_ARRAYS + else: + raise ValueError("The function being traced must return one of" + f"pytato.Array, tuple, dict. Got {type(output)}.") + + # construct the function + function = FunctionDefinition( + frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs), + return_type, + Map(returns), + tags=_get_default_tags() | (frozenset([FunctionIdentifier(identifier)]) + if identifier + else frozenset()) + ) + + # type-ignore-reason: return type is dependent on dynamic state i.e. + # ret_type and hence mypy is unhappy + return function( # type: ignore[return-value] + **{pl.name: arg for pl, arg in zip(pl_args, args)}, + **{pl_kwargs[kw].name: arg for kw, arg in kwargs.items()} + ) + +# }}} + +# vim:foldmethod=marker diff --git a/pytato/tags.py b/pytato/tags.py index 8d080f66e..ad6d33d55 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -10,9 +10,10 @@ .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative .. autoclass:: ExpandedDimsReshape +.. autoclass:: FunctionIdentifier """ -from typing import Tuple +from typing import Tuple, Hashable from pytools.tag import Tag, UniqueTag from dataclasses import dataclass @@ -125,3 +126,13 @@ class ExpandedDimsReshape(UniqueTag): frozenset({ExpandedDimsReshape(new_dims=(0, 2, 4))}) """ new_dims: Tuple[int, ...] + + +@dataclass(frozen=True) +class FunctionIdentifier(UniqueTag): + """ + A tag that can be attached to a + :class:`~pytato.function.FunctionDefinition` node to + to describe the function's identifier. + """ + identifier: Hashable From 988a2c3e3d066fdca0cf3cd6eeb9db7e90c9ad85 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Fri, 4 Nov 2022 23:44:17 -0500 Subject: [PATCH 02/11] Support mapper methods for FunctionDefintion, Call Co-authored-by: Andreas Kloeckner --- pytato/codegen.py | 11 +- pytato/equality.py | 27 ++++ pytato/tags.py | 28 +++- pytato/transform/__init__.py | 240 ++++++++++++++++++++++++++++++++++- pytato/visualization/dot.py | 25 +++- 5 files changed, 322 insertions(+), 9 deletions(-) diff --git a/pytato/codegen.py b/pytato/codegen.py index 63067fbd9..d95f96aa3 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -23,7 +23,7 @@ """ import dataclasses -from typing import Union, Dict, Tuple, List, Any +from typing import Union, Dict, Tuple, List, Any, Optional from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder, DataInterface, SizeParam, InputArgumentBase, @@ -102,12 +102,17 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc] ====================================== ===================================== """ - def __init__(self, target: Target) -> None: + def __init__(self, target: Target, + kernels_seen: Optional[Dict[str, lp.LoopKernel]] = None + ) -> None: super().__init__() self.bound_arguments: Dict[str, DataInterface] = {} self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator() self.target = target - self.kernels_seen: Dict[str, lp.LoopKernel] = {} + self.kernels_seen: Dict[str, lp.LoopKernel] = kernels_seen or {} + + def clone_for_callee(self) -> CodeGenPreprocessor: + return CodeGenPreprocessor(self.target, self.kernels_seen) def map_size_param(self, expr: SizeParam) -> Array: name = expr.name diff --git a/pytato/equality.py b/pytato/equality.py index 5eb1814c6..42c2978cd 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -31,6 +31,8 @@ IndexBase, IndexLambda, NamedArray, Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) +from pytato.function import Call, NamedCallResult, FunctionDefinition +from pytools import memoize_method if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -273,6 +275,31 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.tags == expr2.tags ) + @memoize_method + def map_function_definition(self, expr1: FunctionDefinition, expr2: Any + ) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.parameters == expr2.parameters + and (set(expr1.returns.keys()) == set(expr2.returns.keys())) + and all(self.rec(expr1.returns[k], expr2.returns[k]) + for k in expr1.returns) + and expr1.tags == expr2.tags + ) + + def map_call(self, expr1: Call, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and self.map_function_definition(expr1.function, expr2.function) + and frozenset(expr1.bindings) == frozenset(expr2.bindings) + and all(self.rec(bnd, + expr2.bindings[name]) + for name, bnd in expr1.bindings.items()) + ) + + def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool: + return (expr1.__class__ is expr2.__class__ + and expr1.name == expr2.name + and self.rec(expr1._container, expr2._container)) + # }}} # vim: fdm=marker diff --git a/pytato/tags.py b/pytato/tags.py index ad6d33d55..1e1180ebe 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -11,6 +11,8 @@ .. autoclass:: AssumeNonNegative .. autoclass:: ExpandedDimsReshape .. autoclass:: FunctionIdentifier +.. autoclass:: CallImplementationTag +.. autoclass:: InlineCallTag """ from typing import Tuple, Hashable @@ -131,8 +133,28 @@ class ExpandedDimsReshape(UniqueTag): @dataclass(frozen=True) class FunctionIdentifier(UniqueTag): """ - A tag that can be attached to a - :class:`~pytato.function.FunctionDefinition` node to - to describe the function's identifier. + A tag that can be attached to a :class:`~pytato.function.FunctionDefinition` + node to to describe the function's identifier. One can use this to refer + all instances of :class:`~pytato.function.FunctionDefinition`, for example in + transformations.transform.calls.concatenate_calls`. + + .. attribute:: identifier """ identifier: Hashable + + +@dataclass(frozen=True) +class CallImplementationTag(UniqueTag): + """ + A tag that can be attached to a :class:`~pytato.function.Call` node to + direct a :class:`~pytato.target.Target` how the call site should be + lowered. + """ + + +@dataclass(frozen=True) +class InlineCallTag(CallImplementationTag): + r""" + A :class:`CallImplementationTag` that directs the + :class:`pytato.target.Target` to inline the call site. + """ diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index a75083f9a..1ceb4c4bb 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1,5 +1,7 @@ from __future__ import annotations +from pytools import memoize_method + __copyright__ = """ Copyright (C) 2020 Matt Wala Copyright (C) 2020-21 Kaushik Kulkarni @@ -28,6 +30,7 @@ import logging import numpy as np +from immutables import Map from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, List, Mapping, Iterable, Tuple, Optional, Hashable) @@ -43,10 +46,12 @@ from pytato.distributed.nodes import ( DistributedSendRefHolder, DistributedRecv, DistributedSend) from pytato.loopy import LoopyCall, LoopyCallResult +from pytato.function import Call, NamedCallResult, FunctionDefinition from dataclasses import dataclass from pytato.tags import ImplStored from pymbolic.mapper.optimize import optimize_mapper + ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] MappedT = TypeVar("MappedT", Array, AbstractResultWithNamedArrays, ArrayOrNames) @@ -56,6 +61,7 @@ CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper IndexOrShapeExpr = TypeVar("IndexOrShapeExpr") R = FrozenSet[Array] +_SelfMapper = TypeVar("_SelfMapper", bound="Mapper") __doc__ = """ .. currentmodule:: pytato.transform @@ -100,6 +106,11 @@ .. class:: CombineT A type variable representing the type of a :class:`CombineMapper`. + +.. class:: _SelfMapper + + A type variable used to represent the type of a mapper in + :meth:`CopyMapper.clone_for_callee`. """ transform_logger = logging.getLogger(__file__) @@ -213,6 +224,8 @@ class CopyMapper(CachedMapper[ArrayOrNames]): The typical use of this mapper is to override individual ``map_`` methods in subclasses to permit term rewriting on an expression graph. + .. automethod:: clone_for_callee + .. note:: This does not copy the data of a :class:`pytato.array.DataWrapper`. @@ -229,6 +242,13 @@ def __call__(self, # type: ignore[override] expr: CopyMapperResultT) -> CopyMapperResultT: return self.rec(expr) + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + return type(self)() + def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...] ) -> Tuple[IndexOrShapeExpr, ...]: # type-ignore-reason: apparently mypy cannot substitute typevars @@ -372,6 +392,32 @@ def map_distributed_recv(self, expr: DistributedRecv) -> Array: shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + @memoize_method + def map_function_definition(self, + expr: FunctionDefinition) -> FunctionDefinition: + # spawn a new mapper to avoid unsound cache hits, since the namespace of the + # function's body is different from that of the caller. + new_mapper = self.clone_for_callee() + new_returns = {name: new_mapper(ret) + for name, ret in expr.returns.items()} + return FunctionDefinition(expr.parameters, + expr.return_type, + Map(new_returns), + tags=expr.tags + ) + + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + return Call(self.map_function_definition(expr.function), + Map({name: self.rec(bnd) + for name, bnd in expr.bindings.items()}), + tags=expr.tags, + ) + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + call = self.rec(expr._container) + assert isinstance(call, Call) + return NamedCallResult(call, expr.name) + class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): """ @@ -575,6 +621,26 @@ def map_distributed_recv(self, expr: DistributedRecv, shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + def map_function_definition(self, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> FunctionDefinition: + raise NotImplementedError("Function definitions are purposefully left" + " unimplemented as the default arguments to a new" + " DAG traversal are tricky to guess.") + + def map_call(self, expr: Call, + *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays: + return Call(self.map_function_definition(expr.function, *args, **kwargs), + Map({name: self.rec(bnd, *args, **kwargs) + for name, bnd in expr.bindings.items()}), + tags=expr.tags, + ) + + def map_named_call_result(self, expr: NamedCallResult, + *args: Any, **kwargs: Any) -> Array: + call = self.rec(expr._container, *args, **kwargs) + assert isinstance(call, Call) + return NamedCallResult(call, expr.name) + # }}} @@ -685,6 +751,18 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> CombineT: return self.combine(*self.rec_idx_or_size_tuple(expr.shape)) + def map_function_definition(self, expr: FunctionDefinition) -> CombineT: + raise NotImplementedError("Combining results from a callee expression" + " is context-dependent. Derived classes" + " must override map_function_definition.") + + def map_call(self, expr: Call) -> CombineT: + return self.combine(self.map_function_definition(expr.function), + *[self.rec(bnd) for bnd in expr.bindings.values()]) + + def map_named_call_result(self, expr: NamedCallResult) -> CombineT: + return self.rec(expr._container) + # }}} @@ -752,6 +830,18 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> R: return self.combine(frozenset([expr]), super().map_distributed_recv(expr)) + def map_function_definition(self, expr: FunctionDefinition) -> R: + # do not include arrays from the function's body as it would involve + # putting arrays from different namespaces into the same collection. + return frozenset() + + def map_call(self, expr: Call) -> R: + return self.combine(self.map_function_definition(expr.function), + *[self.rec(bnd) for bnd in expr.bindings.values()]) + + def map_named_call_result(self, expr: NamedCallResult) -> R: + return self.rec(expr._container) + # }}} @@ -796,6 +886,27 @@ def map_data_wrapper(self, expr: DataWrapper) -> FrozenSet[InputArgumentBase]: def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]: return frozenset([expr]) + @memoize_method + def map_function_definition(self, expr: FunctionDefinition + ) -> FrozenSet[InputArgumentBase]: + # get rid of placeholders local to the function. + new_mapper = InputGatherer() + all_callee_inputs = new_mapper.combine(*[new_mapper(ret) + for ret in expr.returns.values()]) + result: Set[InputArgumentBase] = set() + for inp in all_callee_inputs: + if isinstance(inp, Placeholder): + if inp.name in expr.parameters: + # drop, reference to argument + pass + else: + raise ValueError("function definition refers to non-argument " + f"placeholder named '{inp.name}'") + else: + result.add(inp) + + return frozenset(result) + # }}} @@ -814,6 +925,12 @@ def combine(self, *args: FrozenSet[SizeParam] def map_size_param(self, expr: SizeParam) -> FrozenSet[SizeParam]: return frozenset([expr]) + @memoize_method + def map_function_definition(self, expr: FunctionDefinition + ) -> FrozenSet[SizeParam]: + return self.combine(*[self.rec(ret) + for ret in expr.returns.values()]) + # }}} @@ -830,6 +947,9 @@ class WalkMapper(Mapper): .. automethod:: post_visit """ + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + return type(self)() + def visit(self, expr: Any, *args: Any, **kwargs: Any) -> bool: """ If this method returns *True*, *expr* is traversed during the walk. @@ -982,6 +1102,36 @@ def map_loopy_call(self, expr: LoopyCall, *args: Any, **kwargs: Any) -> None: self.post_visit(expr, *args, **kwargs) + def map_function_definition(self, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee() + for subexpr in expr.returns.values(): + new_mapper(subexpr, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + + def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> None: + if not self.visit(expr): + return + + self.map_function_definition(expr.function) + for bnd in expr.bindings.values(): + self.rec(bnd) + + self.post_visit(expr) + + def map_named_call_result(self, expr: NamedCallResult, + *args: Any, **kwargs: Any) -> None: + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr._container, *args, **kwargs) + + self.post_visit(expr, *args, **kwargs) + # }}} @@ -1019,6 +1169,11 @@ class TopoSortMapper(CachedWalkMapper): """A mapper that creates a list of nodes in topological order. :members: topological_order + + .. note:: + + Does not consider the nodes inside a + :class:`~pytato.function.FunctionDefinition`. """ def __init__(self) -> None: @@ -1033,6 +1188,12 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] def post_visit(self, expr: Any) -> None: # type: ignore[override] self.topological_order.append(expr) + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def map_function_definition(self, # type: ignore[override] + expr: FunctionDefinition) -> None: + # do nothing as it includes arrays from a different namespace. + return + # }}} @@ -1048,6 +1209,11 @@ def __init__(self, map_fn: Callable[[ArrayOrNames], ArrayOrNames]) -> None: super().__init__() self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn + def clone_for_callee(self: _SelfMapper) -> _SelfMapper: + # type-ignore-reason: self.__init__ has a different function signature + # than Mapper.__init__ and does not have map_fn + return type(self)(self.map_fn) # type: ignore[call-arg,attr-defined] + # type-ignore-reason:incompatible with Mapper.rec() def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] if expr in self._cache: @@ -1102,7 +1268,14 @@ def _materialize_if_mpms(expr: Array, class MPMSMaterializer(Mapper): - """See :func:`materialize_with_mpms` for an explanation.""" + """ + See :func:`materialize_with_mpms` for an explanation. + + .. attribute:: nsuccessors + + A mapping from a node in the expression graph (i.e. an + :class:`~pytato.Array`) to its number of successors. + """ def __init__(self, nsuccessors: Mapping[Array, int]): super().__init__() self.nsuccessors = nsuccessors @@ -1252,6 +1425,52 @@ def map_distributed_recv(self, expr: DistributedRecv ) -> MPMSMaterializerAccumulator: return MPMSMaterializerAccumulator(frozenset([expr]), expr) + @memoize_method + def map_function_definition(self, expr: FunctionDefinition + ) -> FunctionDefinition: + # spawn a new traversal here. + from pytato.analysis import get_nusers + + returns_dict_of_named_arys = DictOfNamedArrays(expr.returns) + func_nsuccessors = get_nusers(returns_dict_of_named_arys) + new_mapper = MPMSMaterializer(func_nsuccessors) + new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()} + return FunctionDefinition(expr.parameters, + expr.return_type, + Map(new_returns), + tags=expr.tags) + + @memoize_method + def map_call(self, expr: Call) -> Call: + return Call(self.map_function_definition(expr.function), + Map({name: self.rec(bnd).expr + for name, bnd in expr.bindings.items()}), + tags=expr.tags) + + def map_named_call_result(self, expr: NamedCallResult + ) -> MPMSMaterializerAccumulator: + assert isinstance(expr._container, Call) + new_call = self.map_call(expr._container) + new_result = new_call[expr.name] + + assert isinstance(new_result, NamedCallResult) + assert isinstance(new_result._container, Call) + + # do not use _materialize_if_mpms as tagging a NamedArray is illegal. + if new_result.tags_of_type(ImplStored): + return MPMSMaterializerAccumulator(frozenset([new_result]), + new_result) + else: + from functools import reduce + materialized_predecessors: FrozenSet[Array] = ( + reduce(frozenset.union, + (self.rec(bnd).materialized_predecessors + for bnd in new_result._container.bindings.values()), + frozenset()) + ) + return MPMSMaterializerAccumulator(materialized_predecessors, + new_result) + # }}} @@ -1487,9 +1706,26 @@ def map_distributed_send_ref_holder( def map_distributed_recv(self, expr: DistributedRecv) -> None: self.rec_idx_or_size_tuple(expr, expr.shape) + def map_function_definition(self, expr: FunctionDefinition, *args: Any + ) -> None: + raise AssertionError("Control shouldn't reach at this point." + " Instantiate another UsersCollector to" + " traverse the callee function.") + + def map_call(self, expr: Call, *args: Any) -> None: + for bnd in expr.bindings.values(): + self.rec(bnd) + + def map_named_call(self, expr: NamedCallResult, *args: Any) -> None: + assert isinstance(expr._container, Call) + for bnd in expr._container.bindings.values(): + self.node_to_users.setdefault(bnd, set()).add(expr) + + self.rec(expr._container) + def get_users(expr: ArrayOrNames) -> Dict[ArrayOrNames, - Set[ArrayOrNames]]: + Set[ArrayOrNames]]: """ Returns a mapping from node in *expr* to its direct users. """ diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 6f302fad4..b836e3814 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -37,6 +37,7 @@ from pytools.tag import Tag from pytools.codegen import CodeGenerator as CodeGeneratorBase from pytato.loopy import LoopyCall +from pytato.function import Call, NamedCallResult from pytato.array import ( Array, DataWrapper, DictOfNamedArrays, IndexLambda, InputArgumentBase, @@ -243,6 +244,28 @@ def map_distributed_send_ref_holder( self.nodes[expr] = info + def map_call(self, expr: Call) -> None: + for bnd in expr.bindings.values(): + self.rec(bnd) + + self.nodes[expr] = DotNodeInfo( + title=expr.__class__.__name__, + edges=dict(expr.bindings), + fields={ + "addr": hex(id(expr)), + "tags": stringify_tags(expr.tags), + } + ) + + def map_named_call_result(self, expr: NamedCallResult) -> None: + self.rec(expr._container) + self.nodes[expr] = DotNodeInfo( + title=expr.__class__.__name__, + edges={"": expr._container}, + fields={"addr": hex(id(expr)), + "name": expr.name}, + ) + def dot_escape(s: str) -> str: # "\" and HTML are significant in graphviz. @@ -551,4 +574,4 @@ def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, from pytools.graphviz import show_dot show_dot(dot_code, **kwargs) -# }}} +# vim:fdm=marker From 7eaf163d352af19509421f8fa02958c82903fd55 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 5 Nov 2022 00:01:15 -0500 Subject: [PATCH 03/11] Implement pt.inline_calls --- pytato/__init__.py | 3 + pytato/codegen.py | 11 ++-- pytato/target/loopy/codegen.py | 70 ++++++++++++-------- pytato/transform/__init__.py | 8 +++ pytato/transform/calls.py | 114 +++++++++++++++++++++++++++++++++ 5 files changed, 177 insertions(+), 29 deletions(-) create mode 100644 pytato/transform/calls.py diff --git a/pytato/__init__.py b/pytato/__init__.py index 46794a9af..0509a8c9f 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -93,6 +93,7 @@ def set_debug_enabled(flag: bool) -> None: get_dot_graph_from_partition, show_fancy_placeholder_data_flow, ) +from pytato.transform.calls import tag_all_calls_to_be_inlined, inline_calls import pytato.analysis as analysis import pytato.tags as tags import pytato.function as function @@ -165,6 +166,8 @@ def set_debug_enabled(flag: bool) -> None: "DistributedGraphPart", "DistributedGraphPartition", + "tag_all_calls_to_be_inlined", "inline_calls", + "find_distributed_partition", "number_distributed_tags", diff --git a/pytato/codegen.py b/pytato/codegen.py index d95f96aa3..0bc85d649 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -267,6 +267,7 @@ class PreprocessResult: def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult: """Preprocess a computation for code generation.""" from pytato.transform import copy_dict_of_named_arrays + from pytato.transform.calls import inline_calls check_validity_of_outputs(outputs) @@ -294,12 +295,14 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult: # }}} - mapper = CodeGenPreprocessor(target) + new_outputs = inline_calls(outputs) + assert isinstance(new_outputs, DictOfNamedArrays) - new_outputs = copy_dict_of_named_arrays(outputs, mapper) + mapper = CodeGenPreprocessor(target) + new_outputs = copy_dict_of_named_arrays(new_outputs, mapper) return PreprocessResult(outputs=new_outputs, - compute_order=tuple(output_order), - bound_arguments=mapper.bound_arguments) + compute_order=tuple(output_order), + bound_arguments=mapper.bound_arguments) # vim: fdm=marker diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index ad7b73491..7d8760ed6 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -46,6 +46,7 @@ from pytato.transform import Mapper from pytato.scalar_expr import ScalarExpression, INT_CLASSES from pytato.codegen import preprocess, normalize_outputs, SymbolicIndex +from pytato.function import Call, NamedCallResult from pytato.loopy import LoopyCall from pytato.tags import (ImplStored, ImplInlined, Named, PrefixNamed, ImplementationStrategy) @@ -575,6 +576,19 @@ def _get_sub_array_ref(array: Array, name: str) -> "lp.symbolic.SubArrayRef": state.update_kernel(kernel) + def map_named_call_result(self, expr: NamedCallResult, + state: CodeGenState) -> None: + raise NotImplementedError("LoopyTarget does not support outlined calls" + " (yet). As a fallback, the call" + " could be inlined using" + " pt.mark_all_calls_to_be_inlined.") + + def map_call(self, expr: Call, state: CodeGenState) -> None: + raise NotImplementedError("LoopyTarget does not support outlined calls" + " (yet). As a fallback, the call" + " could be inlined using" + " pt.mark_all_calls_to_be_inlined.") + # }}} @@ -972,36 +986,30 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]], .. note:: - :mod:`pytato` metadata :math:`\mapsto` :mod:`loopy` metadata semantics: - - - Inames that index over an :class:`~pytato.array.Array`'s axis in the - allocation instruction are tagged with the corresponding - :class:`~pytato.array.Axis`'s tags. The caller may choose to not - propagate axis tags of type *axis_tag_t_to_not_propagate*. - - :attr:`pytato.Array.tags` of inputs/outputs in *outputs* - would be copied over to the tags of the corresponding - :class:`loopy.ArrayArg`. The caller may choose to not - propagate array tags of type *array_tag_t_to_not_propagate*. - - Arrays tagged with :class:`pytato.tags.ImplStored` would have their - tags copied over to the tags of corresponding - :class:`loopy.TemporaryVariable`. The caller may choose to not - propagate array tags of type *array_tag_t_to_not_propagate*. + - :mod:`pytato` metadata :math:`\mapsto` :mod:`loopy` metadata semantics: + + - Inames that index over an :class:`~pytato.array.Array`'s axis in the + allocation instruction are tagged with the corresponding + :class:`~pytato.array.Axis`'s tags. The caller may choose to not + propagate axis tags of type *axis_tag_t_to_not_propagate*. + - :attr:`pytato.Array.tags` of inputs/outputs in *outputs* + would be copied over to the tags of the corresponding + :class:`loopy.ArrayArg`. The caller may choose to not + propagate array tags of type *array_tag_t_to_not_propagate*. + - Arrays tagged with :class:`pytato.tags.ImplStored` would have their + tags copied over to the tags of corresponding + :class:`loopy.TemporaryVariable`. The caller may choose to not + propagate array tags of type *array_tag_t_to_not_propagate*. + + .. warning:: + + Currently only :class:`~pytato.function.Call` nodes that are tagged with + :class:`pytato.tags.InlineCallTag` can be lowered to :mod:`loopy` IR. """ result_is_dict = isinstance(result, (dict, DictOfNamedArrays)) orig_outputs: DictOfNamedArrays = normalize_outputs(result) - # optimization: remove any ImplStored tags on outputs to avoid redundant - # store-load operations (see https://github.com/inducer/pytato/issues/415) - orig_outputs = DictOfNamedArrays( - {name: (output.without_tags(ImplStored(), - verify_existence=False) - if not isinstance(output, - InputArgumentBase) - else output) - for name, output in orig_outputs._data.items()}, - tags=orig_outputs.tags) - del result if cl_device is not None: @@ -1017,6 +1025,18 @@ def generate_loopy(result: Union[Array, DictOfNamedArrays, Dict[str, Array]], preproc_result = preprocess(orig_outputs, target) outputs = preproc_result.outputs + # optimization: remove any ImplStored tags on outputs to avoid redundant + # store-load operations (see https://github.com/inducer/pytato/issues/415) + # (This must be done after all the calls have been inlined) + outputs = DictOfNamedArrays( + {name: (output.without_tags(ImplStored(), + verify_existence=False) + if not isinstance(output, + InputArgumentBase) + else output) + for name, output in outputs._data.items()}, + tags=outputs.tags) + compute_order = preproc_result.compute_order if options is None: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 1ceb4c4bb..953cbfa46 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -96,6 +96,14 @@ .. autofunction:: tag_user_nodes .. autofunction:: rec_get_user_nodes + +Transforming call sites +----------------------- + +.. automodule:: pytato.transform.calls + +.. currentmodule:: pytato.transform + Internal stuff that is only here because the documentation tool wants it ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py new file mode 100644 index 000000000..a25dd8311 --- /dev/null +++ b/pytato/transform/calls.py @@ -0,0 +1,114 @@ +""" +.. currentmodule:: pytato.transform.calls + +.. autofunction:: inline_calls +.. autofunction:: tag_all_calls_to_be_inlined +""" +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +from immutables import Map +from pytato.transform import (ArrayOrNames, CopyMapper) +from pytato.array import (AbstractResultWithNamedArrays, Array, + DictOfNamedArrays, Placeholder) + +from pytato.function import Call, NamedCallResult +from pytato.tags import InlineCallTag + + +# {{{ inlining + +class PlaceholderSubstitutor(CopyMapper): + """ + .. attribute:: substitutions + + A mapping from the placeholder name to the array that it is to be + substituted with. + """ + def __init__(self, substitutions: Map[str, Array]) -> None: + super().__init__() + self.substitutions = substitutions + + def map_placeholder(self, expr: Placeholder) -> Array: + return self.substitutions[expr.name] + + +class Inliner(CopyMapper): + """ + Primary mapper for :func:`inline_calls`. + """ + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + # inline call sites within the callee. + new_expr = super().map_call(expr) + assert isinstance(new_expr, Call) + + if expr.tags_of_type(InlineCallTag): + substitutor = PlaceholderSubstitutor(expr.bindings) + + return DictOfNamedArrays( + {name: substitutor(ret) + for name, ret in new_expr.function.returns.items()}, + tags=expr.tags + ) + else: + return new_expr + + def map_named_call_result(self, expr: NamedCallResult) -> Array: + new_call = self.rec(expr._container) + assert isinstance(new_call, AbstractResultWithNamedArrays) + return new_call[expr.name] + + +class InlineMarker(CopyMapper): + """ + Primary mapper for :func:`tag_all_calls_to_be_inlined`. + """ + def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: + return super().map_call(expr).tagged(InlineCallTag()) + + +def inline_calls(expr: ArrayOrNames) -> ArrayOrNames: + """ + Returns a copy of *expr* with call sites tagged with + :class:`pytato.tags.InlineCallTag` inlined into the expression graph. + """ + inliner = Inliner() + return inliner(expr) + + +def tag_all_calls_to_be_inlined(expr: ArrayOrNames) -> ArrayOrNames: + """ + Returns a copy of *expr* with all reachable instances of + :class:`pytato.function.Call` nodes tagged with + :class:`pytato.tags.InlineCallTag`. + + .. note:: + + This routine does NOT inline calls, to inline the calls + use :func:`tag_all_calls_to_be_inlined` on this routine's + output. + """ + return InlineMarker()(expr) + +# }}} + +# vim:foldmethod=marker From 235a62722a82a55bd42b05cdb7514aba0265d291 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Sat, 5 Nov 2022 01:02:31 -0500 Subject: [PATCH 04/11] Add pt.tracing module to docs --- doc/array.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/doc/array.rst b/doc/array.rst index a31fb2e53..56772affa 100644 --- a/doc/array.rst +++ b/doc/array.rst @@ -5,12 +5,17 @@ Array Expressions .. automodule:: pytato.array +Functions in Pytato IR +---------------------- + +.. automodule:: pytato.function Raising :class:`~pytato.array.IndexLambda` nodes ------------------------------------------------ .. automodule:: pytato.raising + Calling :mod:`loopy` kernels in an array expression --------------------------------------------------- From 7e0619af88577d51c21d830ff4c31a621bb1beb4 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Thu, 18 May 2023 13:42:27 -0500 Subject: [PATCH 05/11] Test function calls via inlining --- test/test_codegen.py | 87 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/test/test_codegen.py b/test/test_codegen.py index 874906528..e02fbc85c 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1854,6 +1854,93 @@ def test_pad(ctx_factory): np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array) +def test_function_call(ctx_factory): + cl_ctx = ctx_factory() + cq = cl.CommandQueue(cl_ctx) + + def f(x): + return 2*x + + def g(x): + return 2*x, 3*x + + def h(x, y): + return {"twice": 2*x+y, "thrice": 3*x+y} + + def build_expression(tracer): + x = pt.arange(500, dtype=np.float32) + twice_x = tracer(f, x) + twice_x_2, thrice_x_2 = tracer(g, x) + + result = tracer(h, x, 2*x) + twice_x_3 = result["twice"] + thrice_x_3 = result["thrice"] + + return {"foo": 3.14 + twice_x_3, + "bar": 4 * thrice_x_3, + "baz": 65 * twice_x, + "quux": 7 * twice_x_2} + + result1 = pt.tag_all_calls_to_be_inlined( + pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) + result2 = pt.make_dict_of_named_arrays( + build_expression(lambda fn, *args: fn(*args))) + + _, outputs = pt.generate_loopy(result1)(cq, out_host=True) + _, expected = pt.generate_loopy(result2)(cq, out_host=True) + + assert len(outputs) == len(expected) + + for key in outputs.keys(): + np.testing.assert_allclose(outputs[key], expected[key]) + + +def test_nested_function_calls(ctx_factory): + from functools import partial + + ctx = ctx_factory() + cq = cl.CommandQueue(ctx) + + rng = np.random.default_rng(0) + x_np = rng.random((10,)) + + x = pt.make_placeholder("x", 10, np.float64).tagged(pt.tags.ImplStored()) + prg = pt.generate_loopy({"out1": 3*x, "out2": x}) + _, out = prg(cq, x=x_np) + np.testing.assert_allclose(out["out1"], 3*x_np) + np.testing.assert_allclose(out["out2"], x_np) + ref_tracer = lambda f, *args, identifier: f(*args) # noqa: E731 + + def foo(tracer, x, y): + return 2*x + 3*y + + def bar(tracer, x, y): + foo_x_y = tracer(partial(foo, tracer), x, y, identifier="foo") + return foo_x_y * x * y + + def call_bar(tracer, x, y): + return tracer(partial(bar, tracer), x, y, identifier="bar") + + x1_np, y1_np = rng.random((2, 13, 29)) + x2_np, y2_np = rng.random((2, 4, 29)) + x1, y1 = pt.make_data_wrapper(x1_np), pt.make_data_wrapper(y1_np) + x2, y2 = pt.make_data_wrapper(x2_np), pt.make_data_wrapper(y2_np) + result = pt.make_dict_of_named_arrays({"out1": call_bar(pt.trace_call, x1, y1), + "out2": call_bar(pt.trace_call, x2, y2)} + ) + result = pt.tag_all_calls_to_be_inlined(result) + expect = pt.make_dict_of_named_arrays({"out1": call_bar(ref_tracer, x1, y1), + "out2": call_bar(ref_tracer, x2, y2)} + ) + + _, result_out = pt.generate_loopy(result)(cq) + _, expect_out = pt.generate_loopy(expect)(cq) + + assert result_out.keys() == expect_out.keys() + for k in expect_out: + np.testing.assert_allclose(result_out[k], expect_out[k]) + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From c7066dc2f64b88ad4bbbe3b2b44a1772f0dc3b5f Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Tue, 14 Mar 2023 13:11:15 -0500 Subject: [PATCH 06/11] Add pt.analysis.get_num_call_sites --- pytato/analysis/__init__.py | 58 +++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index a3c5e99ce..f279a3524 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -31,9 +31,11 @@ DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, ShapeType) +from pytato.function import FunctionDefinition, Call from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper from pytato.loopy import LoopyCall from pymbolic.mapper.optimize import optimize_mapper +from pytools import memoize_method if TYPE_CHECKING: from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder @@ -47,6 +49,8 @@ .. autofunction:: get_num_nodes +.. autofunction:: get_num_call_sites + .. autoclass:: DirectPredecessorsGetter """ @@ -388,3 +392,57 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: return ncm.count # }}} + + +# {{{ CallSiteCountMapper + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class CallSiteCountMapper(CachedWalkMapper): + """ + Counts the number of :class:`~pytato.Call` nodes in a DAG. + + .. attribute:: count + + The number of nodes. + """ + + def __init__(self) -> None: + super().__init__() + self.count = 0 + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] + return id(expr) + + @memoize_method + def map_function_definition(self, /, expr: FunctionDefinition, + *args: Any, **kwargs: Any) -> None: + if not self.visit(expr): + return + + new_mapper = self.clone_for_callee() + for subexpr in expr.returns.values(): + new_mapper(subexpr, *args, **kwargs) + + self.count += new_mapper.count + + self.post_visit(expr, *args, **kwargs) + + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def post_visit(self, expr: Any) -> None: # type: ignore[override] + if isinstance(expr, Call): + self.count += 1 + + +def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: + """Returns the number of nodes in DAG *outputs*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + cscm = CallSiteCountMapper() + cscm(outputs) + + return cscm.count + +# }}} From 37cbe3115dcfd78e190d72fda4f796a95ee839e9 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 29 May 2023 17:32:46 -0500 Subject: [PATCH 07/11] Drop asciidag visualization --- .pylintrc-local.yml | 1 - examples/visualization.py | 2 - pytato/__init__.py | 5 +- pytato/visualization/__init__.py | 4 - pytato/visualization/ascii.py | 124 ------------------------------- requirements.txt | 1 - test/test_pytato.py | 31 -------- 7 files changed, 2 insertions(+), 166 deletions(-) delete mode 100644 pytato/visualization/ascii.py diff --git a/.pylintrc-local.yml b/.pylintrc-local.yml index f5a171e96..ae0f92032 100644 --- a/.pylintrc-local.yml +++ b/.pylintrc-local.yml @@ -1,6 +1,5 @@ - arg: ignored-modules val: - - asciidag - matplotlib - ipykernel - ply diff --git a/examples/visualization.py b/examples/visualization.py index 2b569a392..ac71e6060 100755 --- a/examples/visualization.py +++ b/examples/visualization.py @@ -22,8 +22,6 @@ def main(): stack = pt.stack([array, 2*array, array + 6]) result = stack @ stack.T - pt.show_ascii_graph(result) - dot_code = pt.get_dot_graph(result) with open(GRAPH_DOT, "w") as outf: diff --git a/pytato/__init__.py b/pytato/__init__.py index 0509a8c9f..572e4a7ab 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -89,7 +89,6 @@ def set_debug_enabled(flag: bool) -> None: from pytato.target.loopy import LoopyPyOpenCLTarget from pytato.target.python.jax import generate_jax from pytato.visualization import (get_dot_graph, show_dot_graph, - get_ascii_graph, show_ascii_graph, get_dot_graph_from_partition, show_fancy_placeholder_data_flow, ) @@ -136,8 +135,8 @@ def set_debug_enabled(flag: bool) -> None: "Target", "LoopyPyOpenCLTarget", - "get_dot_graph", "show_dot_graph", "get_ascii_graph", - "show_ascii_graph", "get_dot_graph_from_partition", + "get_dot_graph", "show_dot_graph", + "get_dot_graph_from_partition", "show_fancy_placeholder_data_flow", "abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh", diff --git a/pytato/visualization/__init__.py b/pytato/visualization/__init__.py index 6fed19fec..e0138a78c 100644 --- a/pytato/visualization/__init__.py +++ b/pytato/visualization/__init__.py @@ -2,19 +2,15 @@ .. currentmodule:: pytato .. automodule:: pytato.visualization.dot -.. automodule:: pytato.visualization.ascii .. automodule:: pytato.visualization.fancy_placeholder_data_flow """ from .dot import get_dot_graph, show_dot_graph, get_dot_graph_from_partition -from .ascii import get_ascii_graph, show_ascii_graph from .fancy_placeholder_data_flow import show_fancy_placeholder_data_flow __all__ = [ "get_dot_graph", "show_dot_graph", "get_dot_graph_from_partition", - "get_ascii_graph", "show_ascii_graph", - "show_fancy_placeholder_data_flow", ] diff --git a/pytato/visualization/ascii.py b/pytato/visualization/ascii.py deleted file mode 100644 index e417a9bfc..000000000 --- a/pytato/visualization/ascii.py +++ /dev/null @@ -1,124 +0,0 @@ -""" -.. currentmodule:: pytato - -.. autofunction:: get_ascii_graph -.. autofunction:: show_ascii_graph -""" -__copyright__ = """ -Copyright (C) 2021 University of Illinois Board of Trustees -""" - -__license__ = """ -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in -all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -THE SOFTWARE. -""" - -from typing import Union, List, Dict -from pytato.transform import ArrayOrNames -from pytato.array import Array, DictOfNamedArrays, InputArgumentBase -from pytato.visualization.dot import ArrayToDotNodeInfoMapper -from pytato.codegen import normalize_outputs -from pytools import UniqueNameGenerator - - -# {{{ Show ASCII representation of DAG - -def get_ascii_graph(result: Union[Array, DictOfNamedArrays], - use_color: bool = True) -> str: - """Return a string representing the computation of *result* - using the `asciidag `_ package. - - :arg result: Outputs of the computation (cf. - :func:`pytato.generate_loopy`). - :arg use_color: Colorized output - """ - outputs: DictOfNamedArrays = normalize_outputs(result) - del result - - mapper = ArrayToDotNodeInfoMapper() - for elem in outputs._data.values(): - mapper(elem) - - nodes = mapper.nodes - - input_arrays: List[Array] = [] - internal_arrays: List[ArrayOrNames] = [] - array_to_id: Dict[ArrayOrNames, str] = {} - - id_gen = UniqueNameGenerator() - for array in nodes: - array_to_id[array] = id_gen("array") - if isinstance(array, InputArgumentBase): - input_arrays.append(array) - else: - internal_arrays.append(array) - - # Since 'asciidag' prints the DAG from top to bottom (ie, with the inputs - # at the bottom), we need to invert our representation of it, that is, the - # 'parents' constructor argument to Node() actually means 'children'. - from asciidag.node import Node # type: ignore[import] - asciidag_nodes: Dict[ArrayOrNames, Node] = {} - - from collections import defaultdict - asciidag_edges: Dict[ArrayOrNames, List[ArrayOrNames]] = defaultdict(list) - - # Reverse edge directions - for array in internal_arrays: - for _, v in nodes[array].edges.items(): - asciidag_edges[v].append(array) - - # Add the internal arrays in reversed order - for array in internal_arrays[::-1]: - ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]] - - if array == internal_arrays[-1]: - ary_edges.append(Node("Outputs")) - - asciidag_nodes[array] = Node(f"{nodes[array].title}", - parents=ary_edges) - - # Add the input arrays last since they have no predecessors - for array in input_arrays: - ary_edges = [asciidag_nodes[v] for v in asciidag_edges[array]] - asciidag_nodes[array] = Node(f"{nodes[array].title}", parents=ary_edges) - - input_node = Node("Inputs", parents=[asciidag_nodes[v] for v in input_arrays]) - - from asciidag.graph import Graph # type: ignore[import] - from io import StringIO - - f = StringIO() - graph = Graph(fh=f, use_color=use_color) - - graph.show_nodes([input_node]) - - # Get the graph and remove trailing whitespace - res = "\n".join([s.rstrip() for s in f.getvalue().split("\n")]) - - return res - - -def show_ascii_graph(result: Union[Array, DictOfNamedArrays]) -> None: - """Print a graph representing the computation of *result* to stdout using the - `asciidag `_ package. - - :arg result: Outputs of the computation (cf. - :func:`pytato.generate_loopy`) or the output of :func:`get_dot_graph`. - """ - - print(get_ascii_graph(result, use_color=True)) diff --git a/requirements.txt b/requirements.txt index 7e69f17f9..4f9b14d65 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,4 @@ git+https://github.com/inducer/pymbolic.git#egg=pymbolic git+https://github.com/inducer/genpy.git#egg=genpy git+https://github.com/inducer/loopy.git#egg=loopy -asciidag mako diff --git a/test/test_pytato.py b/test/test_pytato.py index 5d2343ee9..5b35bdde7 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -371,37 +371,6 @@ def test_userscollector(): assert nuc[dag] == 0 -def test_asciidag(): - pytest.importorskip("asciidag") - - n = pt.make_size_param("n") - array = pt.make_placeholder(name="array", shape=n, dtype=np.float64) - stack = pt.stack([array, 2*array, array + 6]) - y = stack @ stack.T - - from pytato import get_ascii_graph - - res = get_ascii_graph(y, use_color=False) - - ref_str = r"""* Inputs -*-. Placeholder -|\ \ -* | | IndexLambda -| |/ -|/| -| * IndexLambda -|/ -* Stack -|\ -* | AxisPermutation -|/ -* Einsum -* Outputs -""" - - assert res == ref_str - - def test_linear_complexity_inequality(): # See https://github.com/inducer/pytato/issues/163 import pytato as pt From d5a36a9df1066b65b03246db4b1d2825d99fcb5f Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 29 May 2023 19:35:48 -0500 Subject: [PATCH 08/11] Better explain DistributedGraphPartition.name_to_output --- pytato/distributed/partition.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index a59bd0e40..15f4f89ca 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -249,7 +249,9 @@ class DistributedGraphPartition: .. attribute:: name_to_output Mapping of placeholder names to the respective :class:`pytato.array.Array` - they represent. + they represent. This is where the actual expressions are stored, for + all parts. Observe that the :class:`DistributedGraphPart`, for the most + part, only stores names. """ parts: Mapping[PartId, DistributedGraphPart] name_to_output: Mapping[str, Array] From 96a389855ec65a0c295f8ce67f95987050742add Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 29 May 2023 19:36:19 -0500 Subject: [PATCH 09/11] Fix sectioning markers in pt.analysis --- pytato/analysis/__init__.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index f279a3524..d0fd6ef1e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -55,6 +55,8 @@ """ +# {{{ NUserCollector + class NUserCollector(Mapper): """ A :class:`pytato.transform.CachedWalkMapper` that records the number of @@ -172,6 +174,8 @@ def map_distributed_recv(self, expr: DistributedRecv) -> None: self.nusers[dim] += 1 self.rec(dim) +# }}} + def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]: """ @@ -185,6 +189,8 @@ def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]: return nuser_collector.nusers +# {{{ is_einsum_similar_to_subscript + def _get_indices_from_input_subscript(subscript: str, is_output: bool, ) -> Tuple[str, ...]: @@ -212,8 +218,6 @@ def _get_indices_from_input_subscript(subscript: str, # }}} - # {{{ - if is_output: if len(normalized_indices) != len(set(normalized_indices)): repeated_idx = next(idx @@ -277,6 +281,8 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: return True +# }}} + # {{{ DirectPredecessorsGetter @@ -446,3 +452,5 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: return cscm.count # }}} + +# vim: fdm=marker From f89c796626200d831e2eb1b0b91bf0e112c7cc00 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 May 2023 17:37:30 -0500 Subject: [PATCH 10/11] Distributed partition: preserve names of overall computation outputs --- pytato/distributed/execute.py | 2 +- pytato/distributed/partition.py | 16 ++++++++++++++-- pytato/distributed/tags.py | 1 + 3 files changed, 16 insertions(+), 3 deletions(-) diff --git a/pytato/distributed/execute.py b/pytato/distributed/execute.py index 79444cd22..fd1e6c6c4 100644 --- a/pytato/distributed/execute.py +++ b/pytato/distributed/execute.py @@ -231,7 +231,7 @@ def wait_for_some_recvs() -> None: assert count == 0 assert name not in context - return context + return {name: context[name] for name in partition.overall_output_names} # }}} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 15f4f89ca..c8ab048f3 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -251,10 +251,19 @@ class DistributedGraphPartition: Mapping of placeholder names to the respective :class:`pytato.array.Array` they represent. This is where the actual expressions are stored, for all parts. Observe that the :class:`DistributedGraphPart`, for the most - part, only stores names. + part, only stores names. These "outputs" may be 'part outputs' (i.e. + data computed in one part for use by another, effectively tempoarary + variables), or 'overall outputs' of the comutation. + + .. attribute:: overall_output_names + + The names of the outputs (in :attr:`name_to_output`) that were given to + :func:`find_distributed_partition` to specify the overall computaiton. + """ parts: Mapping[PartId, DistributedGraphPart] name_to_output: Mapping[str, Array] + overall_output_names: Sequence[str] # }}} @@ -367,6 +376,7 @@ def _make_distributed_partition( sptpo_ary_to_name: Mapping[Array, str], local_recv_id_to_recv_node: Dict[CommunicationOpIdentifier, DistributedRecv], local_send_id_to_send_node: Dict[CommunicationOpIdentifier, DistributedSend], + overall_output_names: Sequence[str], ) -> DistributedGraphPartition: name_to_output = {} parts: Dict[PartId, DistributedGraphPart] = {} @@ -404,6 +414,7 @@ def _make_distributed_partition( result = DistributedGraphPartition( parts=parts, name_to_output=name_to_output, + overall_output_names=overall_output_names, ) return result @@ -969,7 +980,8 @@ def gen_array_name(ary: Array) -> str: sent_ary_to_name, sptpo_ary_to_name, lsrdg.local_recv_id_to_recv_node, - lsrdg.local_send_id_to_send_node) + lsrdg.local_send_id_to_send_node, + tuple(outputs)) from pytato.distributed.verify import _run_partition_diagnostics _run_partition_diagnostics(outputs, partition) diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index 4fa419cc2..9e3bde8d0 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -123,6 +123,7 @@ def set_union( for pid, part in partition.parts.items() }, name_to_output=partition.name_to_output, + overall_output_names=partition.overall_output_names, ), next_tag # }}} From 419da72b86b3ddd9d5ab45c856ce6abaf13646bf Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Wed, 31 May 2023 20:38:48 -0500 Subject: [PATCH 11/11] Rework visualization to render function bodies Also, use distributed vis functionality even in the non-distributed case, to cut down on code duplication. --- pytato/visualization/dot.py | 714 ++++++++++++++++++++++++------------ test/test_codegen.py | 23 +- 2 files changed, 490 insertions(+), 247 deletions(-) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index b836e3814..0008cea20 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -26,18 +26,19 @@ """ -import contextlib +from functools import partial import dataclasses import html -from typing import (TYPE_CHECKING, Callable, Dict, Union, Iterator, List, - Mapping, Hashable, Any, FrozenSet) +from typing import (TYPE_CHECKING, Callable, Dict, Tuple, Union, List, + Mapping, Any, FrozenSet, Set, Optional) from pytools import UniqueNameGenerator from pytools.tag import Tag -from pytools.codegen import CodeGenerator as CodeGeneratorBase from pytato.loopy import LoopyCall -from pytato.function import Call, NamedCallResult +from pytato.function import Call, FunctionDefinition, NamedCallResult +from pytato.tags import FunctionIdentifier +from pytools.codegen import remove_common_indentation from pytato.array import ( Array, DataWrapper, DictOfNamedArrays, IndexLambda, InputArgumentBase, @@ -45,7 +46,7 @@ IndexBase) from pytato.codegen import normalize_outputs -from pytato.transform import CachedMapper, ArrayOrNames +from pytato.transform import CachedMapper, ArrayOrNames, InputGatherer from pytato.distributed.partition import ( DistributedGraphPartition, DistributedGraphPart, PartId) @@ -63,13 +64,88 @@ """ +# {{{ _DotEmitter + +@dataclasses.dataclass +class _SubgraphTree: + contents: Optional[List[str]] + subgraphs: Dict[str, _SubgraphTree] + + +class DotEmitter: + def __init__(self) -> None: + self.subgraph_to_lines: Dict[Tuple[str, ...], List[str]] = {} + + def __call__(self, subgraph_path: Tuple[str, ...], s: str) -> None: + line_list = self.subgraph_to_lines.setdefault(subgraph_path, []) + + if not s.strip(): + line_list.append("") + else: + if "\n" in s: + s = remove_common_indentation(s) + + for line in s.split("\n"): + line_list.append(line) + + def _get_subgraph_tree(self) -> _SubgraphTree: + subgraph_tree = _SubgraphTree(contents=None, subgraphs={}) + + def insert_into_subgraph_tree( + root: _SubgraphTree, path: Tuple[str, ...], contents: List[str] + ) -> None: + if not path: + assert root.contents is None + root.contents = contents + + else: + subgraph = root.subgraphs.setdefault( + path[0], + _SubgraphTree(contents=None, subgraphs={})) + + insert_into_subgraph_tree(subgraph, path[1:], contents) + + for sgp, lines in self.subgraph_to_lines.items(): + insert_into_subgraph_tree(subgraph_tree, sgp, lines) + + return subgraph_tree + + def generate(self) -> str: + result = ["digraph computation {"] + + indent_level = 1 + + def emit_subgraph(sg: _SubgraphTree) -> None: + nonlocal indent_level + + indent = (indent_level*4)*" " + if sg.contents: + for ln in sg.contents: + result.append(indent + ln) + + indent_level += 1 + for sg_name, sub_sg in sg.subgraphs.items(): + result.append(f"{indent}subgraph {sg_name} {{") + emit_subgraph(sub_sg) + result.append(f"{indent}" "}") + indent_level -= 1 + + emit_subgraph(self._get_subgraph_tree()) + + result.append("}") + + return "\n".join(result) + +# }}} + + # {{{ array -> dot node converter @dataclasses.dataclass -class DotNodeInfo: +class _DotNodeInfo: title: str fields: Dict[str, str] - edges: Dict[str, ArrayOrNames] + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] def stringify_tags(tags: FrozenSet[Tag]) -> str: @@ -89,17 +165,18 @@ def stringify_shape(shape: ShapeType) -> str: class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): def __init__(self) -> None: super().__init__() - self.nodes: Dict[ArrayOrNames, DotNodeInfo] = {} + self.node_to_dot: Dict[ArrayOrNames, _DotNodeInfo] = {} + self.functions: Set[FunctionDefinition] = set() - def get_common_dot_info(self, expr: Array) -> DotNodeInfo: + def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: title = type(expr).__name__ fields = {"addr": hex(id(expr)), "shape": stringify_shape(expr.shape), "dtype": str(expr.dtype), "tags": stringify_tags(expr.tags)} - edges: Dict[str, ArrayOrNames] = {} - return DotNodeInfo(title, fields, edges) + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} + return _DotNodeInfo(title, fields, edges) # type-ignore-reason: incompatible with supertype def handle_unsupported_array(self, # type: ignore[override] @@ -126,7 +203,7 @@ def handle_unsupported_array(self, # type: ignore[override] else: info.fields[field] = str(attr) - self.nodes[expr] = info + self.node_to_dot[expr] = info def map_data_wrapper(self, expr: DataWrapper) -> None: info = self.get_common_dot_info(expr) @@ -138,7 +215,7 @@ def map_data_wrapper(self, expr: DataWrapper) -> None: with np.printoptions(threshold=4, precision=2): info.fields["data"] = str(expr.data) - self.nodes[expr] = info + self.node_to_dot[expr] = info def map_index_lambda(self, expr: IndexLambda) -> None: info = self.get_common_dot_info(expr) @@ -148,7 +225,7 @@ def map_index_lambda(self, expr: IndexLambda) -> None: self.rec(val) info.edges[name] = val - self.nodes[expr] = info + self.node_to_dot[expr] = info def map_stack(self, expr: Stack) -> None: info = self.get_common_dot_info(expr) @@ -158,7 +235,7 @@ def map_stack(self, expr: Stack) -> None: self.rec(array) info.edges[str(i)] = array - self.nodes[expr] = info + self.node_to_dot[expr] = info map_concatenate = map_stack @@ -189,7 +266,7 @@ def map_basic_index(self, expr: IndexBase) -> None: self.rec(expr.array) info.edges["array"] = expr.array - self.nodes[expr] = info + self.node_to_dot[expr] = info map_contiguous_advanced_index = map_basic_index map_non_contiguous_advanced_index = map_basic_index @@ -202,27 +279,27 @@ def map_einsum(self, expr: Einsum) -> None: self.rec(val) info.edges[f"{iarg}: {access_descr}"] = val - self.nodes[expr] = info + self.node_to_dot[expr] = info def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: - edges: Dict[str, ArrayOrNames] = {} + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} for name, val in expr._data.items(): edges[name] = val self.rec(val) - self.nodes[expr] = DotNodeInfo( + self.node_to_dot[expr] = _DotNodeInfo( title=type(expr).__name__, fields={}, edges=edges) def map_loopy_call(self, expr: LoopyCall) -> None: - edges: Dict[str, ArrayOrNames] = {} + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} for name, arg in expr.bindings.items(): if isinstance(arg, Array): edges[name] = arg self.rec(arg) - self.nodes[expr] = DotNodeInfo( + self.node_to_dot[expr] = _DotNodeInfo( title=type(expr).__name__, fields={"addr": hex(id(expr)), "entrypoint": expr.entrypoint}, edges=edges) @@ -242,15 +319,19 @@ def map_distributed_send_ref_holder( info.fields["comm_tag"] = str(expr.send.comm_tag) - self.nodes[expr] = info + self.node_to_dot[expr] = info def map_call(self, expr: Call) -> None: + self.functions.add(expr.function) + for bnd in expr.bindings.values(): self.rec(bnd) - self.nodes[expr] = DotNodeInfo( + self.node_to_dot[expr] = _DotNodeInfo( title=expr.__class__.__name__, - edges=dict(expr.bindings), + edges={ + "": expr.function, + **expr.bindings}, fields={ "addr": hex(id(expr)), "tags": stringify_tags(expr.tags), @@ -259,30 +340,24 @@ def map_call(self, expr: Call) -> None: def map_named_call_result(self, expr: NamedCallResult) -> None: self.rec(expr._container) - self.nodes[expr] = DotNodeInfo( + self.node_to_dot[expr] = _DotNodeInfo( title=expr.__class__.__name__, edges={"": expr._container}, fields={"addr": hex(id(expr)), "name": expr.name}, ) +# }}} + def dot_escape(s: str) -> str: # "\" and HTML are significant in graphviz. - return html.escape(s.replace("\\", "\\\\")) - + return html.escape(s.replace("\\", "\\\\").replace(" ", "_")) -class DotEmitter(CodeGeneratorBase): - @contextlib.contextmanager - def block(self, name: str) -> Iterator[None]: - self(name + " {") - self.indent() - yield - self.dedent() - self("}") +# {{{ emit helpers -def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str], +def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, str], dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' @@ -300,84 +375,189 @@ def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str], emit("%s [label=<%s> style=filled fillcolor=%s]" % (dot_node_id, table, color)) -def _emit_name_cluster(emit: DotEmitter, names: Mapping[str, ArrayOrNames], +def _emit_name_cluster( + emit: DotEmitter, subgraph_path: Tuple[str, ...], + names: Mapping[str, ArrayOrNames], array_to_id: Mapping[ArrayOrNames, str], id_gen: Callable[[str], str], label: str) -> None: edges = [] - with emit.block("subgraph cluster_%s" % label): - emit("node [shape=ellipse]") - emit('label="%s"' % label) + cluster_subgraph_path = subgraph_path + (f"cluster_{dot_escape(label)}",) + emit_cluster = partial(emit, cluster_subgraph_path) + emit_cluster("node [shape=ellipse]") + emit_cluster(f'label="{label}"') - for name, array in names.items(): - name_id = id_gen(label) - emit('%s [label="%s"]' % (name_id, dot_escape(name))) - array_id = array_to_id[array] - # Edges must be outside the cluster. - edges.append((name_id, array_id)) + for name, array in names.items(): + name_id = id_gen(dot_escape(name)) + emit_cluster('%s [label="%s"]' % (name_id, dot_escape(name))) + array_id = array_to_id[array] + # Edges must be outside the cluster. + edges.append((name_id, array_id)) for name_id, array_id in edges: - emit("%s -> %s" % (name_id, array_id)) - -# }}} + emit(subgraph_path, "%s -> %s" % (array_id, name_id)) -def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: - r"""Return a string in the `dot `_ language depicting the - graph of the computation of *result*. - - :arg result: Outputs of the computation (cf. - :func:`pytato.generate_loopy`). - """ - outputs: DictOfNamedArrays = normalize_outputs(result) - del result - - mapper = ArrayToDotNodeInfoMapper() - for elem in outputs._data.values(): - mapper(elem) - - nodes = mapper.nodes - +def _emit_function( + emitter: DotEmitter, subgraph_path: Tuple[str, ...], + id_gen: UniqueNameGenerator, + node_to_dot: Mapping[ArrayOrNames, _DotNodeInfo], + func_to_id: Mapping[FunctionDefinition, str], + outputs: Mapping[str, Array]) -> None: input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] array_to_id: Dict[ArrayOrNames, str] = {} - id_gen = UniqueNameGenerator() - for array in nodes: + emit = partial(emitter, subgraph_path) + for array in node_to_dot: array_to_id[array] = id_gen("array") if isinstance(array, InputArgumentBase): input_arrays.append(array) else: internal_arrays.append(array) - emit = DotEmitter() + # Emit inputs. + input_subgraph_path = subgraph_path + ("cluster_inputs",) + emit_input = partial(emitter, input_subgraph_path) + emit_input('label="Arguments"') + + for array in input_arrays: + _emit_array( + emit_input, + node_to_dot[array].title, + node_to_dot[array].fields, + array_to_id[array]) + + # Emit non-inputs. + for array in internal_arrays: + _emit_array(emit, + node_to_dot[array].title, + node_to_dot[array].fields, + array_to_id[array]) + + # Emit edges. + for array, node in node_to_dot.items(): + for label, tail_item in node.edges.items(): + head = array_to_id[array] + if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)): + tail = array_to_id[tail_item] + elif isinstance(tail_item, FunctionDefinition): + tail = func_to_id[tail_item] + else: + raise ValueError( + f"unexpected type of tail on edge: {type(tail_item)}") + + emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label))) - with emit.block("digraph computation"): - emit("node [shape=rectangle]") + # Emit output/namespace name mappings. + _emit_name_cluster( + emitter, subgraph_path, outputs, array_to_id, id_gen, label="Returns") - # Emit inputs. - with emit.block("subgraph cluster_Inputs"): - emit('label="Inputs"') - for array in input_arrays: - _emit_array(emit, - nodes[array].title, nodes[array].fields, array_to_id[array]) +# }}} - # Emit non-inputs. - for array in internal_arrays: - _emit_array(emit, - nodes[array].title, nodes[array].fields, array_to_id[array]) - # Emit edges. - for array, node in nodes.items(): - for label, tail_array in node.edges.items(): - tail = array_to_id[tail_array] - head = array_to_id[array] - emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label))) +# {{{ information gathering - # Emit output/namespace name mappings. - _emit_name_cluster(emit, outputs._data, array_to_id, id_gen, label="Outputs") +def _get_function_name(f: FunctionDefinition) -> Optional[str]: + func_id_tags = f.tags_of_type(FunctionIdentifier) + if func_id_tags: + func_id_tag, = func_id_tags + return str(func_id_tag.identifier) + else: + return None + + +def _gather_partition_node_information( + id_gen: UniqueNameGenerator, + partition: DistributedGraphPartition + ) -> Tuple[ + Mapping[PartId, Mapping[FunctionDefinition, str]], + Mapping[Tuple[PartId, Optional[FunctionDefinition]], + Mapping[ArrayOrNames, _DotNodeInfo]] + ]: + part_id_to_func_to_id: Dict[PartId, Dict[FunctionDefinition, str]] = {} + part_id_func_to_node_info: Dict[Tuple[PartId, Optional[FunctionDefinition]], + Dict[ArrayOrNames, _DotNodeInfo]] = {} + + for part in partition.parts.values(): + mapper = ArrayToDotNodeInfoMapper() + for out_name in part.output_names: + mapper(partition.name_to_output[out_name]) + + part_id_func_to_node_info[part.pid, None] = mapper.node_to_dot + part_id_to_func_to_id[part.pid] = {} - return emit.get() + # It is important that seen functions are emitted callee-first. + # (Otherwise function 'entry' nodes will get declared in the wrong + # cluster.) So use a data type that preserves order. + seen_functions: List[FunctionDefinition] = [] + + def gather_function_info(f: FunctionDefinition) -> None: + key = (part.pid, f) # noqa: B023 + if key in part_id_func_to_node_info: + return + + mapper = ArrayToDotNodeInfoMapper() + for elem in f.returns.values(): + mapper(elem) + + part_id_func_to_node_info[key] = mapper.node_to_dot + + for subfunc in mapper.functions: + gather_function_info(subfunc) + + if f not in seen_functions: # noqa: B023 + seen_functions.append(f) # noqa: B023 + + for f in mapper.functions: + gather_function_info(f) + + # Again, important to preserve function order. Here we're relying + # on dicts to preserve order. + for f in seen_functions: + func_name = _get_function_name(f) + if func_name is not None: + fid = id_gen(dot_escape(func_name)) + else: + fid = id_gen("func") + + part_id_to_func_to_id.setdefault(part.pid, {})[f] = fid + + return part_id_to_func_to_id, part_id_func_to_node_info + +# }}} + + +def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: + r"""Return a string in the `dot `_ language depicting the + graph of the computation of *result*. + + :arg result: Outputs of the computation (cf. + :func:`pytato.generate_loopy`). + """ + + outputs: DictOfNamedArrays = normalize_outputs(result) + + return get_dot_graph_from_partition( + DistributedGraphPartition( + parts={ + None: DistributedGraphPart( + pid=None, + needed_pids=frozenset(), + user_input_names=frozenset( + expr.name + for expr in InputGatherer()(outputs) + if isinstance(expr, Placeholder) + ), + partition_input_names=frozenset(), + output_names=frozenset(outputs.keys()), + name_to_recv_node={}, + name_to_send_nodes={}, + ) + }, + name_to_output=outputs._data, + overall_output_names=tuple(outputs), + )) def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: @@ -386,170 +566,226 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: :arg partition: Outputs of :func:`~pytato.find_distributed_partition`. """ - # Maps each partition to a dict of its arrays with the node info - part_id_to_node_info: Dict[Hashable, Dict[ArrayOrNames, DotNodeInfo]] = {} + id_gen = UniqueNameGenerator() - for part in partition.parts.values(): - mapper = ArrayToDotNodeInfoMapper() - for out_name in part.output_names: - mapper(partition.name_to_output[out_name]) + # {{{ gather up node info, per partition and per function - part_id_to_node_info[part.pid] = mapper.nodes + # The "None" function is the body of the partition. - id_gen = UniqueNameGenerator() + part_id_to_func_to_id, part_id_func_to_node_info = \ + _gather_partition_node_information(id_gen, partition) - emit = DotEmitter() + # }}} + + emitter = DotEmitter() + emit_root = partial(emitter, ()) emitted_placeholders = set() - with emit.block("digraph computation"): - emit("node [shape=rectangle]") - placeholder_to_id: Dict[ArrayOrNames, str] = {} - part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {} - - # First pass: generate names for all nodes - for part in partition.parts.values(): - array_to_id = {} - for array, _ in part_id_to_node_info[part.pid].items(): - if isinstance(array, Placeholder): - # Placeholders are only emitted once - if array in placeholder_to_id: - node_id = placeholder_to_id[array] - else: - node_id = id_gen("array") - placeholder_to_id[array] = node_id + emit_root("node [shape=rectangle]") + + placeholder_to_id: Dict[ArrayOrNames, str] = {} + part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {} + + part_id_to_id = {pid: dot_escape(str(pid)) for pid in partition.parts} + assert len(set(part_id_to_id.values())) == len(partition.parts) + + # {{{ generate names for all nodes in the root/None function + + for part in partition.parts.values(): + array_to_id = {} + for array in part_id_func_to_node_info[part.pid, None].keys(): + if isinstance(array, Placeholder): + # Placeholders are only emitted once + if array in placeholder_to_id: + node_id = placeholder_to_id[array] else: node_id = id_gen("array") - array_to_id[array] = node_id - part_id_to_array_to_id[part.pid] = array_to_id - - # Second pass: emit the graph. - for part in partition.parts.values(): - array_to_id = part_id_to_array_to_id[part.pid] - - # {{{ emit receives nodes if distributed - - if isinstance(part, DistributedGraphPart): - part_dist_recv_var_name_to_node_id = {} - for name, recv in ( - part.name_to_recv_node.items()): - node_id = id_gen("recv") - _emit_array(emit, "DistributedRecv", { - "shape": stringify_shape(recv.shape), - "dtype": str(recv.dtype), - "src_rank": str(recv.src_rank), - "comm_tag": str(recv.comm_tag), - }, node_id) + placeholder_to_id[array] = node_id + else: + node_id = id_gen("array") + array_to_id[array] = node_id + part_id_to_array_to_id[part.pid] = array_to_id + + # }}} + + # {{{ emit the graph + + for part in partition.parts.values(): + array_to_id = part_id_to_array_to_id[part.pid] - part_dist_recv_var_name_to_node_id[name] = node_id + is_trivial_partition = part.pid is None and len(partition.parts) == 1 + if is_trivial_partition: + part_subgraph_path: Tuple[str, ...] = () + else: + part_subgraph_path = (f"cluster_{part_id_to_id[part.pid]}",) + + emit_part = partial(emitter, part_subgraph_path) + + if not is_trivial_partition: + emit_part("style=dashed") + emit_part(f'label="{part.pid}"') + + # {{{ emit functions + + # It is important that seen functions are emitted callee-first. + # Here we're relying on the part_id_to_func_to_id dict to preserve order. + + for func, fid in part_id_to_func_to_id[part.pid].items(): + func_subgraph_path = part_subgraph_path + (f"cluster_{fid}",) + label = _get_function_name(func) or fid + + emitter(func_subgraph_path, f'label="{label}"') + emitter(func_subgraph_path, f'{fid} [label="{label}",shape="ellipse"]') + + _emit_function(emitter, func_subgraph_path, + id_gen, part_id_func_to_node_info[part.pid, func], + part_id_to_func_to_id[part.pid], + func.returns) + + # }}} + + # {{{ emit receives nodes + + part_dist_recv_var_name_to_node_id = {} + for name, recv in ( + part.name_to_recv_node.items()): + node_id = id_gen("recv") + _emit_array(emit_part, "DistributedRecv", { + "shape": stringify_shape(recv.shape), + "dtype": str(recv.dtype), + "src_rank": str(recv.src_rank), + "comm_tag": str(recv.comm_tag), + }, node_id) + + part_dist_recv_var_name_to_node_id[name] = node_id + + # }}} + + part_node_to_info = part_id_func_to_node_info[part.pid, None] + input_arrays: List[Array] = [] + internal_arrays: List[ArrayOrNames] = [] + + for array in part_node_to_info.keys(): + if isinstance(array, InputArgumentBase): + input_arrays.append(array) else: - part_dist_recv_var_name_to_node_id = {} + internal_arrays.append(array) - # }}} + # {{{ emit inputs - part_node_to_info = part_id_to_node_info[part.pid] - input_arrays: List[Array] = [] - internal_arrays: List[ArrayOrNames] = [] + # Placeholders are unique, i.e. the same Placeholder object may be + # shared among partitions. Therefore, they should not live inside + # the (dot) subgraph, otherwise they would be forced into multiple + # subgraphs. - for array in part_node_to_info.keys(): - if isinstance(array, InputArgumentBase): - input_arrays.append(array) + for array in input_arrays: + if not isinstance(array, Placeholder): + _emit_array(emit_part, + part_node_to_info[array].title, + part_node_to_info[array].fields, + array_to_id[array], "deepskyblue") + else: + # Is a Placeholder + if array in emitted_placeholders: + continue + + _emit_array(emit_root, + part_node_to_info[array].title, + part_node_to_info[array].fields, + array_to_id[array], "deepskyblue") + + # Emit cross-partition edges + if array.name in part_dist_recv_var_name_to_node_id: + tgt = part_dist_recv_var_name_to_node_id[array.name] + emit_root(f"{tgt} -> {array_to_id[array]} [style=dotted]") + emitted_placeholders.add(array) + elif array.name in part.user_input_names: + # no arrows for these + pass else: - internal_arrays.append(array) - - # {{{ emit inputs - - # Placeholders are unique, i.e. the same Placeholder object may be - # shared among partitions. Therefore, they should not live inside - # the (dot) subgraph, otherwise they would be forced into multiple - # subgraphs. - - for array in input_arrays: - # Non-Placeholders are emitted *inside* their subgraphs below. - if isinstance(array, Placeholder): - if array not in emitted_placeholders: - _emit_array(emit, - part_node_to_info[array].title, - part_node_to_info[array].fields, - array_to_id[array], "deepskyblue") - - # Emit cross-partition edges - if array.name in part_dist_recv_var_name_to_node_id: - tgt = part_dist_recv_var_name_to_node_id[array.name] - emit(f"{tgt} -> {array_to_id[array]} [style=dotted]") - emitted_placeholders.add(array) - elif array.name in part.user_input_names: - # These are placeholders for external input. They - # are cleanly associated with a single partition - # and thus emitted below. - pass - else: - # placeholder for a value from a different partition - computing_pid = None - for other_part in partition.parts.values(): - if array.name in other_part.output_names: - computing_pid = other_part.pid - break - assert computing_pid is not None - tgt = part_id_to_array_to_id[computing_pid][ - partition.name_to_output[array.name]] - emit(f"{tgt} -> {array_to_id[array]} [style=dashed]") - emitted_placeholders.add(array) - - # }}} - - with emit.block(f'subgraph "cluster_part_{part.pid}"'): - emit("style=dashed") - emit(f'label="{part.pid}"') - - for array in input_arrays: - if (not isinstance(array, Placeholder) - or array.name in part.user_input_names): - _emit_array(emit, - part_node_to_info[array].title, - part_node_to_info[array].fields, - array_to_id[array], "deepskyblue") - - # Emit internal nodes - for array in internal_arrays: - _emit_array(emit, - part_node_to_info[array].title, - part_node_to_info[array].fields, - array_to_id[array]) - - # {{{ emit send nodes if distributed - - deferred_send_edges = [] - if isinstance(part, DistributedGraphPart): - for name, sends in ( - part.name_to_send_nodes.items()): - for send in sends: - node_id = id_gen("send") - _emit_array(emit, "DistributedSend", { - "dest_rank": str(send.dest_rank), - "comm_tag": str(send.comm_tag), - }, node_id) - - deferred_send_edges.append( - f"{array_to_id[send.data]} -> {node_id}" - f'[style=dotted, label="{dot_escape(name)}"]') - - # }}} - - # If an edge is emitted in a subgraph, it drags its nodes into the - # subgraph, too. Not what we want. - for edge in deferred_send_edges: - emit(edge) - - # Emit intra-partition edges - for array, node in part_node_to_info.items(): - for label, tail_array in node.edges.items(): - tail = array_to_id[tail_array] - head = array_to_id[array] - emit('%s -> %s [label="%s"]' % - (tail, head, dot_escape(label))) - - return emit.get() + # placeholder for a value from a different partition + computing_pid = None + for other_part in partition.parts.values(): + if array.name in other_part.output_names: + computing_pid = other_part.pid + break + assert computing_pid is not None + tgt = part_id_to_array_to_id[computing_pid][ + partition.name_to_output[array.name]] + emit_root(f"{tgt} -> {array_to_id[array]} [style=dashed]") + emitted_placeholders.add(array) + + # }}} + + # Emit internal nodes + for array in internal_arrays: + _emit_array(emit_part, + part_node_to_info[array].title, + part_node_to_info[array].fields, + array_to_id[array]) + + # {{{ emit send nodes if distributed + + if isinstance(part, DistributedGraphPart): + for name, sends in part.name_to_send_nodes.items(): + for send in sends: + node_id = id_gen("send") + _emit_array(emit_part, "DistributedSend", { + "dest_rank": str(send.dest_rank), + "comm_tag": str(send.comm_tag), + }, node_id) + + # If an edge is emitted in a subgraph, it drags its + # nodes into the subgraph, too. Not what we want. + emit_root( + f"{array_to_id[send.data]} -> {node_id}" + f'[style=dotted, label="{dot_escape(name)}"]') + + # }}} + + # Emit intra-partition edges + for array, node in part_node_to_info.items(): + for label, tail_item in node.edges.items(): + head = array_to_id[array] + + if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)): + tail = array_to_id[tail_item] + elif isinstance(tail_item, FunctionDefinition): + tail = part_id_to_func_to_id[part.pid][tail_item] + else: + raise ValueError( + f"unexpected type of tail on edge: {type(tail_item)}") + + emit_root('%s -> %s [label="%s"]' % + (tail, head, dot_escape(label))) + + _emit_name_cluster( + emitter, part_subgraph_path, + {name: partition.name_to_output[name] for name in part.output_names}, + array_to_id, id_gen, "Part outputs") + + # }}} + + # Arrays may occur in multiple partitions, they get drawn separately anyhow + # (unless they're Placeholders). Don't be tempted to use + # combined_array_to_id everywhere. + + # {{{ draw overall outputs + + combined_array_to_id: Dict[ArrayOrNames, str] = {} + for part_id in partition.parts.keys(): + combined_array_to_id.update(part_id_to_array_to_id[part_id]) + + _emit_name_cluster( + emitter, (), + {name: partition.name_to_output[name] + for name in partition.overall_output_names}, + combined_array_to_id, id_gen, "Overall outputs") + + # }}} + + return emitter.generate() def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, diff --git a/test/test_codegen.py b/test/test_codegen.py index e02fbc85c..b6d0bba8b 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1854,15 +1854,16 @@ def test_pad(ctx_factory): np.testing.assert_allclose(np_out * mask_array, pt_out * mask_array) -def test_function_call(ctx_factory): +def test_function_call(ctx_factory, visualize=False): + from functools import partial cl_ctx = ctx_factory() cq = cl.CommandQueue(cl_ctx) def f(x): return 2*x - def g(x): - return 2*x, 3*x + def g(tracer, x): + return tracer(f, x), 3*x def h(x, y): return {"twice": 2*x+y, "thrice": 3*x+y} @@ -1870,7 +1871,7 @@ def h(x, y): def build_expression(tracer): x = pt.arange(500, dtype=np.float32) twice_x = tracer(f, x) - twice_x_2, thrice_x_2 = tracer(g, x) + twice_x_2, thrice_x_2 = tracer(partial(g, tracer), x) result = tracer(h, x, 2*x) twice_x_3 = result["twice"] @@ -1881,13 +1882,19 @@ def build_expression(tracer): "baz": 65 * twice_x, "quux": 7 * twice_x_2} - result1 = pt.tag_all_calls_to_be_inlined( + result_with_functions = pt.tag_all_calls_to_be_inlined( pt.make_dict_of_named_arrays(build_expression(pt.trace_call))) - result2 = pt.make_dict_of_named_arrays( + result_without_functions = pt.make_dict_of_named_arrays( build_expression(lambda fn, *args: fn(*args))) - _, outputs = pt.generate_loopy(result1)(cq, out_host=True) - _, expected = pt.generate_loopy(result2)(cq, out_host=True) + # test that visualizing graphs with functions works + dot = pt.get_dot_graph(result_with_functions) + + if visualize: + pt.show_dot_graph(dot) + + _, outputs = pt.generate_loopy(result_with_functions)(cq, out_host=True) + _, expected = pt.generate_loopy(result_without_functions)(cq, out_host=True) assert len(outputs) == len(expected)