Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .pylintrc-local.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
- arg: ignored-modules
val:
- asciidag
- matplotlib
- ipykernel
- ply
Expand Down
5 changes: 5 additions & 0 deletions doc/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@ Array Expressions

.. automodule:: pytato.array

Functions in Pytato IR
----------------------

.. automodule:: pytato.function

Raising :class:`~pytato.array.IndexLambda` nodes
------------------------------------------------

.. automodule:: pytato.raising


Calling :mod:`loopy` kernels in an array expression
---------------------------------------------------

Expand Down
2 changes: 0 additions & 2 deletions examples/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ def main():
stack = pt.stack([array, 2*array, array + 6])
result = stack @ stack.T

pt.show_ascii_graph(result)

dot_code = pt.get_dot_graph(result)

with open(GRAPH_DOT, "w") as outf:
Expand Down
14 changes: 10 additions & 4 deletions pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,13 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.target.loopy import LoopyPyOpenCLTarget
from pytato.target.python.jax import generate_jax
from pytato.visualization import (get_dot_graph, show_dot_graph,
get_ascii_graph, show_ascii_graph,
get_dot_graph_from_partition,
show_fancy_placeholder_data_flow,
)
from pytato.transform.calls import tag_all_calls_to_be_inlined, inline_calls
import pytato.analysis as analysis
import pytato.tags as tags
import pytato.function as function
import pytato.transform as transform
from pytato.distributed.nodes import (make_distributed_send, make_distributed_recv,
DistributedRecv, DistributedSend,
Expand All @@ -111,6 +112,7 @@ def set_debug_enabled(flag: bool) -> None:
from pytato.transform.remove_broadcasts_einsum import (
rewrite_einsums_with_no_broadcasts)
from pytato.transform.metadata import unify_axes_tags
from pytato.function import trace_call

__all__ = (
"dtype",
Expand All @@ -133,8 +135,8 @@ def set_debug_enabled(flag: bool) -> None:

"Target", "LoopyPyOpenCLTarget",

"get_dot_graph", "show_dot_graph", "get_ascii_graph",
"show_ascii_graph", "get_dot_graph_from_partition",
"get_dot_graph", "show_dot_graph",
"get_dot_graph_from_partition",
"show_fancy_placeholder_data_flow",

"abs", "sin", "cos", "tan", "arcsin", "arccos", "arctan", "sinh", "cosh",
Expand All @@ -156,11 +158,15 @@ def set_debug_enabled(flag: bool) -> None:

"broadcast_to", "pad",

"trace_call",

"make_distributed_recv", "make_distributed_send", "DistributedRecv",
"DistributedSend", "staple_distributed_send", "DistributedSendRefHolder",

"DistributedGraphPart",
"DistributedGraphPartition",
"tag_all_calls_to_be_inlined", "inline_calls",

"find_distributed_partition",

"number_distributed_tags",
Expand All @@ -175,6 +181,6 @@ def set_debug_enabled(flag: bool) -> None:
"unify_axes_tags",

# sub-modules
"analysis", "tags", "transform",
"analysis", "tags", "transform", "function",

)
70 changes: 68 additions & 2 deletions pytato/analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@
DictOfNamedArrays, NamedArray,
IndexBase, IndexRemappingBase, InputArgumentBase,
ShapeType)
from pytato.function import FunctionDefinition, Call
from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper
from pytato.loopy import LoopyCall
from pymbolic.mapper.optimize import optimize_mapper
from pytools import memoize_method

if TYPE_CHECKING:
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
Expand All @@ -47,10 +49,14 @@

.. autofunction:: get_num_nodes

.. autofunction:: get_num_call_sites

.. autoclass:: DirectPredecessorsGetter
"""


# {{{ NUserCollector

class NUserCollector(Mapper):
"""
A :class:`pytato.transform.CachedWalkMapper` that records the number of
Expand Down Expand Up @@ -168,6 +174,8 @@ def map_distributed_recv(self, expr: DistributedRecv) -> None:
self.nusers[dim] += 1
self.rec(dim)

# }}}


def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]:
"""
Expand All @@ -181,6 +189,8 @@ def get_nusers(outputs: Union[Array, DictOfNamedArrays]) -> Mapping[Array, int]:
return nuser_collector.nusers


# {{{ is_einsum_similar_to_subscript

def _get_indices_from_input_subscript(subscript: str,
is_output: bool,
) -> Tuple[str, ...]:
Expand Down Expand Up @@ -208,8 +218,6 @@ def _get_indices_from_input_subscript(subscript: str,

# }}}

# {{{

if is_output:
if len(normalized_indices) != len(set(normalized_indices)):
repeated_idx = next(idx
Expand Down Expand Up @@ -273,6 +281,8 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool:

return True

# }}}


# {{{ DirectPredecessorsGetter

Expand Down Expand Up @@ -388,3 +398,59 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int:
return ncm.count

# }}}


# {{{ CallSiteCountMapper

@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True)
class CallSiteCountMapper(CachedWalkMapper):
"""
Counts the number of :class:`~pytato.Call` nodes in a DAG.

.. attribute:: count

The number of nodes.
"""

def __init__(self) -> None:
super().__init__()
self.count = 0

# type-ignore-reason: dropped the extra `*args, **kwargs`.
def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override]
return id(expr)

@memoize_method
def map_function_definition(self, /, expr: FunctionDefinition,
*args: Any, **kwargs: Any) -> None:
if not self.visit(expr):
return

new_mapper = self.clone_for_callee()
for subexpr in expr.returns.values():
new_mapper(subexpr, *args, **kwargs)

self.count += new_mapper.count

self.post_visit(expr, *args, **kwargs)

# type-ignore-reason: dropped the extra `*args, **kwargs`.
def post_visit(self, expr: Any) -> None: # type: ignore[override]
if isinstance(expr, Call):
self.count += 1


def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int:
"""Returns the number of nodes in DAG *outputs*."""

from pytato.codegen import normalize_outputs
outputs = normalize_outputs(outputs)

cscm = CallSiteCountMapper()
cscm(outputs)

return cscm.count

# }}}

# vim: fdm=marker
22 changes: 15 additions & 7 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"""

import dataclasses
from typing import Union, Dict, Tuple, List, Any
from typing import Union, Dict, Tuple, List, Any, Optional

from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder,
DataInterface, SizeParam, InputArgumentBase,
Expand Down Expand Up @@ -102,12 +102,17 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
====================================== =====================================
"""

def __init__(self, target: Target) -> None:
def __init__(self, target: Target,
kernels_seen: Optional[Dict[str, lp.LoopKernel]] = None
) -> None:
super().__init__()
self.bound_arguments: Dict[str, DataInterface] = {}
self.var_name_gen: UniqueNameGenerator = UniqueNameGenerator()
self.target = target
self.kernels_seen: Dict[str, lp.LoopKernel] = {}
self.kernels_seen: Dict[str, lp.LoopKernel] = kernels_seen or {}

def clone_for_callee(self) -> CodeGenPreprocessor:
return CodeGenPreprocessor(self.target, self.kernels_seen)

def map_size_param(self, expr: SizeParam) -> Array:
name = expr.name
Expand Down Expand Up @@ -262,6 +267,7 @@ class PreprocessResult:
def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:
"""Preprocess a computation for code generation."""
from pytato.transform import copy_dict_of_named_arrays
from pytato.transform.calls import inline_calls

check_validity_of_outputs(outputs)

Expand Down Expand Up @@ -289,12 +295,14 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:

# }}}

mapper = CodeGenPreprocessor(target)
new_outputs = inline_calls(outputs)
assert isinstance(new_outputs, DictOfNamedArrays)

new_outputs = copy_dict_of_named_arrays(outputs, mapper)
mapper = CodeGenPreprocessor(target)
new_outputs = copy_dict_of_named_arrays(new_outputs, mapper)

return PreprocessResult(outputs=new_outputs,
compute_order=tuple(output_order),
bound_arguments=mapper.bound_arguments)
compute_order=tuple(output_order),
bound_arguments=mapper.bound_arguments)

# vim: fdm=marker
2 changes: 1 addition & 1 deletion pytato/distributed/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def wait_for_some_recvs() -> None:
assert count == 0
assert name not in context

return context
return {name: context[name] for name in partition.overall_output_names}

# }}}

Expand Down
18 changes: 16 additions & 2 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,21 @@ class DistributedGraphPartition:
.. attribute:: name_to_output

Mapping of placeholder names to the respective :class:`pytato.array.Array`
they represent.
they represent. This is where the actual expressions are stored, for
all parts. Observe that the :class:`DistributedGraphPart`, for the most
part, only stores names. These "outputs" may be 'part outputs' (i.e.
data computed in one part for use by another, effectively tempoarary
variables), or 'overall outputs' of the comutation.

.. attribute:: overall_output_names

The names of the outputs (in :attr:`name_to_output`) that were given to
:func:`find_distributed_partition` to specify the overall computaiton.

"""
parts: Mapping[PartId, DistributedGraphPart]
name_to_output: Mapping[str, Array]
overall_output_names: Sequence[str]

# }}}

Expand Down Expand Up @@ -365,6 +376,7 @@ def _make_distributed_partition(
sptpo_ary_to_name: Mapping[Array, str],
local_recv_id_to_recv_node: Dict[CommunicationOpIdentifier, DistributedRecv],
local_send_id_to_send_node: Dict[CommunicationOpIdentifier, DistributedSend],
overall_output_names: Sequence[str],
) -> DistributedGraphPartition:
name_to_output = {}
parts: Dict[PartId, DistributedGraphPart] = {}
Expand Down Expand Up @@ -402,6 +414,7 @@ def _make_distributed_partition(
result = DistributedGraphPartition(
parts=parts,
name_to_output=name_to_output,
overall_output_names=overall_output_names,
)

return result
Expand Down Expand Up @@ -967,7 +980,8 @@ def gen_array_name(ary: Array) -> str:
sent_ary_to_name,
sptpo_ary_to_name,
lsrdg.local_recv_id_to_recv_node,
lsrdg.local_send_id_to_send_node)
lsrdg.local_send_id_to_send_node,
tuple(outputs))

from pytato.distributed.verify import _run_partition_diagnostics
_run_partition_diagnostics(outputs, partition)
Expand Down
1 change: 1 addition & 0 deletions pytato/distributed/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def set_union(
for pid, part in partition.parts.items()
},
name_to_output=partition.name_to_output,
overall_output_names=partition.overall_output_names,
), next_tag

# }}}
Expand Down
27 changes: 27 additions & 0 deletions pytato/equality.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
IndexBase, IndexLambda, NamedArray,
Reshape, Roll, Stack, AbstractResultWithNamedArrays,
Array, DictOfNamedArrays, Placeholder, SizeParam)
from pytato.function import Call, NamedCallResult, FunctionDefinition
from pytools import memoize_method

if TYPE_CHECKING:
from pytato.loopy import LoopyCall, LoopyCallResult
Expand Down Expand Up @@ -273,6 +275,31 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool:
and expr1.tags == expr2.tags
)

@memoize_method
def map_function_definition(self, expr1: FunctionDefinition, expr2: Any
) -> bool:
return (expr1.__class__ is expr2.__class__
and expr1.parameters == expr2.parameters
and (set(expr1.returns.keys()) == set(expr2.returns.keys()))
and all(self.rec(expr1.returns[k], expr2.returns[k])
for k in expr1.returns)
and expr1.tags == expr2.tags
)

def map_call(self, expr1: Call, expr2: Any) -> bool:
return (expr1.__class__ is expr2.__class__
and self.map_function_definition(expr1.function, expr2.function)
and frozenset(expr1.bindings) == frozenset(expr2.bindings)
and all(self.rec(bnd,
expr2.bindings[name])
for name, bnd in expr1.bindings.items())
)

def map_named_call_result(self, expr1: NamedCallResult, expr2: Any) -> bool:
return (expr1.__class__ is expr2.__class__
and expr1.name == expr2.name
and self.rec(expr1._container, expr2._container))

# }}}

# vim: fdm=marker
Loading