diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py
index f241fe227..971f1f4ce 100644
--- a/pytato/visualization/dot.py
+++ b/pytato/visualization/dot.py
@@ -27,10 +27,11 @@
"""
+import gc
import html
+import re
from functools import partial
from typing import (
- TYPE_CHECKING,
Any,
Callable,
Mapping,
@@ -56,6 +57,7 @@
Stack,
)
from pytato.codegen import normalize_outputs
+from pytato.distributed.nodes import DistributedSendRefHolder
from pytato.distributed.partition import (
DistributedGraphPart,
DistributedGraphPartition,
@@ -67,10 +69,6 @@
from pytato.transform import ArrayOrNames, CachedMapper, InputGatherer
-if TYPE_CHECKING:
- from pytato.distributed.nodes import DistributedSendRefHolder
-
-
__doc__ = """
.. currentmodule:: pytato
@@ -88,6 +86,84 @@ class _SubgraphTree:
subgraphs: dict[str, _SubgraphTree]
+def extract_operation_symbol(expr):
+
+ operation_replacements = {
+ r"NaN_if": "if",
+ r"else": "else",
+ r"isnan": "is NaN",
+ r"<": "<",
+ r">": ">",
+ r"\s*==\s*": "==",
+ r"\s*!=\s*": "!=",
+ r"\s*<=\s*": "<=",
+ r"\s*>=\s*": ">=",
+ r"\s*\+\s*": "+",
+ r"\s*\-\s*": "-",
+ r"\s*\*\*\s*": "**",
+ r"\s*\*\s*": "*",
+ r"\s*/\s*": "/",
+ r"\s*//\s*": "//",
+ r"\s*%\s*": "%",
+ r"\s*or\s*": "or",
+ r"\s*and\s*": "and",
+ r"\s*not\s*": "not",
+ r"\s*<<\s*": "<<",
+ r"\s*>>\s*": ">>",
+ r"\s*\|\s*": "|",
+ r"\s*\^\s*": "^",
+ r"~\s*": "~",
+ r"\s*@\s*": "@",
+ r"\s*SumReductionOperation\s*": "Σ",
+ r"<": "<",
+ r">": ">",
+ r"&": "&",
+ }
+
+ for pattern, replacement in operation_replacements.items():
+ if re.search(pattern, expr.strip()):
+ return replacement
+
+ return expr
+
+
+def simplify_indexlambda_node_to_symbol_only(s):
+ if "IndexLambda" in s:
+ expr_match = re.search(
+ r'expr:
(.*?) | ', s
+ )
+
+ if expr_match:
+ original_expr = expr_match.group(1)
+ operation_symbol = extract_operation_symbol(original_expr)
+
+ tooltip_content = []
+ tooltip_matches = re.findall(
+ r'| (.*?) | '
+ r'(.*?) |
',
+ s
+ )
+
+ for key, value in tooltip_matches:
+ tooltip_content.append(f"{key}: {value}")
+
+ tooltip_text = ",\n".join(tooltip_content)
+
+ new_label = (
+ f'| '
+ f'{operation_symbol}'
+ f' |
'
+ )
+
+ s = (
+ f'{new_label}> '
+ f'style=filled fillcolor="white" '
+ f'tooltip="{tooltip_text}"];'
+ )
+
+ return s
+
+
class DotEmitter:
def __init__(self) -> None:
self.subgraph_to_lines: dict[tuple[str, ...], list[str]] = {}
@@ -102,7 +178,9 @@ def __call__(self, subgraph_path: tuple[str, ...], s: str) -> None:
s = remove_common_indentation(s)
for line in s.split("\n"):
- line_list.append(line)
+ simplified_line = simplify_indexlambda_node_to_symbol_only(
+ line)
+ line_list.append(simplified_line)
def _get_subgraph_tree(self) -> _SubgraphTree:
subgraph_tree = _SubgraphTree(contents=None, subgraphs={})
@@ -161,7 +239,9 @@ def emit_subgraph(sg: _SubgraphTree) -> None:
class _DotNodeInfo:
title: str
fields: dict[str, Any]
- edges: dict[str, ArrayOrNames | FunctionDefinition]
+ edges: dict[str, ArrayOrNames |
+ FunctionDefinition | tuple[int |
+ ArrayOrNames, ArrayOrNames], Array]
def stringify_tags(tags: frozenset[Tag | None]) -> str:
@@ -178,11 +258,22 @@ def stringify_shape(shape: ShapeType) -> str:
return "(" + ", ".join(components) + ")"
+def get_object_by_id(object_id: int) -> Any | ArrayOrNames:
+ """Find an object by its ID."""
+ for obj in gc.get_objects():
+ if id(obj) == object_id:
+ return obj
+ return None
+
+
class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]):
- def __init__(self) -> None:
- super().__init__()
- self.node_to_dot: dict[ArrayOrNames, _DotNodeInfo] = {}
+ def __init__(self, count_duplicates: bool = False):
+ self.node_to_dot: dict[int | ArrayOrNames, _DotNodeInfo] = {}
self.functions: set[FunctionDefinition] = set()
+ self.count_duplicates = count_duplicates
+
+ def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames:
+ return id(expr) if self.count_duplicates else expr
def get_common_dot_info(self, expr: Array) -> _DotNodeInfo:
title = type(expr).__name__
@@ -193,68 +284,91 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo:
"non_equality_tags": expr.non_equality_tags,
}
- edges: dict[str, ArrayOrNames | FunctionDefinition] = {}
+ edges: dict[str,
+ ArrayOrNames | FunctionDefinition |
+ tuple[int | AbstractResultWithNamedArrays |
+ Array, Array]] = {}
return _DotNodeInfo(title, fields, edges)
- # type-ignore-reason: incompatible with supertype
- def handle_unsupported_array(self, # type: ignore[override]
- expr: Array) -> None:
+ def process_node(self, expr: ArrayOrNames) -> None:
+ if isinstance(expr, DataWrapper):
+ self.map_data_wrapper(expr)
+ elif isinstance(expr, IndexLambda):
+ self.map_index_lambda(expr)
+ elif isinstance(expr, Stack):
+ self.map_stack(expr)
+ elif isinstance(expr, IndexBase):
+ self.map_basic_index(expr)
+ elif isinstance(expr, Einsum):
+ self.map_einsum(expr)
+ elif isinstance(expr, DictOfNamedArrays):
+ self.map_dict_of_named_arrays(expr)
+ elif isinstance(expr, LoopyCall):
+ self.map_loopy_call(expr)
+ elif isinstance(expr, DistributedSendRefHolder):
+ self.map_distributed_send_ref_holder(expr)
+ elif isinstance(expr, Call):
+ self.map_call(expr)
+ elif isinstance(expr, NamedCallResult):
+ self.map_named_call_result(expr)
+ else:
+ self.handle_unsupported_array(expr)
+
+ def handle_unsupported_array(self,
+ expr: Array) -> None:
# Default handler, does its best to guess how to handle fields.
info = self.get_common_dot_info(expr)
-
- # pylint: disable=not-an-iterable
+ expr_key = self.get_cache_key(expr)
for field in attrs.fields(type(expr)):
if field.name in info.fields:
continue
attr = getattr(expr, field.name)
-
if isinstance(attr, Array):
- self.rec(attr)
- info.edges[field.name] = attr
-
+ self.process_node(attr)
+ key = self.get_cache_key(attr)
+ info.edges[field.name] = (key, attr)
elif isinstance(attr, AbstractResultWithNamedArrays):
- self.rec(attr)
- info.edges[field.name] = attr
-
+ self.process_node(attr)
+ key = self.get_cache_key(attr)
+ info.edges[field.name] = (key, attr)
elif isinstance(attr, tuple):
info.fields[field.name] = stringify_shape(attr)
-
else:
info.fields[field.name] = str(attr)
-
- self.node_to_dot[expr] = info
+ self.node_to_dot[expr_key] = info
def map_data_wrapper(self, expr: DataWrapper) -> None:
info = self.get_common_dot_info(expr)
if expr.name is not None:
info.fields["name"] = expr.name
- # Only show summarized data
import numpy as np
with np.printoptions(threshold=4, precision=2):
info.fields["data"] = str(expr.data)
- self.node_to_dot[expr] = info
+ self.node_to_dot[self.get_cache_key(expr)] = info
def map_index_lambda(self, expr: IndexLambda) -> None:
info = self.get_common_dot_info(expr)
info.fields["expr"] = str(expr.expr)
for name, val in expr.bindings.items():
- self.rec(val)
- info.edges[name] = val
+ self.process_node(val)
+ key = self.get_cache_key(val)
+ info.edges[name] = (key, val)
- self.node_to_dot[expr] = info
+ self.node_to_dot[self.get_cache_key(expr)] = info
def map_stack(self, expr: Stack) -> None:
info = self.get_common_dot_info(expr)
info.fields["axis"] = str(expr.axis)
for i, array in enumerate(expr.arrays):
- self.rec(array)
- info.edges[str(i)] = array
+ self.process_node(array)
+ key = self.get_cache_key(array)
+ info.edges[str(i)] = (key, array)
- self.node_to_dot[expr] = info
+ self.node_to_dot[self.get_cache_key(expr)] = info
map_concatenate = map_stack
@@ -270,9 +384,10 @@ def map_basic_index(self, expr: IndexBase) -> None:
elif isinstance(index, Array):
label = f"i{i}"
- self.rec(index)
+ self.process_node(index)
+ key = self.get_cache_key(index)
indices_parts.append(label)
- info.edges[label] = index
+ info.edges[label] = (key, index)
elif index is None:
indices_parts.append("newaxis")
@@ -282,10 +397,11 @@ def map_basic_index(self, expr: IndexBase) -> None:
info.fields["indices"] = ", ".join(indices_parts)
- self.rec(expr.array)
- info.edges["array"] = expr.array
+ self.process_node(expr.array)
+ key = self.get_cache_key(expr.array)
+ info.edges["array"] = (key, expr.array)
- self.node_to_dot[expr] = info
+ self.node_to_dot[self.get_cache_key(expr)] = info
map_contiguous_advanced_index = map_basic_index
map_non_contiguous_advanced_index = map_basic_index
@@ -295,30 +411,35 @@ def map_einsum(self, expr: Einsum) -> None:
for iarg, (access_descr, val) in enumerate(zip(expr.access_descriptors,
expr.args)):
- self.rec(val)
- info.edges[f"{iarg}: {access_descr}"] = val
+ self.process_node(val)
+ key = self.get_cache_key(val)
+ info.edges[f"{iarg}: {access_descr}"] = (key, val)
- self.node_to_dot[expr] = info
+ self.node_to_dot[self.get_cache_key(expr)] = info
def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None:
- edges: dict[str, ArrayOrNames | FunctionDefinition] = {}
+ edges: dict[str, ArrayOrNames | FunctionDefinition |
+ tuple[int | ArrayOrNames, Array]] = {}
for name, val in expr._data.items():
- edges[name] = val
- self.rec(val)
+ self.process_node(val)
+ key = self.get_cache_key(val)
+ edges[name] = (key, val)
- self.node_to_dot[expr] = _DotNodeInfo(
+ self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo(
title=type(expr).__name__,
fields={},
edges=edges)
def map_loopy_call(self, expr: LoopyCall) -> None:
- edges: dict[str, ArrayOrNames | FunctionDefinition] = {}
+ edges: dict[str, ArrayOrNames | FunctionDefinition |
+ tuple[int | ArrayOrNames, Array]] = {}
for name, arg in expr.bindings.items():
if isinstance(arg, Array):
- edges[name] = arg
- self.rec(arg)
+ self.process_node(arg)
+ key = self.get_cache_key(arg)
+ edges[name] = (key, arg)
- self.node_to_dot[expr] = _DotNodeInfo(
+ self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo(
title=type(expr).__name__,
fields={"addr": hex(id(expr)), "entrypoint": expr.entrypoint},
edges=edges)
@@ -328,29 +449,31 @@ def map_distributed_send_ref_holder(
info = self.get_common_dot_info(expr)
- self.rec(expr.passthrough_data)
- info.edges["passthrough"] = expr.passthrough_data
+ self.process_node(expr.passthrough_data)
+ key = self.get_cache_key(expr.passthrough_data)
+ info.edges["passthrough"] = (key, expr.passthrough_data)
- self.rec(expr.send.data)
- info.edges["sent"] = expr.send.data
+ self.process_node(expr.send.data)
+ key = self.get_cache_key(expr.send.data)
+ info.edges["sent"] = (key, expr.send.data)
info.fields["dest_rank"] = str(expr.send.dest_rank)
-
info.fields["comm_tag"] = str(expr.send.comm_tag)
- self.node_to_dot[expr] = info
+ self.node_to_dot[self.get_cache_key(expr)] = info
def map_call(self, expr: Call) -> None:
self.functions.add(expr.function)
for bnd in expr.bindings.values():
- self.rec(bnd)
+ self.process_node(bnd)
- self.node_to_dot[expr] = _DotNodeInfo(
+ self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo(
title=expr.__class__.__name__,
edges={
"": expr.function,
- **expr.bindings},
+ **{name: (self.get_cache_key(bnd), bnd)
+ for name, bnd in expr.bindings.items()}},
fields={
"addr": hex(id(expr)),
"tags": stringify_tags(expr.tags),
@@ -358,14 +481,16 @@ def map_call(self, expr: Call) -> None:
)
def map_named_call_result(self, expr: NamedCallResult) -> None:
- self.rec(expr._container)
- self.node_to_dot[expr] = _DotNodeInfo(
+ self.process_node(expr._container)
+ key = self.get_cache_key(expr._container)
+ self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo(
title=expr.__class__.__name__,
- edges={"": expr._container},
+ edges={"": (key, expr._container)},
fields={"addr": hex(id(expr)),
"name": expr.name},
)
+
# }}}
@@ -379,6 +504,12 @@ def dot_escape_leave_space(s: str) -> str:
return html.escape(s.replace("\\", "\\\\"))
+def get_array_key(array: ArrayOrNames | FunctionDefinition | int,
+ count_duplicates: bool = False) -> Any:
+ """Return a consistent key for the array."""
+ return id(array) if count_duplicates and not isinstance(array, int) else array
+
+
# {{{ emit helpers
def _stringify_created_at(non_equality_tags: frozenset[Tag]) -> str:
@@ -391,7 +522,7 @@ def _stringify_created_at(non_equality_tags: frozenset[Tag]) -> str:
def _emit_array(emit: Callable[[str], None], title: str, fields: dict[str, Any],
- dot_node_id: str, color: str = "white") -> None:
+ dot_node_id: str, color: str = "white") -> None:
td_attrib = 'border="0"'
table_attrib = 'border="0" cellborder="1" cellspacing="0"'
@@ -416,8 +547,10 @@ def _emit_array(emit: Callable[[str], None], title: str, fields: dict[str, Any],
def _emit_name_cluster(
emit: DotEmitter, subgraph_path: tuple[str, ...],
names: Mapping[str, ArrayOrNames],
- array_to_id: Mapping[ArrayOrNames, str], id_gen: Callable[[str], str],
- label: str) -> None:
+ array_to_id: Mapping[
+ int | ArrayOrNames, str], id_gen: Callable[[str], str],
+ label: str,
+ count_duplicates: bool = False) -> None:
edges = []
cluster_subgraph_path = (*subgraph_path, f"cluster_{dot_escape(label)}")
@@ -428,7 +561,8 @@ def _emit_name_cluster(
for name, array in names.items():
name_id = id_gen(dot_escape(name))
emit_cluster(f'{name_id} [label="{dot_escape(name)}"]')
- array_id = array_to_id[array]
+ array_key = get_array_key(array, count_duplicates)
+ array_id = array_to_id[array_key]
# Edges must be outside the cluster.
edges.append((name_id, array_id))
@@ -439,16 +573,18 @@ def _emit_name_cluster(
def _emit_function(
emitter: DotEmitter, subgraph_path: tuple[str, ...],
id_gen: UniqueNameGenerator,
- node_to_dot: Mapping[ArrayOrNames, _DotNodeInfo],
+ node_to_dot: Mapping[int | ArrayOrNames, _DotNodeInfo],
func_to_id: Mapping[FunctionDefinition, str],
- outputs: Mapping[str, Array]) -> None:
+ outputs: Mapping[str, Array],
+ count_duplicates: bool = False) -> None:
input_arrays: list[Array] = []
- internal_arrays: list[ArrayOrNames] = []
- array_to_id: dict[ArrayOrNames, str] = {}
+ internal_arrays: list[int | ArrayOrNames] = []
+ array_to_id: dict[int | ArrayOrNames, str] = {}
emit = partial(emitter, subgraph_path)
for array in node_to_dot:
- array_to_id[array] = id_gen("array")
+ key = get_array_key(array, count_duplicates)
+ array_to_id[key] = id_gen("array")
if isinstance(array, InputArgumentBase):
input_arrays.append(array)
else:
@@ -460,36 +596,47 @@ def _emit_function(
emit_input('label="Arguments"')
for array in input_arrays:
+ key = get_array_key(array, count_duplicates)
_emit_array(
emit_input,
node_to_dot[array].title,
node_to_dot[array].fields,
- array_to_id[array])
+ array_to_id[key])
# Emit non-inputs.
for array in internal_arrays:
+ key = get_array_key(array, count_duplicates)
_emit_array(emit,
node_to_dot[array].title,
node_to_dot[array].fields,
- array_to_id[array])
+ array_to_id[key])
# Emit edges.
for array, node in node_to_dot.items():
- for label, tail_item in node.edges.items():
- head = array_to_id[array]
+ key = get_array_key(array, count_duplicates)
+ for label, edge_info in node.edges.items():
+ if isinstance(edge_info, tuple):
+ tail_key, tail_item = edge_info
+ else:
+ tail_item = edge_info
+ tail_key = get_array_key(tail_item, count_duplicates)
+
+ head = array_to_id[key]
if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)):
- tail = array_to_id[tail_item]
+ tail = array_to_id[tail_key]
elif isinstance(tail_item, FunctionDefinition):
tail = func_to_id[tail_item]
else:
raise ValueError(
- f"unexpected type of tail on edge: {type(tail_item)}")
+ f"unexpected type of tail on edge: {type(tail_item)}")
emit(f'{tail} -> {head} [label="{dot_escape(label)}"]')
# Emit output/namespace name mappings.
_emit_name_cluster(
- emitter, subgraph_path, outputs, array_to_id, id_gen, label="Returns")
+ emitter, subgraph_path,
+ outputs, array_to_id, id_gen,
+ label="Returns", count_duplicates=count_duplicates)
# }}}
@@ -507,20 +654,21 @@ def _get_function_name(f: FunctionDefinition) -> str | None:
def _gather_partition_node_information(
id_gen: UniqueNameGenerator,
- partition: DistributedGraphPartition
+ partition: DistributedGraphPartition,
+ count_duplicates: bool = False
) -> tuple[
- Mapping[PartId, Mapping[FunctionDefinition, str]],
- Mapping[tuple[PartId, FunctionDefinition | None],
- Mapping[ArrayOrNames, _DotNodeInfo]]
- ]:
+ dict[PartId, dict[FunctionDefinition, str]],
+ dict[tuple[PartId, FunctionDefinition | None],
+ dict[int | ArrayOrNames, _DotNodeInfo]]]:
part_id_to_func_to_id: dict[PartId, dict[FunctionDefinition, str]] = {}
part_id_func_to_node_info: dict[tuple[PartId, FunctionDefinition | None],
- dict[ArrayOrNames, _DotNodeInfo]] = {}
+ dict[int | ArrayOrNames,
+ _DotNodeInfo]] = {}
for part in partition.parts.values():
- mapper = ArrayToDotNodeInfoMapper()
+ mapper = ArrayToDotNodeInfoMapper(count_duplicates)
for out_name in part.output_names:
- mapper(partition.name_to_output[out_name])
+ mapper.process_node(partition.name_to_output[out_name])
part_id_func_to_node_info[part.pid, None] = mapper.node_to_dot
part_id_to_func_to_id[part.pid] = {}
@@ -535,9 +683,9 @@ def gather_function_info(f: FunctionDefinition) -> None:
if key in part_id_func_to_node_info:
return
- mapper = ArrayToDotNodeInfoMapper()
+ mapper = ArrayToDotNodeInfoMapper(count_duplicates)
for elem in f.returns.values():
- mapper(elem)
+ mapper.process_node(elem)
part_id_func_to_node_info[key] = mapper.node_to_dot
@@ -563,10 +711,12 @@ def gather_function_info(f: FunctionDefinition) -> None:
return part_id_to_func_to_id, part_id_func_to_node_info
+
# }}}
-def get_dot_graph(result: Array | DictOfNamedArrays) -> str:
+def get_dot_graph(result: Array | DictOfNamedArrays,
+ count_duplicates: bool = False) -> str:
r"""Return a string in the `dot `_ language depicting the
graph of the computation of *result*.
@@ -576,30 +726,32 @@ def get_dot_graph(result: Array | DictOfNamedArrays) -> str:
outputs: DictOfNamedArrays = normalize_outputs(result)
- return get_dot_graph_from_partition(
- DistributedGraphPartition(
- parts={
- None: DistributedGraphPart(
- pid=None,
- needed_pids=frozenset(),
- user_input_names=frozenset(
+ partition = DistributedGraphPartition(
+ parts={
+ None: DistributedGraphPart(
+ pid=None,
+ needed_pids=frozenset(),
+ user_input_names=frozenset(
expr.name
for expr in InputGatherer()(outputs)
if isinstance(expr, Placeholder)
),
- partition_input_names=frozenset(),
- output_names=frozenset(outputs.keys()),
- name_to_recv_node={},
- name_to_send_nodes={},
- )
- },
- name_to_output=outputs._data,
- overall_output_names=tuple(outputs),
- ))
-
-
-def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
- r"""Return a string in the `dot `_ language depicting the
+ partition_input_names=frozenset(),
+ output_names=frozenset(outputs.keys()),
+ name_to_recv_node={},
+ name_to_send_nodes={},
+ )
+ },
+ name_to_output=outputs._data,
+ overall_output_names=tuple(outputs),
+ )
+
+ return get_dot_graph_from_partition(partition, count_duplicates)
+
+
+def get_dot_graph_from_partition(partition: DistributedGraphPartition,
+ count_duplicates: bool = False) -> str:
+ """Return a string in the `dot `_ language depicting the
graph of the partitioned computation of *partition*.
:arg partition: Outputs of :func:`~pytato.find_distributed_partition`.
@@ -611,9 +763,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
# The "None" function is the body of the partition.
part_id_to_func_to_id, part_id_func_to_node_info = \
- _gather_partition_node_information(id_gen, partition)
-
- # }}}
+ _gather_partition_node_information(id_gen, partition, count_duplicates)
emitter = DotEmitter()
emit_root = partial(emitter, ())
@@ -622,8 +772,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
emit_root("node [shape=rectangle]")
- placeholder_to_id: dict[ArrayOrNames, str] = {}
- part_id_to_array_to_id: dict[PartId, dict[ArrayOrNames, str]] = {}
+ placeholder_to_id: dict[int | ArrayOrNames, str] = {}
+ part_id_to_array_to_id: dict[PartId, dict[int | ArrayOrNames, str]] = {}
part_id_to_id = {pid: dot_escape(str(pid)) for pid in partition.parts}
assert len(set(part_id_to_id.values())) == len(partition.parts)
@@ -633,16 +783,18 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
for part in partition.parts.values():
array_to_id = {}
for array in part_id_func_to_node_info[part.pid, None].keys():
+ if isinstance(array, int): # if the key is an ID
+ array = get_object_by_id(array)
+ key = get_array_key(array, count_duplicates)
if isinstance(array, Placeholder):
- # Placeholders are only emitted once
- if array in placeholder_to_id:
- node_id = placeholder_to_id[array]
+ if key in placeholder_to_id:
+ node_id = placeholder_to_id[key]
else:
node_id = id_gen("array")
- placeholder_to_id[array] = node_id
+ placeholder_to_id[key] = node_id
else:
node_id = id_gen("array")
- array_to_id[array] = node_id
+ array_to_id[key] = node_id
part_id_to_array_to_id[part.pid] = array_to_id
# }}}
@@ -679,22 +831,22 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
_emit_function(emitter, func_subgraph_path,
id_gen, part_id_func_to_node_info[part.pid, func],
part_id_to_func_to_id[part.pid],
- func.returns)
+ func.returns,
+ count_duplicates=count_duplicates)
# }}}
# {{{ emit receives nodes
part_dist_recv_var_name_to_node_id = {}
- for name, recv in (
- part.name_to_recv_node.items()):
+ for name, recv in part.name_to_recv_node.items():
node_id = id_gen("recv")
_emit_array(emit_part, "DistributedRecv", {
"shape": stringify_shape(recv.shape),
"dtype": str(recv.dtype),
"src_rank": str(recv.src_rank),
"comm_tag": str(recv.comm_tag),
- }, node_id)
+ }, node_id)
part_dist_recv_var_name_to_node_id[name] = node_id
@@ -705,6 +857,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
internal_arrays: list[ArrayOrNames] = []
for array in part_node_to_info.keys():
+ if isinstance(array, int): # if the key is an ID
+ array = get_object_by_id(array)
if isinstance(array, InputArgumentBase):
input_arrays.append(array)
else:
@@ -718,26 +872,26 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
# subgraphs.
for array in input_arrays:
+ key = array = get_array_key(array, count_duplicates)
if not isinstance(array, Placeholder):
_emit_array(emit_part,
part_node_to_info[array].title,
part_node_to_info[array].fields,
- array_to_id[array], "deepskyblue")
+ array_to_id[key], "deepskyblue")
else:
# Is a Placeholder
- if array in emitted_placeholders:
+ if key in emitted_placeholders:
continue
-
_emit_array(emit_root,
part_node_to_info[array].title,
part_node_to_info[array].fields,
- array_to_id[array], "deepskyblue")
+ array_to_id[key], "deepskyblue")
# Emit cross-partition edges
if array.name in part_dist_recv_var_name_to_node_id:
tgt = part_dist_recv_var_name_to_node_id[array.name]
- emit_root(f"{tgt} -> {array_to_id[array]} [style=dotted]")
- emitted_placeholders.add(array)
+ emit_root(f"{tgt} -> {array_to_id[key]} [style=dotted]")
+ emitted_placeholders.add(key)
elif array.name in part.user_input_names:
# no arrows for these
pass
@@ -750,18 +904,22 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
break
assert computing_pid is not None
tgt = part_id_to_array_to_id[computing_pid][
- partition.name_to_output[array.name]]
- emit_root(f"{tgt} -> {array_to_id[array]} [style=dashed]")
- emitted_placeholders.add(array)
+ id(partition.name_to_output[array.name])
+ if count_duplicates
+ else partition.name_to_output[array.name]]
+ emit_root(f"{tgt} -> {array_to_id[key]} [style=dashed]")
+ emitted_placeholders.add(key)
# }}}
# Emit internal nodes
+
for array in internal_arrays:
+ key = array = get_array_key(array, count_duplicates)
_emit_array(emit_part,
part_node_to_info[array].title,
part_node_to_info[array].fields,
- array_to_id[array])
+ array_to_id[key])
# {{{ emit send nodes if distributed
@@ -772,35 +930,47 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
_emit_array(emit_part, "DistributedSend", {
"dest_rank": str(send.dest_rank),
"comm_tag": str(send.comm_tag),
- }, node_id)
+ }, node_id)
# If an edge is emitted in a subgraph, it drags its
# nodes into the subgraph, too. Not what we want.
+ data = id(send.data) if count_duplicates else send.data
emit_root(
- f"{array_to_id[send.data]} -> {node_id}"
- f'[style=dotted, label="{dot_escape(name)}"]')
+ f"{array_to_id[data]} -> {node_id}"
+ f'[style=dotted, label="{dot_escape(name)}"]')
# }}}
# Emit intra-partition edges
for array, node in part_node_to_info.items():
- for label, tail_item in node.edges.items():
- head = array_to_id[array]
+ key = get_array_key(array, count_duplicates)
+
+ tail_item: Array | AbstractResultWithNamedArrays | FunctionDefinition
+ for label, edge_info in node.edges.items():
+ if isinstance(edge_info, tuple):
+ tail_key, tail_item = edge_info
+ else:
+ tail_item = edge_info
+ tail_key = get_array_key(tail_item, count_duplicates)
+
+ head = array_to_id[key]
if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)):
- tail = array_to_id[tail_item]
+ tail = array_to_id[tail_key]
elif isinstance(tail_item, FunctionDefinition):
tail = part_id_to_func_to_id[part.pid][tail_item]
else:
raise ValueError(
- f"unexpected type of tail on edge: {type(tail_item)}")
+ f"unexpected type of tail on edge: {type(tail_item)}")
emit_root(f'{tail} -> {head} [label="{dot_escape(label)}"]')
_emit_name_cluster(
- emitter, part_subgraph_path,
- {name: partition.name_to_output[name] for name in part.output_names},
- array_to_id, id_gen, "Part outputs")
+ emitter, part_subgraph_path,
+ {name: partition.name_to_output[name]
+ for name in part.output_names},
+ array_to_id, id_gen, "Part outputs",
+ count_duplicates)
# }}}
@@ -810,15 +980,16 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
# {{{ draw overall outputs
- combined_array_to_id: dict[ArrayOrNames, str] = {}
+ combined_array_to_id: dict[int | ArrayOrNames, str] = {}
for part_id in partition.parts.keys():
combined_array_to_id.update(part_id_to_array_to_id[part_id])
_emit_name_cluster(
- emitter, (),
- {name: partition.name_to_output[name]
- for name in partition.overall_output_names},
- combined_array_to_id, id_gen, "Overall outputs")
+ emitter, (),
+ {name: partition.name_to_output[name]
+ for name in partition.overall_output_names},
+ combined_array_to_id, id_gen, "Overall outputs",
+ count_duplicates)
# }}}
@@ -826,7 +997,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str:
def show_dot_graph(result: str | Array | DictOfNamedArrays | DistributedGraphPartition,
- **kwargs: Any) -> None:
+ count_duplicates: bool = False,
+ **kwargs: Any) -> None:
"""Show a graph representing the computation of *result* in a browser.
:arg result: Outputs of the computation (cf.
@@ -839,9 +1011,9 @@ def show_dot_graph(result: str | Array | DictOfNamedArrays | DistributedGraphPar
if isinstance(result, str):
dot_code = result
elif isinstance(result, DistributedGraphPartition):
- dot_code = get_dot_graph_from_partition(result)
+ dot_code = get_dot_graph_from_partition(result, count_duplicates)
else:
- dot_code = get_dot_graph(result)
+ dot_code = get_dot_graph(result, count_duplicates)
from pytools.graphviz import show_dot
show_dot(dot_code, **kwargs)
diff --git a/test/test_pytato.py b/test/test_pytato.py
index f67e7e5f1..7a7b79cf8 100644
--- a/test/test_pytato.py
+++ b/test/test_pytato.py
@@ -32,7 +32,6 @@
import attrs
import numpy as np
import pytest
-from testlib import RandomDAGContext, make_random_dag
from pyopencl.tools import ( # noqa
pytest_generate_tests_for_pyopencl as pytest_generate_tests,
@@ -765,6 +764,53 @@ def test_large_dag_with_duplicates_count():
dag, count_duplicates=False)
+def test_duplicate_node_count():
+ from testlib import get_random_pt_dag
+
+ from pytato.analysis import get_node_multiplicities, get_num_nodes
+ for i in range(80):
+ dag = get_random_pt_dag(seed=i, axis_len=5)
+
+ # Get the number of types of expressions
+ node_count = get_num_nodes(dag, count_duplicates=True)
+
+ # Get the number of expressions and the amount they're called
+ node_multiplicity = get_node_multiplicities(dag)
+
+ # Get difference in duplicates
+ num_duplicates = sum(
+ count - 1 for count in node_multiplicity.values() if count > 1)
+ # Check that duplicates are correctly calculated
+ assert node_count - num_duplicates == len(
+ pt.transform.DependencyMapper()(dag))
+
+
+def test_duplicate_nodes_with_comm_count():
+ from testlib import get_random_pt_dag_with_send_recv_nodes
+
+ from pytato.analysis import get_node_multiplicities, get_num_nodes
+
+ rank = 0
+ size = 2
+ for i in range(20):
+ dag = get_random_pt_dag_with_send_recv_nodes(
+ seed=i, rank=rank, size=size)
+
+ # Get the number of types of expressions
+ node_count = get_num_nodes(dag, count_duplicates=True)
+
+ # Get the number of expressions and the amount they're called
+ node_multiplicity = get_node_multiplicities(dag)
+
+ # Get difference in duplicates
+ num_duplicates = sum(
+ count - 1 for count in node_multiplicity.values() if count > 1)
+
+ # Check that duplicates are correctly calculated
+ assert node_count - num_duplicates == len(
+ pt.transform.DependencyMapper()(dag))
+
+
def test_rec_get_user_nodes():
x1 = pt.make_placeholder("x1", shape=(10, 4))
x2 = pt.make_placeholder("x2", shape=(10, 4))
@@ -1336,6 +1382,7 @@ def test_rewrite_einsums_with_no_broadcasts():
def test_dot_visualizers():
+ from testlib import RandomDAGContext, make_random_dag
a = pt.make_placeholder("A", shape=(10, 4))
x1 = pt.make_placeholder("x1", shape=4)
x2 = pt.make_placeholder("x2", shape=4)
@@ -1364,6 +1411,79 @@ def test_dot_visualizers():
# }}}
+def test_duplicate_node_count_dot_graph():
+ from testlib import count_dot_graph_nodes, get_random_pt_dag
+
+ from pytato.analysis import get_num_nodes
+ from pytato.visualization.dot import get_dot_graph
+
+ for i in range(80):
+ dag = get_random_pt_dag(seed=i, axis_len=5)
+
+ # Generate dot graph with duplicates
+ dot_graph = get_dot_graph(dag, count_duplicates=True)
+ node_counts = count_dot_graph_nodes(dot_graph)
+
+ assert len(node_counts) == get_num_nodes(dag, count_duplicates=True)
+
+ # Generate dot graph without duplicates
+ dot_graph = get_dot_graph(dag, count_duplicates=False)
+ node_counts = count_dot_graph_nodes(dot_graph)
+
+ # Verify node counts without duplicates
+ assert len(node_counts) == get_num_nodes(dag, count_duplicates=False)
+
+
+def test_duplicate_nodes_with_comm_count_dot_graph():
+ from testlib import count_dot_graph_nodes, get_random_pt_dag_with_send_recv_nodes
+
+ from pytato.analysis import get_num_nodes
+ from pytato.visualization.dot import get_dot_graph
+
+ rank = 0
+ size = 2
+ for i in range(20):
+ dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank,
+ size=size)
+
+ # Generate dot graph with duplicates
+ dot_graph = get_dot_graph(dag, count_duplicates=True)
+ node_counts = count_dot_graph_nodes(dot_graph)
+
+ assert len(node_counts) == get_num_nodes(dag, count_duplicates=True)
+
+ # Generate dot graph without duplicates
+ dot_graph = get_dot_graph(dag, count_duplicates=False)
+ node_counts = count_dot_graph_nodes(dot_graph)
+
+ # Verify node counts without duplicates
+ assert len(node_counts) == get_num_nodes(dag, count_duplicates=False)
+
+
+def test_large_dot_graph_with_duplicates_count():
+ from testlib import count_dot_graph_nodes, make_large_dag
+
+ from pytato.analysis import get_num_nodes
+ from pytato.visualization.dot import get_dot_graph
+
+ iterations = 100
+ dag = make_large_dag(iterations, seed=42)
+
+ # Generate dot graph with duplicates
+ dot_graph = get_dot_graph(dag, count_duplicates=True)
+ node_counts = count_dot_graph_nodes(dot_graph)
+
+ # Verify node counts with duplicates
+ assert len(node_counts) == get_num_nodes(dag, count_duplicates=True)
+
+ # Generate dot graph without duplicates
+ dot_graph = get_dot_graph(dag, count_duplicates=False)
+ node_counts = count_dot_graph_nodes(dot_graph)
+
+ # Verify node counts without duplicates
+ assert len(node_counts) == get_num_nodes(dag, count_duplicates=False)
+
+
if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
diff --git a/test/testlib.py b/test/testlib.py
index 53bf79436..3136a7fb4 100644
--- a/test/testlib.py
+++ b/test/testlib.py
@@ -2,6 +2,7 @@
import operator
import random
+import re
import types
from typing import Any, Callable, Sequence
@@ -395,6 +396,23 @@ def make_large_dag_with_duplicates(iterations: int,
result = pt.sum(combined_expr, axis=0)
return pt.make_dict_of_named_arrays({"result": result})
+
+def count_dot_graph_nodes(dot_graph: str) -> dict[Any, int]:
+ """
+ Parses a dot graph and returns a dictionary with
+ the count of each unique node identifier.
+ """
+
+ node_pattern = re.compile(r"(\barray_\d+\b|\barray\b)")
+
+ nodes = node_pattern.findall(dot_graph)
+
+ node_counts: dict[Any, int] = {}
+ for node in nodes:
+ node_counts[node] = node_counts.get(node, 0) + 1
+
+ return node_counts
+
# }}}