Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ Array Expressions

.. automodule:: pytato.array

Functions in Pytato IR
----------------------

.. automodule:: pytato.function

Raising :class:`~pytato.array.IndexLambda` nodes
------------------------------------------------

.. automodule:: pytato.raising


Calling :mod:`loopy` kernels in an array expression
---------------------------------------------------

Expand Down
12 changes: 11 additions & 1 deletion pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.visualization import (get_dot_graph, show_dot_graph,
get_ascii_graph, show_ascii_graph,
get_dot_graph_from_partition)
from pytato.transform.calls import tag_all_calls_to_be_inlined, inline_calls
import pytato.analysis as analysis
import pytato.tags as tags
import pytato.function as function
import pytato.transform as transform
from pytato.distributed.nodes import (make_distributed_send, make_distributed_recv,
DistributedRecv, DistributedSend,
Expand All @@ -108,6 +110,8 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.transform.remove_broadcasts_einsum import (
rewrite_einsums_with_no_broadcasts)
from pytato.transform.metadata import unify_axes_tags
from pytato.function import trace_call
from pytato.transform.calls import concatenate_calls

from pytato.partition import generate_code_for_partition

Expand Down Expand Up @@ -154,11 +158,17 @@ def set_debug_enabled(flag: bool) -> None:

"broadcast_to", "pad",

"trace_call",

"concatenate_calls",

"make_distributed_recv", "make_distributed_send", "DistributedRecv",
"DistributedSend", "staple_distributed_send", "DistributedSendRefHolder",

"DistributedGraphPart",
"DistributedGraphPartition",
"tag_all_calls_to_be_inlined", "inline_calls",

"find_distributed_partition",

"number_distributed_tags",
Expand All @@ -174,6 +184,6 @@ def set_debug_enabled(flag: bool) -> None:
"unify_axes_tags",

# sub-modules
"analysis", "tags", "transform",
"analysis", "tags", "transform", "function",

)
58 changes: 58 additions & 0 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase,
ShapeType)
from pytato.function import FunctionDefinition, Call
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
from pytato.loopy import LoopyCall
from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method

if TYPE_CHECKING:
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
Expand All @@ -47,6 +49,8 @@

.. autofunction:: get_num_nodes

.. autofunction:: get_num_call_sites

.. autoclass:: DirectPredecessorsGetter
"""

Expand Down Expand Up @@ -388,3 +392,57 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
return ncm.count

# }}}


# {{{ CallSiteCountMapper

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class CallSiteCountMapper(CachedWalkMapper):
"""
Counts the number of nodes in a DAG.

.. attribute:: count

The number of nodes.
"""

def __init__(self) -> None:
super().__init__()
self.count = 0

# type-ignore-reason: dropped the extra `*args, **kwargs`.
def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
return id(expr)

@memoize_method
def map_function_definition(self, /, expr: FunctionDefinition,
*args: Any, **kwargs: Any) -> None:
if not self.visit(expr):
return

new_mapper = self.clone_for_callee()
for subexpr in expr.returns.values():
new_mapper(subexpr, *args, **kwargs)

self.count += new_mapper.count

self.post_visit(expr, *args, **kwargs)

# type-ignore-reason: dropped the extra `*args, **kwargs`.
def post_visit(self, expr: Any) -> None: # type: ignore[override]
if isinstance(expr, Call):
self.count += 1


def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int:
"""Returns the number of nodes in DAG *outputs*."""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

cscm = CallSiteCountMapper()
cscm(outputs)

return cscm.count

# }}}
22 changes: 15 additions & 7 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""

import dataclasses
from typing import Union, Dict, Tuple, List, Any
from typing import Union, Dict, Tuple, List, Any, Optional

from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder,
DataInterface, SizeParam, InputArgumentBase,
Expand Down Expand Up @@ -102,12 +102,17 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
====================================== =====================================
"""

def __init__(self, target: Target) -> None:
def __init__(self, target: Target,
kernels_seen: Optional[Dict[str, lp.LoopKernel]] = None
) -> None:
super().__init__()
self.bound_arguments: Dict[str, DataInterface] = {}
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
self.target = target
self.kernels_seen: Dict[str, lp.LoopKernel] = {}
self.kernels_seen: Dict[str, lp.LoopKernel] = kernels_seen or {}

def clone_for_callee(self) -> CodeGenPreprocessor:
return CodeGenPreprocessor(self.target, self.kernels_seen)

def map_size_param(self, expr: SizeParam) -> Array:
name = expr.name
Expand Down Expand Up @@ -262,6 +267,7 @@ class PreprocessResult:
def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:
"""Preprocess a computation for code generation."""
from pytato.transform import copy_dict_of_named_arrays
from pytato.transform.calls import inline_calls

check_validity_of_outputs(outputs)

Expand Down Expand Up @@ -289,12 +295,14 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:

# }}}

mapper = CodeGenPreprocessor(target)
new_outputs = inline_calls(outputs)
assert isinstance(new_outputs, DictOfNamedArrays)

new_outputs = copy_dict_of_named_arrays(outputs, mapper)
mapper = CodeGenPreprocessor(target)
new_outputs = copy_dict_of_named_arrays(new_outputs, mapper)

return PreprocessResult(outputs=new_outputs,
compute_order=tuple(output_order),
bound_arguments=mapper.bound_arguments)
compute_order=tuple(output_order),
bound_arguments=mapper.bound_arguments)

# vim: fdm=marker
27 changes: 27 additions & 0 deletions pytato/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
IndexBase, IndexLambda, NamedArray,
Reshape, Roll, Stack, AbstractResultWithNamedArrays,
Array, DictOfNamedArrays, Placeholder, SizeParam)
from pytato.function import Call, NamedCallResult, FunctionDefinition
from pytools import memoize_method

if TYPE_CHECKING:
from pytato.loopy import LoopyCall, LoopyCallResult
Expand Down Expand Up @@ -273,6 +275,31 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
and expr1.tags == expr2.tags
)

@memoize_method
def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
) -> bool:
return (expr1.__class__ is expr2.__class__
and expr1.parameters == expr2.parameters
and (set(expr1.returns.keys()) == set(expr2.returns.keys()))
and all(self.rec(expr1.returns[k], expr2.returns[k])
for k in expr1.returns)
and expr1.tags == expr2.tags
)

def map_call(self, expr1: Call, expr2: Any) -> bool:
return (expr1.__class__ is expr2.__class__
and self.map_function_definition(expr1.function, expr2.function)
and frozenset(expr1.bindings) == frozenset(expr2.bindings)
and all(self.rec(bnd,
expr2.bindings[name])
for name, bnd in expr1.bindings.items())
)

def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool:
return (expr1.__class__ is expr2.__class__
and expr1.name == expr2.name
and self.rec(expr1._container, expr2._container))

# }}}

# vim: fdm=marker
Loading