From 3e19358d7cf2f57a45bccd6759a61ca72890eab7 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:08:57 -0600 Subject: [PATCH 01/19] Add node counter tests --- pytato/analysis/__init__.py | 35 +++++++++++---- test/test_pytato.py | 90 ++++++++++++++++++++++++++++++++++++- test/testlib.py | 26 ++++++++++- 3 files changed, 139 insertions(+), 12 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5bf374746..41c21aa9d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,7 +26,7 @@ """ from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - TYPE_CHECKING) + Type, TYPE_CHECKING) from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, @@ -49,6 +49,8 @@ .. autofunction:: get_num_nodes +.. autofunction:: get_node_type_counts + .. autofunction:: get_num_call_sites .. autoclass:: DirectPredecessorsGetter @@ -381,23 +383,38 @@ def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]: @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NodeCountMapper(CachedWalkMapper): """ - Counts the number of nodes in a DAG. + Counts the number of nodes of a given type in a DAG. - .. attribute:: count + .. attribute:: counts - The number of nodes. + Dictionary mapping node types to number of nodes of that type. """ def __init__(self) -> None: + from collections import defaultdict super().__init__() - self.count = 0 + self.counts = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: - return id(expr) + # does NOT account for duplicate nodes + return expr def post_visit(self, expr: Any) -> None: - self.count += 1 + self.counts[type(expr)] += 1 + +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + """ + Returns a dictionary mapping node types to node count for that type + in DAG *outputs*. + """ + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper() + ncm(outputs) + return ncm.counts def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" @@ -408,7 +425,7 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: ncm = NodeCountMapper() ncm(outputs) - return ncm.count + return sum(ncm.counts.values()) # }}} @@ -463,4 +480,4 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} -# vim: fdm=marker +# vim: fdm=marker \ No newline at end of file diff --git a/test/test_pytato.py b/test/test_pytato.py index 8939073cb..cea10480c 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -26,7 +26,6 @@ """ import sys - import numpy as np import pytest import attrs @@ -585,7 +584,7 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) -def test_nodecountmapper(): +def test_node_count_mapper(): from testlib import RandomDAGContext, make_random_dag from pytato.analysis import get_num_nodes @@ -600,6 +599,93 @@ def test_nodecountmapper(): assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) +def test_empty_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + empty_dag = pt.make_dict_of_named_arrays({}) + + # Verify that get_num_nodes returns 0 for an empty DAG + assert get_num_nodes(empty_dag) - 1 == 0 + + counts = get_node_type_counts(empty_dag) + assert len(counts) == 1 + +def test_single_node_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + data = np.random.rand(4, 4) + single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + + # Get counts per node type + node_counts = get_node_type_counts(single_node_dag) + + # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # DictOfNamedArrays is automatically added + assert node_counts == {pt.DataWrapper: 1, pt.DictOfNamedArrays: 1} + assert sum(node_counts.values()) - 1 == 1 # Total node count is 1 + + # Get total number of nodes + total_nodes = get_num_nodes(single_node_dag) + + assert total_nodes - 1 == 1 + + +def test_small_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + # Make a DAG using two nodes and one operation + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + + # Verify that get_num_nodes returns 2 for a DAG with two nodes + assert get_num_nodes(dag) - 1 == 2 + + counts = get_node_type_counts(dag) + assert len(counts) - 1 == 2 + assert counts[pt.array.Placeholder] == 1 # "a" + assert counts[pt.array.IndexLambda] == 1 # single operation + + +def test_large_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + from testlib import make_large_dag + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of nodes is equal to iterations + 1 (placeholder) + assert get_num_nodes(dag) - 1 == iterations + 1 + + # Verify that the counts dictionary has correct counts for the complicated DAG + counts = get_node_type_counts(dag) + assert len(counts) >= 1 + assert counts[pt.array.Placeholder] == 1 + assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert sum(counts.values()) - 1 == iterations + 1 + + +def test_random_dag_count(): + from testlib import get_random_pt_dag + from pytato.analysis import get_num_nodes + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. + assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + +def test_random_dag_with_comm_count(): + from testlib import get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes + rank = 0 + size = 2 + for i in range(10): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. + assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index 5cd1342d3..cdf827e96 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -311,6 +311,30 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: convert_dws_to_placeholders=convert_dws_to_placeholders, additional_generators=[(comm_fake_probability, gen_comm)]) +def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: + """ + Builds a DAG with emphasis on number of operations. + """ + import random + import operator + + rng = np.random.default_rng(seed) + random.seed(seed) + + # Begin with a placeholder + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a + + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] + + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) + + return pt.make_dict_of_named_arrays({"result": current}) + # }}} @@ -369,4 +393,4 @@ class QuuxTag(TestlibTag): # }}} -# vim: foldmethod=marker +# vim: foldmethod=marker \ No newline at end of file From ea2402c9fe933dbf30cee8a71cf5247387461e65 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:41:47 -0600 Subject: [PATCH 02/19] CI fixes --- doc/conf.py | 1 + pytato/analysis/__init__.py | 10 ++++++---- test/test_pytato.py | 14 ++++++++------ test/testlib.py | 37 +++++++++++++++++++------------------ 4 files changed, 34 insertions(+), 28 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 081642f1d..e6f7ac0c0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -46,6 +46,7 @@ nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"typing_extensions(.+)"], + ["py:class", r"numpy.bool_"], # As of 2023-10-05, it doesn't look like there's sphinx documentation # available. ["py:class", r"immutabledict(.*)"], diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 41c21aa9d..5da4ea70e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -393,16 +393,17 @@ class NodeCountMapper(CachedWalkMapper): def __init__(self) -> None: from collections import defaultdict super().__init__() - self.counts = defaultdict(int) + self.counts = defaultdict(int) # type: Dict[Type[Any], int] - def get_cache_key(self, expr: ArrayOrNames) -> int: + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: # does NOT account for duplicate nodes return expr def post_visit(self, expr: Any) -> None: self.counts[type(expr)] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -416,6 +417,7 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, return ncm.counts + def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" @@ -480,4 +482,4 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} -# vim: fdm=marker \ No newline at end of file +# vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index cea10480c..51e5381d5 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -610,6 +610,7 @@ def test_empty_dag_count(): counts = get_node_type_counts(empty_dag) assert len(counts) == 1 + def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts @@ -626,7 +627,7 @@ def test_single_node_dag_count(): # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) - + assert total_nodes - 1 == 1 @@ -636,15 +637,15 @@ def test_small_dag_count(): # Make a DAG using two nodes and one operation a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) b = a + 1 - dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes assert get_num_nodes(dag) - 1 == 2 counts = get_node_type_counts(dag) assert len(counts) - 1 == 2 - assert counts[pt.array.Placeholder] == 1 # "a" - assert counts[pt.array.IndexLambda] == 1 # single operation + assert counts[pt.array.Placeholder] == 1 # "a" + assert counts[pt.array.IndexLambda] == 1 # single operation def test_large_dag_count(): @@ -661,7 +662,7 @@ def test_large_dag_count(): counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 - assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert counts[pt.array.IndexLambda] == 100 # 100 operations assert sum(counts.values()) - 1 == iterations + 1 @@ -670,10 +671,11 @@ def test_random_dag_count(): from pytato.analysis import get_num_nodes for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + def test_random_dag_with_comm_count(): from testlib import get_random_pt_dag_with_send_recv_nodes from pytato.analysis import get_num_nodes diff --git a/test/testlib.py b/test/testlib.py index cdf827e96..e15489c4b 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -311,29 +311,30 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: convert_dws_to_placeholders=convert_dws_to_placeholders, additional_generators=[(comm_fake_probability, gen_comm)]) + def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: - """ - Builds a DAG with emphasis on number of operations. - """ - import random - import operator + """ + Builds a DAG with emphasis on number of operations. + """ + import random + import operator - rng = np.random.default_rng(seed) - random.seed(seed) + rng = np.random.default_rng(seed) + random.seed(seed) - # Begin with a placeholder - a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) - current = a + # Begin with a placeholder + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a - # Will randomly choose from the operators - operations = [operator.add, operator.sub, operator.mul, operator.truediv] + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] - for _ in range(iterations): - operation = random.choice(operations) - value = rng.uniform(1, 10) - current = operation(current, value) + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) - return pt.make_dict_of_named_arrays({"result": current}) + return pt.make_dict_of_named_arrays({"result": current}) # }}} @@ -393,4 +394,4 @@ class QuuxTag(TestlibTag): # }}} -# vim: foldmethod=marker \ No newline at end of file +# vim: foldmethod=marker From b122aa9a6e87419bd5ddc77ac9ca098f6091d7c0 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:49:18 -0600 Subject: [PATCH 03/19] Add comments --- doc/conf.py | 1 - test/testlib.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index e6f7ac0c0..081642f1d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -46,7 +46,6 @@ nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"typing_extensions(.+)"], - ["py:class", r"numpy.bool_"], # As of 2023-10-05, it doesn't look like there's sphinx documentation # available. ["py:class", r"immutabledict(.*)"], diff --git a/test/testlib.py b/test/testlib.py index e15489c4b..a208f0816 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -334,6 +334,7 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: value = rng.uniform(1, 10) current = operation(current, value) + # DAG should have `iterations` number of operations return pt.make_dict_of_named_arrays({"result": current}) # }}} From 4a52c8d8ce400b71461e97424bb62fa43e5ce28e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 18 Jun 2024 10:08:34 -0600 Subject: [PATCH 04/19] Remove unnecessary test --- test/test_pytato.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 51e5381d5..962fe337e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -584,21 +584,6 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) -def test_node_count_mapper(): - from testlib import RandomDAGContext, make_random_dag - from pytato.analysis import get_num_nodes - - axis_len = 5 - - for i in range(10): - rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=False) - dag = make_random_dag(rdagc) - - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) - - def test_empty_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts From 570eda4d72c46b4922c662c8f80bc8ae976d5262 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Sun, 23 Jun 2024 16:52:02 -0600 Subject: [PATCH 05/19] Add duplicate node functionality and tests --- pytato/analysis/__init__.py | 38 +++++++++++++++------- test/test_pytato.py | 65 +++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 12 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5da4ea70e..f009f89f6 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -385,25 +385,29 @@ class NodeCountMapper(CachedWalkMapper): """ Counts the number of nodes of a given type in a DAG. - .. attribute:: counts + .. attribute:: expr_type_counts + .. attribute:: count_duplicates + .. attribute:: expr_call_counts Dictionary mapping node types to number of nodes of that type. """ - def __init__(self) -> None: + def __init__(self, count_duplicates=False) -> None: # added parameter from collections import defaultdict super().__init__() - self.counts = defaultdict(int) # type: Dict[Type[Any], int] + self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] + self.count_duplicates = count_duplicates + self.expr_call_counts = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: - # does NOT account for duplicate nodes - return expr + return id(expr) if self.count_duplicates else expr # returns unique nodes only if count_duplicates is True def post_visit(self, expr: Any) -> None: - self.counts[type(expr)] += 1 + self.expr_type_counts[type(expr)] += 1 + self.expr_call_counts[expr] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -412,22 +416,32 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper() + ncm = NodeCountMapper(count_duplicates) ncm(outputs) - return ncm.counts + return ncm.expr_type_counts -def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: +def get_num_nodes(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> int: """Returns the number of nodes in DAG *outputs*.""" from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper() + ncm = NodeCountMapper(count_duplicates) + ncm(outputs) + + return sum(ncm.expr_type_counts.values()) + + +def get_expr_calls(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper(count_duplicates) ncm(outputs) - return sum(ncm.counts.values()) + return ncm.expr_call_counts # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 962fe337e..a1f3bc518 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -672,6 +672,71 @@ def test_random_dag_with_comm_count(): # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) +def test_duplicate_node_count(): + from testlib import get_random_pt_dag + from pytato.analysis import get_num_nodes, get_expr_calls + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Get the number of types of expressions + node_count = get_num_nodes(dag, count_duplicates=True) + + # Get the number of expressions and the amount they're called + expr_counts = get_expr_calls(dag, count_duplicates=True) + + num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + # Check that duplicates are correctly calculated + assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + + +def test_duplicate_nodes_with_comm_count(): + from testlib import get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes, get_expr_calls + + rank = 0 + size = 2 + for i in range(20): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Get the number of types of expressions + node_count = get_num_nodes(dag, count_duplicates=True) + + # Get the number of expressions and the amount they're called + expr_counts = get_expr_calls(dag, count_duplicates=True) + + num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + + # Check that duplicates are correctly calculated + assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + + +def test_large_dag_with_duplicates_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts, get_expr_calls + from testlib import make_large_dag + import pytato as pt + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of nodes is equal to iterations + 1 (placeholder) + node_count = get_num_nodes(dag, count_duplicates=True) + assert node_count - 1 == iterations + 1 + + # Get the number of expressions and the amount they're called + expr_counts = get_expr_calls(dag, count_duplicates=True) + + num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + + # Verify that the counts dictionary has correct counts for the complicated DAG + counts = get_node_type_counts(dag, count_duplicates=True) + assert len(counts) >= 1 + assert counts[pt.array.Placeholder] == 1 + assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert sum(counts.values()) - 1 == iterations + 1 + + # Check that duplicates are correctly calculated + assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) From d8dbe62f5a17a477b84592d8817d11b547bf59ee Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:16:20 -0600 Subject: [PATCH 06/19] Remove incrementation for DictOfNamedArrays and update tests --- pytato/analysis/__init__.py | 11 +++++++++-- test/test_pytato.py | 23 +++++++++++------------ 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5da4ea70e..4938f6794 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -400,13 +400,16 @@ def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: return expr def post_visit(self, expr: Any) -> None: - self.counts[type(expr)] += 1 + if type(expr) is not DictOfNamedArrays: + self.counts[type(expr)] += 1 def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. + + `DictOfNamedArrays` are added when *outputs* is normalized and ignored. """ from pytato.codegen import normalize_outputs @@ -419,7 +422,11 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: - """Returns the number of nodes in DAG *outputs*.""" + """ + Returns the number of nodes in DAG *outputs*. + + `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) diff --git a/test/test_pytato.py b/test/test_pytato.py index 962fe337e..3d9e2a684 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -590,10 +590,10 @@ def test_empty_dag_count(): empty_dag = pt.make_dict_of_named_arrays({}) # Verify that get_num_nodes returns 0 for an empty DAG - assert get_num_nodes(empty_dag) - 1 == 0 + assert get_num_nodes(empty_dag) == 0 counts = get_node_type_counts(empty_dag) - assert len(counts) == 1 + assert len(counts) == 0 def test_single_node_dag_count(): @@ -606,14 +606,13 @@ def test_single_node_dag_count(): node_counts = get_node_type_counts(single_node_dag) # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays - # DictOfNamedArrays is automatically added - assert node_counts == {pt.DataWrapper: 1, pt.DictOfNamedArrays: 1} - assert sum(node_counts.values()) - 1 == 1 # Total node count is 1 + assert node_counts == {pt.DataWrapper: 1} + assert sum(node_counts.values()) == 1 # Total node count is 1 # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) - assert total_nodes - 1 == 1 + assert total_nodes == 1 def test_small_dag_count(): @@ -625,10 +624,10 @@ def test_small_dag_count(): dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes - assert get_num_nodes(dag) - 1 == 2 + assert get_num_nodes(dag) == 2 counts = get_node_type_counts(dag) - assert len(counts) - 1 == 2 + assert len(counts) == 2 assert counts[pt.array.Placeholder] == 1 # "a" assert counts[pt.array.IndexLambda] == 1 # single operation @@ -641,14 +640,14 @@ def test_large_dag_count(): dag = make_large_dag(iterations, seed=42) # Verify that the number of nodes is equal to iterations + 1 (placeholder) - assert get_num_nodes(dag) - 1 == iterations + 1 + assert get_num_nodes(dag) == iterations + 1 # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 assert counts[pt.array.IndexLambda] == 100 # 100 operations - assert sum(counts.values()) - 1 == iterations + 1 + assert sum(counts.values()) == iterations + 1 def test_random_dag_count(): @@ -658,7 +657,7 @@ def test_random_dag_count(): dag = get_random_pt_dag(seed=i, axis_len=5) # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_random_dag_with_comm_count(): @@ -670,7 +669,7 @@ def test_random_dag_with_comm_count(): dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): From 178127c9a7608b58901ae49f9a06c378d0b128d8 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:24:57 -0600 Subject: [PATCH 07/19] Edit tests to account for not counting DictOfNamedArrays --- test/test_pytato.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 16d1343e6..6f302caa2 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -656,7 +656,6 @@ def test_random_dag_count(): for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) @@ -668,9 +667,9 @@ def test_random_dag_with_comm_count(): for i in range(10): dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) + def test_duplicate_node_count(): from testlib import get_random_pt_dag from pytato.analysis import get_num_nodes, get_expr_calls @@ -683,9 +682,10 @@ def test_duplicate_node_count(): # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) + # Get difference in duplicates num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) def test_duplicate_nodes_with_comm_count(): @@ -703,10 +703,11 @@ def test_duplicate_nodes_with_comm_count(): # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) + # Get difference in duplicates num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) def test_large_dag_with_duplicates_count(): @@ -719,7 +720,7 @@ def test_large_dag_with_duplicates_count(): # Verify that the number of nodes is equal to iterations + 1 (placeholder) node_count = get_num_nodes(dag, count_duplicates=True) - assert node_count - 1 == iterations + 1 + assert node_count == iterations + 1 # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) @@ -731,10 +732,10 @@ def test_large_dag_with_duplicates_count(): assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 assert counts[pt.array.IndexLambda] == 100 # 100 operations - assert sum(counts.values()) - 1 == iterations + 1 + assert sum(counts.values()) == iterations + 1 # Check that duplicates are correctly calculated - assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): From 326045ecc2b34e425606a67c819a38c75a6d349e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:37:55 -0600 Subject: [PATCH 08/19] Fix CI tests --- pytato/analysis/__init__.py | 30 ++++++++++++++++++++++-------- test/test_pytato.py | 35 ++++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index aa015a760..844c9214b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -392,15 +392,16 @@ class NodeCountMapper(CachedWalkMapper): Dictionary mapping node types to number of nodes of that type. """ - def __init__(self, count_duplicates=False) -> None: # added parameter + def __init__(self, count_duplicates: bool = False) -> None: from collections import defaultdict super().__init__() - self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] + self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] self.count_duplicates = count_duplicates - self.expr_call_counts = defaultdict(int) + self.expr_call_counts = defaultdict(int) # type: Dict[Any, int] - def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: - return id(expr) if self.count_duplicates else expr # returns unique nodes only if count_duplicates is True + def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: + # Returns unique nodes only if count_duplicates is True + return id(expr) if self.count_duplicates else expr def post_visit(self, expr: Any) -> None: if type(expr) is not DictOfNamedArrays: @@ -408,7 +409,10 @@ def post_visit(self, expr: Any) -> None: self.expr_call_counts[expr] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: +def get_node_type_counts( + outputs: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False + ) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -425,7 +429,10 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays], count_duplica return ncm.expr_type_counts -def get_num_nodes(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> int: +def get_num_nodes( + outputs: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False + ) -> int: """ Returns the number of nodes in DAG *outputs*. @@ -441,7 +448,14 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays], count_duplicates=Fal return sum(ncm.expr_type_counts.values()) -def get_expr_calls(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: +def get_expr_calls( + outputs: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False + ) -> Dict[Type[Any], int]: + """ + Returns the count of calls per `expr`. + """ + from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) diff --git a/test/test_pytato.py b/test/test_pytato.py index 6f302caa2..20fffb948 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -600,12 +600,13 @@ def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts data = np.random.rand(4, 4) - single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) # Get counts per node type node_counts = get_node_type_counts(single_node_dag) - # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} assert sum(node_counts.values()) == 1 # Total node count is 1 @@ -642,7 +643,6 @@ def test_large_dag_count(): # Verify that the number of nodes is equal to iterations + 1 (placeholder) assert get_num_nodes(dag) == iterations + 1 - # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 @@ -665,7 +665,8 @@ def test_random_dag_with_comm_count(): rank = 0 size = 2 for i in range(10): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) @@ -683,9 +684,11 @@ def test_duplicate_node_count(): expr_counts = get_expr_calls(dag, count_duplicates=True) # Get difference in duplicates - num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + num_duplicates = sum( + count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) def test_duplicate_nodes_with_comm_count(): @@ -695,7 +698,8 @@ def test_duplicate_nodes_with_comm_count(): rank = 0 size = 2 for i in range(20): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) # Get the number of types of expressions node_count = get_num_nodes(dag, count_duplicates=True) @@ -704,14 +708,18 @@ def test_duplicate_nodes_with_comm_count(): expr_counts = get_expr_calls(dag, count_duplicates=True) # Get difference in duplicates - num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + num_duplicates = sum( + count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) def test_large_dag_with_duplicates_count(): - from pytato.analysis import get_num_nodes, get_node_type_counts, get_expr_calls + from pytato.analysis import ( + get_num_nodes, get_node_type_counts, get_expr_calls + ) from testlib import make_large_dag import pytato as pt @@ -725,9 +733,9 @@ def test_large_dag_with_duplicates_count(): # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) - num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + num_duplicates = sum( + count - 1 for count in expr_counts.values() if count > 1) - # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag, count_duplicates=True) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 @@ -735,7 +743,8 @@ def test_large_dag_with_duplicates_count(): assert sum(counts.values()) == iterations + 1 # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): From 6a0a2a9f44f58d1d4ba7e39122a4f2577ad0b157 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:39:57 -0600 Subject: [PATCH 09/19] Fix comments --- test/test_pytato.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 3d9e2a684..e202fd2c3 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -600,12 +600,13 @@ def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts data = np.random.rand(4, 4) - single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) # Get counts per node type node_counts = get_node_type_counts(single_node_dag) - # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} assert sum(node_counts.values()) == 1 # Total node count is 1 @@ -642,7 +643,6 @@ def test_large_dag_count(): # Verify that the number of nodes is equal to iterations + 1 (placeholder) assert get_num_nodes(dag) == iterations + 1 - # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 @@ -656,7 +656,6 @@ def test_random_dag_count(): for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) @@ -666,9 +665,9 @@ def test_random_dag_with_comm_count(): rank = 0 size = 2 for i in range(10): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) From 0dca4d7c4295179f8e946f9411349d2dfa43512e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 26 Jun 2024 19:50:38 -0600 Subject: [PATCH 10/19] Clarify wording and clean up --- pytato/analysis/__init__.py | 6 +++--- test/test_pytato.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 4938f6794..3a112501b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -400,7 +400,7 @@ def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: return expr def post_visit(self, expr: Any) -> None: - if type(expr) is not DictOfNamedArrays: + if not isinstance(expr, DictOfNamedArrays): self.counts[type(expr)] += 1 @@ -409,7 +409,7 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. - `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + Instances of `DictOfNamedArrays` are excluded from counting. """ from pytato.codegen import normalize_outputs @@ -425,7 +425,7 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """ Returns the number of nodes in DAG *outputs*. - `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + Instances of `DictOfNamedArrays` are excluded from counting. """ from pytato.codegen import normalize_outputs diff --git a/test/test_pytato.py b/test/test_pytato.py index e202fd2c3..23aebd9e4 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -608,7 +608,6 @@ def test_single_node_dag_count(): # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} - assert sum(node_counts.values()) == 1 # Total node count is 1 # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) From 9489ecf05dc9089903c9fc6b64f1252cb792f8d8 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 26 Jun 2024 20:21:44 -0600 Subject: [PATCH 11/19] Move `get_node_multiplicities` to its own mapper --- pytato/analysis/__init__.py | 41 +++++++++++++++++++++++++++---------- test/test_pytato.py | 18 ++++++++-------- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index ad585db16..1f84aa912 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -387,7 +387,6 @@ class NodeCountMapper(CachedWalkMapper): .. attribute:: expr_type_counts .. attribute:: count_duplicates - .. attribute:: expr_call_counts Dictionary mapping node types to number of nodes of that type. """ @@ -397,7 +396,6 @@ def __init__(self, count_duplicates: bool = False) -> None: super().__init__() self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] self.count_duplicates = count_duplicates - self.expr_call_counts = defaultdict(int) # type: Dict[Any, int] def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: # Returns unique nodes only if count_duplicates is True @@ -406,7 +404,6 @@ def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: def post_visit(self, expr: Any) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_type_counts[type(expr)] += 1 - self.expr_call_counts[expr] += 1 def get_node_type_counts( @@ -447,21 +444,43 @@ def get_num_nodes( return sum(ncm.expr_type_counts.values()) +# }}} -def get_expr_calls( - outputs: Union[Array, DictOfNamedArrays], - count_duplicates: bool = False - ) -> Dict[Type[Any], int]: + +# {{{ NodeMultiplicityMapper + + +class NodeMultiplicityMapper(CachedWalkMapper): """ - Returns the count of calls per `expr`. + Counts the number of unique nodes by ID in a DAG. + + .. attribute:: expr_multiplicity_counts + """ + def __init__(self) -> None: + from collections import defaultdict + super().__init__() + self.expr_multiplicity_counts = defaultdict(int) # type: Dict[Any, int] + + def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: + # Returns unique nodes + return id(expr) + + def post_visit(self, expr: Any) -> None: + if not isinstance(expr, DictOfNamedArrays): + self.expr_multiplicity_counts[expr] += 1 + + +def get_node_multiplicities(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: + """ + Returns the multiplicity per `expr`. """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper(count_duplicates) - ncm(outputs) + nmm = NodeMultiplicityMapper() + nmm(outputs) - return ncm.expr_call_counts + return nmm.expr_multiplicity_counts # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 4ef754a69..c34f14776 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -672,7 +672,7 @@ def test_random_dag_with_comm_count(): def test_duplicate_node_count(): from testlib import get_random_pt_dag - from pytato.analysis import get_num_nodes, get_expr_calls + from pytato.analysis import get_num_nodes, get_node_multiplicities for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) @@ -680,11 +680,11 @@ def test_duplicate_node_count(): node_count = get_num_nodes(dag, count_duplicates=True) # Get the number of expressions and the amount they're called - expr_counts = get_expr_calls(dag, count_duplicates=True) + node_multiplicity = get_node_multiplicities(dag) # Get difference in duplicates num_duplicates = sum( - count - 1 for count in expr_counts.values() if count > 1) + count - 1 for count in node_multiplicity.values() if count > 1) # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( pt.transform.DependencyMapper()(dag)) @@ -692,7 +692,7 @@ def test_duplicate_node_count(): def test_duplicate_nodes_with_comm_count(): from testlib import get_random_pt_dag_with_send_recv_nodes - from pytato.analysis import get_num_nodes, get_expr_calls + from pytato.analysis import get_num_nodes, get_node_multiplicities rank = 0 size = 2 @@ -704,11 +704,11 @@ def test_duplicate_nodes_with_comm_count(): node_count = get_num_nodes(dag, count_duplicates=True) # Get the number of expressions and the amount they're called - expr_counts = get_expr_calls(dag, count_duplicates=True) + node_multiplicity = get_node_multiplicities(dag) # Get difference in duplicates num_duplicates = sum( - count - 1 for count in expr_counts.values() if count > 1) + count - 1 for count in node_multiplicity.values() if count > 1) # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( @@ -717,7 +717,7 @@ def test_duplicate_nodes_with_comm_count(): def test_large_dag_with_duplicates_count(): from pytato.analysis import ( - get_num_nodes, get_node_type_counts, get_expr_calls + get_num_nodes, get_node_type_counts, get_node_multiplicities ) from testlib import make_large_dag import pytato as pt @@ -730,10 +730,10 @@ def test_large_dag_with_duplicates_count(): assert node_count == iterations + 1 # Get the number of expressions and the amount they're called - expr_counts = get_expr_calls(dag, count_duplicates=True) + node_multiplicity = get_node_multiplicities(dag) num_duplicates = sum( - count - 1 for count in expr_counts.values() if count > 1) + count - 1 for count in node_multiplicity.values() if count > 1) counts = get_node_type_counts(dag, count_duplicates=True) assert len(counts) >= 1 From 27d6283ba74c24ca2744a41797a8beef36f2ff8c Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 26 Jun 2024 20:26:30 -0600 Subject: [PATCH 12/19] Add autofunction --- pytato/analysis/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 1f84aa912..3bc6785a7 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -51,6 +51,8 @@ .. autofunction:: get_node_type_counts +.. autofunction:: get_node_multiplicities + .. autofunction:: get_num_call_sites .. autoclass:: DirectPredecessorsGetter From 1444c50d77badc5d947cb2aa536eaaca1861c5df Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 2 Jul 2024 18:05:48 -0600 Subject: [PATCH 13/19] Formatting --- pytato/analysis/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3a112501b..b822d6622 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -404,7 +404,8 @@ def post_visit(self, expr: Any) -> None: self.counts[type(expr)] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays] + ) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. From 528c617b80350bfe643e4a9de82a3e3fe06be34c Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Sat, 6 Jul 2024 21:58:09 -0600 Subject: [PATCH 14/19] Add functionality to graph duplicate nodes via dot graph --- pytato/visualization/dot.py | 366 ++++++++++++++++++++++-------------- test/test_pytato.py | 73 +++++++ test/testlib.py | 16 ++ 3 files changed, 315 insertions(+), 140 deletions(-) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 0e3396df0..e02ec05c5 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -29,6 +29,7 @@ from functools import partial import html import attrs +import gc from typing import (TYPE_CHECKING, Callable, Dict, Tuple, Union, List, Mapping, Any, FrozenSet, Set, Optional) @@ -51,8 +52,7 @@ from pytato.distributed.partition import ( DistributedGraphPartition, DistributedGraphPart, PartId) -if TYPE_CHECKING: - from pytato.distributed.nodes import DistributedSendRefHolder +from pytato.distributed.nodes import DistributedSendRefHolder __doc__ = """ @@ -162,12 +162,23 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" -class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): - def __init__(self) -> None: - super().__init__() - self.node_to_dot: Dict[ArrayOrNames, _DotNodeInfo] = {} - self.functions: Set[FunctionDefinition] = set() +def get_object_by_id(object_id): + """Find an object by its ID.""" + for obj in gc.get_objects(): + if id(obj) == object_id: + return obj + return None + + +class ArrayToDotNodeInfoMapper: + def __init__(self, count_duplicates: bool = False): + self.count_duplicates = count_duplicates + self.node_to_dot = {} + self.functions = set() + def get_cache_key(self, expr): + return id(expr) if self.count_duplicates else expr + def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: title = type(expr).__name__ fields = {"addr": hex(id(expr)), @@ -180,65 +191,83 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} return _DotNodeInfo(title, fields, edges) - # type-ignore-reason: incompatible with supertype - def handle_unsupported_array(self, # type: ignore[override] - expr: Array) -> None: - # Default handler, does its best to guess how to handle fields. - info = self.get_common_dot_info(expr) + def process_node(self, expr: Array) -> None: + if isinstance(expr, DataWrapper): + self.map_data_wrapper(expr) + elif isinstance(expr, IndexLambda): + self.map_index_lambda(expr) + elif isinstance(expr, Stack): + self.map_stack(expr) + elif isinstance(expr, (IndexBase, IndexLambda)): + self.map_basic_index(expr) + elif isinstance(expr, Einsum): + self.map_einsum(expr) + elif isinstance(expr, DictOfNamedArrays): + self.map_dict_of_named_arrays(expr) + elif isinstance(expr, LoopyCall): + self.map_loopy_call(expr) + elif isinstance(expr, DistributedSendRefHolder): + self.map_distributed_send_ref_holder(expr) + elif isinstance(expr, Call): + self.map_call(expr) + elif isinstance(expr, NamedCallResult): + self.map_named_call_result(expr) + else: + self.handle_unsupported_array(expr) - # pylint: disable=not-an-iterable + def handle_unsupported_array(self, expr: Array) -> None: + info = self.get_common_dot_info(expr) + expr_key = self.get_cache_key(expr) for field in attrs.fields(type(expr)): if field.name in info.fields: continue attr = getattr(expr, field.name) - if isinstance(attr, Array): - self.rec(attr) - info.edges[field.name] = attr - + self.process_node(attr) + key = self.get_cache_key(attr) + info.edges[field.name] = (key, attr) elif isinstance(attr, AbstractResultWithNamedArrays): - self.rec(attr) - info.edges[field.name] = attr - + self.process_node(attr) + key = self.get_cache_key(attr) + info.edges[field.name] = (key, attr) elif isinstance(attr, tuple): info.fields[field.name] = stringify_shape(attr) - else: info.fields[field.name] = str(attr) - - self.node_to_dot[expr] = info + self.node_to_dot[expr_key] = info def map_data_wrapper(self, expr: DataWrapper) -> None: info = self.get_common_dot_info(expr) if expr.name is not None: info.fields["name"] = expr.name - # Only show summarized data import numpy as np with np.printoptions(threshold=4, precision=2): info.fields["data"] = str(expr.data) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_index_lambda(self, expr: IndexLambda) -> None: info = self.get_common_dot_info(expr) info.fields["expr"] = str(expr.expr) for name, val in expr.bindings.items(): - self.rec(val) - info.edges[name] = val + self.process_node(val) + key = self.get_cache_key(val) + info.edges[name] = (key, val) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_stack(self, expr: Stack) -> None: info = self.get_common_dot_info(expr) info.fields["axis"] = str(expr.axis) for i, array in enumerate(expr.arrays): - self.rec(array) - info.edges[str(i)] = array + self.process_node(array) + key = self.get_cache_key(array) + info.edges[str(i)] = (key, array) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info map_concatenate = map_stack @@ -254,9 +283,10 @@ def map_basic_index(self, expr: IndexBase) -> None: elif isinstance(index, Array): label = f"i{i}" - self.rec(index) + self.process_node(index) + key = self.get_cache_key(index) indices_parts.append(label) - info.edges[label] = index + info.edges[label] = (key, index) elif index is None: indices_parts.append("newaxis") @@ -266,10 +296,11 @@ def map_basic_index(self, expr: IndexBase) -> None: info.fields["indices"] = ", ".join(indices_parts) - self.rec(expr.array) - info.edges["array"] = expr.array + self.process_node(expr.array) + key = self.get_cache_key(expr.array) + info.edges["array"] = (key, expr.array) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info map_contiguous_advanced_index = map_basic_index map_non_contiguous_advanced_index = map_basic_index @@ -279,18 +310,20 @@ def map_einsum(self, expr: Einsum) -> None: for iarg, (access_descr, val) in enumerate(zip(expr.access_descriptors, expr.args)): - self.rec(val) - info.edges[f"{iarg}: {access_descr}"] = val + self.process_node(val) + key = self.get_cache_key(val) + info.edges[f"{iarg}: {access_descr}"] = (key, val) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} for name, val in expr._data.items(): - edges[name] = val - self.rec(val) + self.process_node(val) + key = self.get_cache_key(val) + edges[name] = (key, val) - self.node_to_dot[expr] = _DotNodeInfo( + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=type(expr).__name__, fields={}, edges=edges) @@ -299,10 +332,11 @@ def map_loopy_call(self, expr: LoopyCall) -> None: edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} for name, arg in expr.bindings.items(): if isinstance(arg, Array): - edges[name] = arg - self.rec(arg) + self.process_node(arg) + key = self.get_cache_key(arg) + edges[name] = (key, arg) - self.node_to_dot[expr] = _DotNodeInfo( + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=type(expr).__name__, fields={"addr": hex(id(expr)), "entrypoint": expr.entrypoint}, edges=edges) @@ -312,29 +346,30 @@ def map_distributed_send_ref_holder( info = self.get_common_dot_info(expr) - self.rec(expr.passthrough_data) - info.edges["passthrough"] = expr.passthrough_data + self.process_node(expr.passthrough_data) + key = self.get_cache_key(expr.passthrough_data) + info.edges["passthrough"] = (key, expr.passthrough_data) - self.rec(expr.send.data) - info.edges["sent"] = expr.send.data + self.process_node(expr.send.data) + key = self.get_cache_key(expr.send.data) + info.edges["sent"] = (key, expr.send.data) info.fields["dest_rank"] = str(expr.send.dest_rank) - info.fields["comm_tag"] = str(expr.send.comm_tag) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_call(self, expr: Call) -> None: self.functions.add(expr.function) for bnd in expr.bindings.values(): - self.rec(bnd) + self.process_node(bnd) - self.node_to_dot[expr] = _DotNodeInfo( + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=expr.__class__.__name__, edges={ "": expr.function, - **expr.bindings}, + **{name: (self.get_cache_key(bnd), bnd) for name, bnd in expr.bindings.items()}}, fields={ "addr": hex(id(expr)), "tags": stringify_tags(expr.tags), @@ -342,14 +377,16 @@ def map_call(self, expr: Call) -> None: ) def map_named_call_result(self, expr: NamedCallResult) -> None: - self.rec(expr._container) - self.node_to_dot[expr] = _DotNodeInfo( + self.process_node(expr._container) + key = self.get_cache_key(expr._container) + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=expr.__class__.__name__, - edges={"": expr._container}, + edges={"": (key, expr._container)}, fields={"addr": hex(id(expr)), "name": expr.name}, ) + # }}} @@ -363,6 +400,11 @@ def dot_escape_leave_space(s: str) -> str: return html.escape(s.replace("\\", "\\\\")) +def get_array_key(array, count_duplicates): + """Return a consistent key for the array.""" + return id(array) if count_duplicates and not isinstance(array, int) else array + + # {{{ emit helpers def _stringify_created_at(non_equality_tags: FrozenSet[Tag]) -> str: @@ -375,7 +417,7 @@ def _stringify_created_at(non_equality_tags: FrozenSet[Tag]) -> str: def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], - dot_node_id: str, color: str = "white") -> None: + dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' @@ -397,11 +439,20 @@ def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], f'tooltip="{tooltip}"]') +def preprocess_all_nodes(partition, array_to_id, id_gen, count_duplicates): + mapper = ArrayToDotNodeInfoMapper(count_duplicates) + for part in partition.parts.values(): + for out_name in part.output_names: + node = partition.name_to_output[out_name] + mapper.process_node(node) + + def _emit_name_cluster( emit: DotEmitter, subgraph_path: Tuple[str, ...], names: Mapping[str, ArrayOrNames], array_to_id: Mapping[ArrayOrNames, str], id_gen: Callable[[str], str], - label: str) -> None: + label: str, + count_duplicates: bool = False) -> None: edges = [] cluster_subgraph_path = subgraph_path + (f"cluster_{dot_escape(label)}",) @@ -412,7 +463,8 @@ def _emit_name_cluster( for name, array in names.items(): name_id = id_gen(dot_escape(name)) emit_cluster('%s [label="%s"]' % (name_id, dot_escape(name))) - array_id = array_to_id[array] + array_key = get_array_key(array, count_duplicates) + array_id = array_to_id[array_key] # Edges must be outside the cluster. edges.append((name_id, array_id)) @@ -425,14 +477,16 @@ def _emit_function( id_gen: UniqueNameGenerator, node_to_dot: Mapping[ArrayOrNames, _DotNodeInfo], func_to_id: Mapping[FunctionDefinition, str], - outputs: Mapping[str, Array]) -> None: + outputs: Mapping[str, Array], + count_duplicates: bool = False) -> None: input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] array_to_id: Dict[ArrayOrNames, str] = {} emit = partial(emitter, subgraph_path) for array in node_to_dot: - array_to_id[array] = id_gen("array") + key = get_array_key(array, count_duplicates) + array_to_id[key] = id_gen("array") if isinstance(array, InputArgumentBase): input_arrays.append(array) else: @@ -444,36 +498,46 @@ def _emit_function( emit_input('label="Arguments"') for array in input_arrays: + key = get_array_key(array, count_duplicates) _emit_array( emit_input, node_to_dot[array].title, node_to_dot[array].fields, - array_to_id[array]) + array_to_id[key]) # Emit non-inputs. for array in internal_arrays: + key = get_array_key(array, count_duplicates) _emit_array(emit, node_to_dot[array].title, node_to_dot[array].fields, - array_to_id[array]) + array_to_id[key]) # Emit edges. for array, node in node_to_dot.items(): - for label, tail_item in node.edges.items(): - head = array_to_id[array] + key = get_array_key(array, count_duplicates) + for label, edge_info in node.edges.items(): + if isinstance(edge_info, tuple): + tail_key, tail_item = edge_info + else: + tail_item = edge_info + tail_key = get_array_key(tail_item, count_duplicates) + + head = array_to_id[key] if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)): - tail = array_to_id[tail_item] + tail = array_to_id[tail_key] elif isinstance(tail_item, FunctionDefinition): tail = func_to_id[tail_item] else: - raise ValueError( - f"unexpected type of tail on edge: {type(tail_item)}") + raise ValueError(f"unexpected type of tail on edge: {type(tail_item)}") emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label))) # Emit output/namespace name mappings. _emit_name_cluster( - emitter, subgraph_path, outputs, array_to_id, id_gen, label="Returns") + emitter, subgraph_path, + outputs, array_to_id, id_gen, + label="Returns", count_duplicates=count_duplicates) # }}} @@ -491,7 +555,8 @@ def _get_function_name(f: FunctionDefinition) -> Optional[str]: def _gather_partition_node_information( id_gen: UniqueNameGenerator, - partition: DistributedGraphPartition + partition: DistributedGraphPartition, + count_duplicates: bool = False ) -> Tuple[ Mapping[PartId, Mapping[FunctionDefinition, str]], Mapping[Tuple[PartId, Optional[FunctionDefinition]], @@ -502,9 +567,9 @@ def _gather_partition_node_information( Dict[ArrayOrNames, _DotNodeInfo]] = {} for part in partition.parts.values(): - mapper = ArrayToDotNodeInfoMapper() + mapper = ArrayToDotNodeInfoMapper(count_duplicates) for out_name in part.output_names: - mapper(partition.name_to_output[out_name]) + mapper.process_node(partition.name_to_output[out_name]) part_id_func_to_node_info[part.pid, None] = mapper.node_to_dot part_id_to_func_to_id[part.pid] = {} @@ -519,9 +584,9 @@ def gather_function_info(f: FunctionDefinition) -> None: if key in part_id_func_to_node_info: return - mapper = ArrayToDotNodeInfoMapper() + mapper = ArrayToDotNodeInfoMapper(count_duplicates) for elem in f.returns.values(): - mapper(elem) + mapper.process_node(elem) part_id_func_to_node_info[key] = mapper.node_to_dot @@ -547,10 +612,12 @@ def gather_function_info(f: FunctionDefinition) -> None: return part_id_to_func_to_id, part_id_func_to_node_info + # }}} -def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: +def get_dot_graph(result: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False) -> str: r"""Return a string in the `dot `_ language depicting the graph of the computation of *result*. @@ -560,30 +627,32 @@ def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: outputs: DictOfNamedArrays = normalize_outputs(result) - return get_dot_graph_from_partition( - DistributedGraphPartition( - parts={ - None: DistributedGraphPart( - pid=None, - needed_pids=frozenset(), - user_input_names=frozenset( + partition = DistributedGraphPartition( + parts={ + None: DistributedGraphPart( + pid=None, + needed_pids=frozenset(), + user_input_names=frozenset( expr.name for expr in InputGatherer()(outputs) if isinstance(expr, Placeholder) ), - partition_input_names=frozenset(), - output_names=frozenset(outputs.keys()), - name_to_recv_node={}, - name_to_send_nodes={}, - ) - }, - name_to_output=outputs._data, - overall_output_names=tuple(outputs), - )) - - -def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: - r"""Return a string in the `dot `_ language depicting the + partition_input_names=frozenset(), + output_names=frozenset(outputs.keys()), + name_to_recv_node={}, + name_to_send_nodes={}, + ) + }, + name_to_output=outputs._data, + overall_output_names=tuple(outputs), + ) + + return get_dot_graph_from_partition(partition, count_duplicates) + + +def get_dot_graph_from_partition(partition: DistributedGraphPartition, + count_duplicates: bool = False) -> str: + """Return a string in the `dot `_ language depicting the graph of the partitioned computation of *partition*. :arg partition: Outputs of :func:`~pytato.find_distributed_partition`. @@ -595,9 +664,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # The "None" function is the body of the partition. part_id_to_func_to_id, part_id_func_to_node_info = \ - _gather_partition_node_information(id_gen, partition) - - # }}} + _gather_partition_node_information(id_gen, partition, count_duplicates) emitter = DotEmitter() emit_root = partial(emitter, ()) @@ -606,8 +673,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: emit_root("node [shape=rectangle]") - placeholder_to_id: Dict[ArrayOrNames, str] = {} - part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {} + placeholder_to_id: Dict[Union[int, ArrayOrNames], str] = {} + part_id_to_array_to_id: Dict[PartId, Dict[Union[int, ArrayOrNames], str]] = {} part_id_to_id = {pid: dot_escape(str(pid)) for pid in partition.parts} assert len(set(part_id_to_id.values())) == len(partition.parts) @@ -617,16 +684,18 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: for part in partition.parts.values(): array_to_id = {} for array in part_id_func_to_node_info[part.pid, None].keys(): + if isinstance(array, int): # if the key is an ID + array = get_object_by_id(array) + key = get_array_key(array, count_duplicates) if isinstance(array, Placeholder): - # Placeholders are only emitted once - if array in placeholder_to_id: - node_id = placeholder_to_id[array] + if key in placeholder_to_id: + node_id = placeholder_to_id[key] else: node_id = id_gen("array") - placeholder_to_id[array] = node_id + placeholder_to_id[key] = node_id else: node_id = id_gen("array") - array_to_id[array] = node_id + array_to_id[key] = node_id part_id_to_array_to_id[part.pid] = array_to_id # }}} @@ -663,32 +732,34 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: _emit_function(emitter, func_subgraph_path, id_gen, part_id_func_to_node_info[part.pid, func], part_id_to_func_to_id[part.pid], - func.returns) + func.returns, + count_duplicates=count_duplicates) # }}} # {{{ emit receives nodes part_dist_recv_var_name_to_node_id = {} - for name, recv in ( - part.name_to_recv_node.items()): + for name, recv in part.name_to_recv_node.items(): node_id = id_gen("recv") _emit_array(emit_part, "DistributedRecv", { "shape": stringify_shape(recv.shape), "dtype": str(recv.dtype), "src_rank": str(recv.src_rank), "comm_tag": str(recv.comm_tag), - }, node_id) + }, node_id) part_dist_recv_var_name_to_node_id[name] = node_id # }}} - + part_node_to_info = part_id_func_to_node_info[part.pid, None] input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] for array in part_node_to_info.keys(): + if isinstance(array, int): # if the key is an ID + array = get_object_by_id(array) if isinstance(array, InputArgumentBase): input_arrays.append(array) else: @@ -702,26 +773,26 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # subgraphs. for array in input_arrays: + key = array = get_array_key(array, count_duplicates) if not isinstance(array, Placeholder): _emit_array(emit_part, part_node_to_info[array].title, part_node_to_info[array].fields, - array_to_id[array], "deepskyblue") + array_to_id[key], "deepskyblue") else: # Is a Placeholder - if array in emitted_placeholders: + if key in emitted_placeholders: continue - _emit_array(emit_root, part_node_to_info[array].title, part_node_to_info[array].fields, - array_to_id[array], "deepskyblue") + array_to_id[key], "deepskyblue") # Emit cross-partition edges if array.name in part_dist_recv_var_name_to_node_id: tgt = part_dist_recv_var_name_to_node_id[array.name] - emit_root(f"{tgt} -> {array_to_id[array]} [style=dotted]") - emitted_placeholders.add(array) + emit_root(f"{tgt} -> {array_to_id[key]} [style=dotted]") + emitted_placeholders.add(key) elif array.name in part.user_input_names: # no arrows for these pass @@ -734,18 +805,22 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: break assert computing_pid is not None tgt = part_id_to_array_to_id[computing_pid][ - partition.name_to_output[array.name]] - emit_root(f"{tgt} -> {array_to_id[array]} [style=dashed]") - emitted_placeholders.add(array) + id(partition.name_to_output[array.name]) + if count_duplicates + else partition.name_to_output[array.name]] + emit_root(f"{tgt} -> {array_to_id[key]} [style=dashed]") + emitted_placeholders.add(key) # }}} # Emit internal nodes + for array in internal_arrays: + key = array = get_array_key(array, count_duplicates) _emit_array(emit_part, part_node_to_info[array].title, part_node_to_info[array].fields, - array_to_id[array]) + array_to_id[key]) # {{{ emit send nodes if distributed @@ -756,36 +831,45 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: _emit_array(emit_part, "DistributedSend", { "dest_rank": str(send.dest_rank), "comm_tag": str(send.comm_tag), - }, node_id) + }, node_id) # If an edge is emitted in a subgraph, it drags its # nodes into the subgraph, too. Not what we want. emit_root( - f"{array_to_id[send.data]} -> {node_id}" - f'[style=dotted, label="{dot_escape(name)}"]') + f"{array_to_id[id(send.data) if count_duplicates else send.data]} -> {node_id}" + f'[style=dotted, label="{dot_escape(name)}"]') # }}} # Emit intra-partition edges for array, node in part_node_to_info.items(): - for label, tail_item in node.edges.items(): - head = array_to_id[array] + key = get_array_key(array, count_duplicates) + for label, edge_info in node.edges.items(): + if isinstance(edge_info, tuple): + tail_key, tail_item = edge_info + else: + tail_item = edge_info + tail_key = get_array_key(tail_item, count_duplicates) + + head = array_to_id[key] if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)): - tail = array_to_id[tail_item] + tail = array_to_id[tail_key] elif isinstance(tail_item, FunctionDefinition): tail = part_id_to_func_to_id[part.pid][tail_item] else: raise ValueError( - f"unexpected type of tail on edge: {type(tail_item)}") + f"unexpected type of tail on edge: {type(tail_item)}") emit_root('%s -> %s [label="%s"]' % - (tail, head, dot_escape(label))) + (tail, head, dot_escape(label))) _emit_name_cluster( - emitter, part_subgraph_path, - {name: partition.name_to_output[name] for name in part.output_names}, - array_to_id, id_gen, "Part outputs") + emitter, part_subgraph_path, + {name: partition.name_to_output[name] + for name in part.output_names}, + array_to_id, id_gen, "Part outputs", + count_duplicates) # }}} @@ -795,15 +879,16 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # {{{ draw overall outputs - combined_array_to_id: Dict[ArrayOrNames, str] = {} + combined_array_to_id: Dict[Union[int, ArrayOrNames], str] = {} for part_id in partition.parts.keys(): combined_array_to_id.update(part_id_to_array_to_id[part_id]) _emit_name_cluster( - emitter, (), - {name: partition.name_to_output[name] - for name in partition.overall_output_names}, - combined_array_to_id, id_gen, "Overall outputs") + emitter, (), + {name: partition.name_to_output[name] + for name in partition.overall_output_names}, + combined_array_to_id, id_gen, "Overall outputs", + count_duplicates) # }}} @@ -812,7 +897,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, DistributedGraphPartition], - **kwargs: Any) -> None: + count_duplicates: bool = False, + **kwargs: Any) -> None: """Show a graph representing the computation of *result* in a browser. :arg result: Outputs of the computation (cf. @@ -825,9 +911,9 @@ def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, if isinstance(result, str): dot_code = result elif isinstance(result, DistributedGraphPartition): - dot_code = get_dot_graph_from_partition(result) + dot_code = get_dot_graph_from_partition(result, count_duplicates) else: - dot_code = get_dot_graph(result) + dot_code = get_dot_graph(result, count_duplicates) from pytools.graphviz import show_dot show_dot(dot_code, **kwargs) diff --git a/test/test_pytato.py b/test/test_pytato.py index c34f14776..45102158e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -746,6 +746,79 @@ def test_large_dag_with_duplicates_count(): pt.transform.DependencyMapper()(dag)) +def test_duplicate_node_count_dot_graph(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import get_random_pt_dag + from testlib import count_dot_graph_nodes + + for i in range(80): + # print("curr i:", i) + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_duplicate_nodes_with_comm_count_dot_graph(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import get_random_pt_dag_with_send_recv_nodes + from testlib import count_dot_graph_nodes + + rank = 0 + size = 2 + for i in range(20): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_large_dot_graph_with_duplicates_count(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import make_large_dag + from testlib import count_dot_graph_nodes + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts with duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index a208f0816..f81c46682 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -3,6 +3,7 @@ import types from typing import Any, Dict, Optional, List, Tuple, Union, Sequence, Callable import operator +import re import pyopencl as cl import numpy as np import pytato as pt @@ -337,6 +338,21 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: # DAG should have `iterations` number of operations return pt.make_dict_of_named_arrays({"result": current}) + +def count_dot_graph_nodes(dot_graph: str): + """ + Parses a dot graph and returns a dictionary with the count of each unique node identifier. + """ + + node_pattern = re.compile(r'addr:(0x[0-9a-f]+)') + nodes = node_pattern.findall(dot_graph) + + node_counts = {} + for node in nodes: + node_counts[node] = node_counts.get(node, 0) + 1 + + return node_counts + # }}} From f3ff93a835a7e555312b02338e5a4e221d212639 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 8 Jul 2024 20:56:53 -0600 Subject: [PATCH 15/19] Linting and formatting --- pytato/analysis/__init__.py | 3 +- pytato/visualization/dot.py | 87 +++++++++++++++++++++---------------- test/test_codegen.py | 77 +++++++++++++++++++++++++++++++- test/test_pytato.py | 73 ------------------------------- test/testlib.py | 10 +++-- 5 files changed, 132 insertions(+), 118 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3bc6785a7..f025ff350 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -472,7 +472,8 @@ def post_visit(self, expr: Any) -> None: self.expr_multiplicity_counts[expr] += 1 -def get_node_multiplicities(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: +def get_node_multiplicities( + outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns the multiplicity per `expr`. """ diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index e02ec05c5..38f482a06 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -31,7 +31,7 @@ import attrs import gc -from typing import (TYPE_CHECKING, Callable, Dict, Tuple, Union, List, +from typing import (Callable, Dict, Tuple, Union, List, Mapping, Any, FrozenSet, Set, Optional) from pytools import UniqueNameGenerator @@ -145,7 +145,11 @@ def emit_subgraph(sg: _SubgraphTree) -> None: class _DotNodeInfo: title: str fields: Dict[str, Any] - edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] + edges: Dict[str, Union[ + ArrayOrNames, + FunctionDefinition, + Tuple[Union[int, ArrayOrNames], ArrayOrNames]], + Array] def stringify_tags(tags: FrozenSet[Optional[Tag]]) -> str: @@ -162,7 +166,7 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" -def get_object_by_id(object_id): +def get_object_by_id(object_id: int) -> Union[Any, ArrayOrNames]: """Find an object by its ID.""" for obj in gc.get_objects(): if id(obj) == object_id: @@ -170,15 +174,15 @@ def get_object_by_id(object_id): return None -class ArrayToDotNodeInfoMapper: +class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): def __init__(self, count_duplicates: bool = False): + self.node_to_dot: Dict[Union[int, ArrayOrNames], _DotNodeInfo] = {} + self.functions: Set[FunctionDefinition] = set() self.count_duplicates = count_duplicates - self.node_to_dot = {} - self.functions = set() - def get_cache_key(self, expr): + def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: return id(expr) if self.count_duplicates else expr - + def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: title = type(expr).__name__ fields = {"addr": hex(id(expr)), @@ -188,17 +192,20 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: "non_equality_tags": expr.non_equality_tags, } - edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} + edges: Dict[str, + Union[ArrayOrNames, FunctionDefinition, + Tuple[Union[int, AbstractResultWithNamedArrays, + Array], Array]]] = {} return _DotNodeInfo(title, fields, edges) - def process_node(self, expr: Array) -> None: + def process_node(self, expr: ArrayOrNames) -> None: if isinstance(expr, DataWrapper): self.map_data_wrapper(expr) elif isinstance(expr, IndexLambda): self.map_index_lambda(expr) elif isinstance(expr, Stack): self.map_stack(expr) - elif isinstance(expr, (IndexBase, IndexLambda)): + elif isinstance(expr, IndexBase): self.map_basic_index(expr) elif isinstance(expr, Einsum): self.map_einsum(expr) @@ -215,7 +222,9 @@ def process_node(self, expr: Array) -> None: else: self.handle_unsupported_array(expr) - def handle_unsupported_array(self, expr: Array) -> None: + def handle_unsupported_array(self, + expr: Array) -> None: + # Default handler, does its best to guess how to handle fields. info = self.get_common_dot_info(expr) expr_key = self.get_cache_key(expr) for field in attrs.fields(type(expr)): @@ -317,7 +326,8 @@ def map_einsum(self, expr: Einsum) -> None: self.node_to_dot[self.get_cache_key(expr)] = info def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: - edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition, Tuple[Union[ + int, ArrayOrNames], Array]]] = {} for name, val in expr._data.items(): self.process_node(val) key = self.get_cache_key(val) @@ -329,7 +339,8 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: edges=edges) def map_loopy_call(self, expr: LoopyCall) -> None: - edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition, Tuple[Union[ + int, ArrayOrNames], Array]]] = {} for name, arg in expr.bindings.items(): if isinstance(arg, Array): self.process_node(arg) @@ -369,7 +380,8 @@ def map_call(self, expr: Call) -> None: title=expr.__class__.__name__, edges={ "": expr.function, - **{name: (self.get_cache_key(bnd), bnd) for name, bnd in expr.bindings.items()}}, + **{name: (self.get_cache_key(bnd), bnd) + for name, bnd in expr.bindings.items()}}, fields={ "addr": hex(id(expr)), "tags": stringify_tags(expr.tags), @@ -400,7 +412,8 @@ def dot_escape_leave_space(s: str) -> str: return html.escape(s.replace("\\", "\\\\")) -def get_array_key(array, count_duplicates): +def get_array_key(array: Union[ArrayOrNames, FunctionDefinition, int], + count_duplicates: bool = False) -> Any: """Return a consistent key for the array.""" return id(array) if count_duplicates and not isinstance(array, int) else array @@ -439,18 +452,11 @@ def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], f'tooltip="{tooltip}"]') -def preprocess_all_nodes(partition, array_to_id, id_gen, count_duplicates): - mapper = ArrayToDotNodeInfoMapper(count_duplicates) - for part in partition.parts.values(): - for out_name in part.output_names: - node = partition.name_to_output[out_name] - mapper.process_node(node) - - def _emit_name_cluster( emit: DotEmitter, subgraph_path: Tuple[str, ...], names: Mapping[str, ArrayOrNames], - array_to_id: Mapping[ArrayOrNames, str], id_gen: Callable[[str], str], + array_to_id: Mapping[ + Union[int, ArrayOrNames], str], id_gen: Callable[[str], str], label: str, count_duplicates: bool = False) -> None: edges = [] @@ -475,13 +481,13 @@ def _emit_name_cluster( def _emit_function( emitter: DotEmitter, subgraph_path: Tuple[str, ...], id_gen: UniqueNameGenerator, - node_to_dot: Mapping[ArrayOrNames, _DotNodeInfo], + node_to_dot: Mapping[Union[int, ArrayOrNames], _DotNodeInfo], func_to_id: Mapping[FunctionDefinition, str], outputs: Mapping[str, Array], count_duplicates: bool = False) -> None: input_arrays: List[Array] = [] - internal_arrays: List[ArrayOrNames] = [] - array_to_id: Dict[ArrayOrNames, str] = {} + internal_arrays: List[Union[int, ArrayOrNames]] = [] + array_to_id: Dict[Union[int, ArrayOrNames], str] = {} emit = partial(emitter, subgraph_path) for array in node_to_dot: @@ -529,7 +535,8 @@ def _emit_function( elif isinstance(tail_item, FunctionDefinition): tail = func_to_id[tail_item] else: - raise ValueError(f"unexpected type of tail on edge: {type(tail_item)}") + raise ValueError( + f"unexpected type of tail on edge: {type(tail_item)}") emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label))) @@ -558,13 +565,13 @@ def _gather_partition_node_information( partition: DistributedGraphPartition, count_duplicates: bool = False ) -> Tuple[ - Mapping[PartId, Mapping[FunctionDefinition, str]], - Mapping[Tuple[PartId, Optional[FunctionDefinition]], - Mapping[ArrayOrNames, _DotNodeInfo]] - ]: + Dict[PartId, Dict[FunctionDefinition, str]], + Dict[Tuple[PartId, Optional[FunctionDefinition]], + Dict[Union[int, ArrayOrNames], _DotNodeInfo]]]: part_id_to_func_to_id: Dict[PartId, Dict[FunctionDefinition, str]] = {} part_id_func_to_node_info: Dict[Tuple[PartId, Optional[FunctionDefinition]], - Dict[ArrayOrNames, _DotNodeInfo]] = {} + Dict[Union[int, ArrayOrNames], + _DotNodeInfo]] = {} for part in partition.parts.values(): mapper = ArrayToDotNodeInfoMapper(count_duplicates) @@ -650,7 +657,7 @@ def get_dot_graph(result: Union[Array, DictOfNamedArrays], return get_dot_graph_from_partition(partition, count_duplicates) -def get_dot_graph_from_partition(partition: DistributedGraphPartition, +def get_dot_graph_from_partition(partition: DistributedGraphPartition, count_duplicates: bool = False) -> str: """Return a string in the `dot `_ language depicting the graph of the partitioned computation of *partition*. @@ -752,7 +759,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, part_dist_recv_var_name_to_node_id[name] = node_id # }}} - + part_node_to_info = part_id_func_to_node_info[part.pid, None] input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] @@ -814,7 +821,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, # }}} # Emit internal nodes - + for array in internal_arrays: key = array = get_array_key(array, count_duplicates) _emit_array(emit_part, @@ -835,8 +842,9 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, # If an edge is emitted in a subgraph, it drags its # nodes into the subgraph, too. Not what we want. + data = id(send.data) if count_duplicates else send.data emit_root( - f"{array_to_id[id(send.data) if count_duplicates else send.data]} -> {node_id}" + f"{array_to_id[data]} -> {node_id}" f'[style=dotted, label="{dot_escape(name)}"]') # }}} @@ -844,6 +852,9 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, # Emit intra-partition edges for array, node in part_node_to_info.items(): key = get_array_key(array, count_duplicates) + + tail_item: Union[Array, AbstractResultWithNamedArrays, + FunctionDefinition] for label, edge_info in node.edges.items(): if isinstance(edge_info, tuple): tail_key, tail_item = edge_info diff --git a/test/test_codegen.py b/test/test_codegen.py index 0f1456d9b..ca43d6715 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -27,11 +27,10 @@ """ from typing import Union +import sys import itertools import operator -import sys - import loopy as lp import numpy as np import pyopencl as cl @@ -2002,6 +2001,80 @@ def call_bar(tracer, x, y): np.testing.assert_allclose(result_out[k], expect_out[k]) +def test_duplicate_node_count_dot_graph(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import get_random_pt_dag + from testlib import count_dot_graph_nodes + + for i in range(80): + # print("curr i:", i) + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_duplicate_nodes_with_comm_count_dot_graph(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import get_random_pt_dag_with_send_recv_nodes + from testlib import count_dot_graph_nodes + + rank = 0 + size = 2 + for i in range(20): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_large_dot_graph_with_duplicates_count(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import make_large_dag + from testlib import count_dot_graph_nodes + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts with duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/test/test_pytato.py b/test/test_pytato.py index 45102158e..c34f14776 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -746,79 +746,6 @@ def test_large_dag_with_duplicates_count(): pt.transform.DependencyMapper()(dag)) -def test_duplicate_node_count_dot_graph(): - from pytato.visualization.dot import get_dot_graph - from pytato.analysis import get_num_nodes - from testlib import get_random_pt_dag - from testlib import count_dot_graph_nodes - - for i in range(80): - # print("curr i:", i) - dag = get_random_pt_dag(seed=i, axis_len=5) - - # Generate dot graph with duplicates - dot_graph = get_dot_graph(dag, count_duplicates=True) - node_counts = count_dot_graph_nodes(dot_graph) - - assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) - - # Generate dot graph without duplicates - dot_graph = get_dot_graph(dag, count_duplicates=False) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts without duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) - - -def test_duplicate_nodes_with_comm_count_dot_graph(): - from pytato.visualization.dot import get_dot_graph - from pytato.analysis import get_num_nodes - from testlib import get_random_pt_dag_with_send_recv_nodes - from testlib import count_dot_graph_nodes - - rank = 0 - size = 2 - for i in range(20): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) - - # Generate dot graph with duplicates - dot_graph = get_dot_graph(dag, count_duplicates=True) - node_counts = count_dot_graph_nodes(dot_graph) - - assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) - - # Generate dot graph without duplicates - dot_graph = get_dot_graph(dag, count_duplicates=False) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts without duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) - - -def test_large_dot_graph_with_duplicates_count(): - from pytato.visualization.dot import get_dot_graph - from pytato.analysis import get_num_nodes - from testlib import make_large_dag - from testlib import count_dot_graph_nodes - - iterations = 100 - dag = make_large_dag(iterations, seed=42) - - # Generate dot graph with duplicates - dot_graph = get_dot_graph(dag, count_duplicates=True) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts with duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) - - # Generate dot graph without duplicates - dot_graph = get_dot_graph(dag, count_duplicates=False) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts without duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) - - def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index f81c46682..7508eb203 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -339,15 +339,17 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: return pt.make_dict_of_named_arrays({"result": current}) -def count_dot_graph_nodes(dot_graph: str): +def count_dot_graph_nodes(dot_graph: str) -> Dict[Any, int]: """ - Parses a dot graph and returns a dictionary with the count of each unique node identifier. + Parses a dot graph and returns a dictionary with + the count of each unique node identifier. """ - node_pattern = re.compile(r'addr:(0x[0-9a-f]+)') + node_pattern = re.compile( + r'addr:(0x[0-9a-f]+)') nodes = node_pattern.findall(dot_graph) - node_counts = {} + node_counts: Dict[Any, int] = {} for node in nodes: node_counts[node] = node_counts.get(node, 0) + 1 From 8fc05e3d44bc91741568d6bd2d09c89150d0a267 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 21 Aug 2024 09:53:32 -0500 Subject: [PATCH 16/19] Add functionality that compresses IndexLambda representations in the DAG --- pytato/visualization/dot.py | 69 ++++++++++++++++++++++++++++++++++- test/test_codegen.py | 73 ------------------------------------- test/test_pytato.py | 72 ++++++++++++++++++++++++++++++++++++ test/testlib.py | 4 +- 4 files changed, 142 insertions(+), 76 deletions(-) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 38f482a06..c0bc6e2ed 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -30,6 +30,7 @@ import html import attrs import gc +import re from typing import (Callable, Dict, Tuple, Union, List, Mapping, Any, FrozenSet, Set, Optional) @@ -72,6 +73,70 @@ class _SubgraphTree: subgraphs: Dict[str, _SubgraphTree] +def extract_operation_symbol(expr): + + operation_replacements = { + r'NaN_if': 'if', + r'else': 'else', + r'isnan': 'is NaN', + r'<': '<', + r'>': '>', + r'\s*==\s*': '==', + r'\s*!=\s*': '!=', + r'\s*<=\s*': '<=', + r'\s*>=\s*': '>=', + r'\s*\+\s*': '+', + r'\s*\-\s*': '-', + r'\s*\*\*\s*': '**', + r'\s*\*\s*': '*', + r'\s*/\s*': '/', + r'\s*//\s*': '//', + r'\s*%\s*': '%', + r'\s*or\s*': 'or', + r'\s*and\s*': 'and', + r'\s*not\s*': 'not', + r'\s*<<\s*': '<<', + r'\s*>>\s*': '>>', + r'\s*\|\s*': '|', + r'\s*\^\s*': '^', + r'~\s*': '~', + r'\s*@\s*': '@', + r'\s*SumReductionOperation\s*': 'Σ', + r'<': '<', + r'>': '>', + r'&': '&', + } + + for pattern, replacement in operation_replacements.items(): + if re.search(pattern, expr.strip()): + return replacement + + return expr + + +def simplify_indexlambda_node_to_symbol_only(s): + if "IndexLambda" in s: + expr_match = re.search(r'expr:(.*?)', s) + if expr_match: + original_expr = expr_match.group(1) + operation_symbol = extract_operation_symbol(original_expr) + + print(operation_symbol) + + tooltip_content = [] + tooltip_matches = re.findall(r'(.*?)(.*?)', s) + for key, value in tooltip_matches: + tooltip_content.append(f"{key}: {value}") + + tooltip_text = ",\n".join(tooltip_content) + + new_label = f'{operation_symbol}' + + s = f'{new_label}> style=filled fillcolor="white" tooltip="{tooltip_text}"];' + + return s + + class DotEmitter: def __init__(self) -> None: self.subgraph_to_lines: Dict[Tuple[str, ...], List[str]] = {} @@ -86,7 +151,9 @@ def __call__(self, subgraph_path: Tuple[str, ...], s: str) -> None: s = remove_common_indentation(s) for line in s.split("\n"): - line_list.append(line) + simplified_line = simplify_indexlambda_node_to_symbol_only( + line) + line_list.append(simplified_line) def _get_subgraph_tree(self) -> _SubgraphTree: subgraph_tree = _SubgraphTree(contents=None, subgraphs={}) diff --git a/test/test_codegen.py b/test/test_codegen.py index 2dfe4964f..b89eb0aa4 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -2013,79 +2013,6 @@ def call_bar(tracer, x, y): np.testing.assert_allclose(result_out[k], expect_out[k]) -def test_duplicate_node_count_dot_graph(): - from pytato.visualization.dot import get_dot_graph - from pytato.analysis import get_num_nodes - from testlib import get_random_pt_dag - from testlib import count_dot_graph_nodes - - for i in range(80): - # print("curr i:", i) - dag = get_random_pt_dag(seed=i, axis_len=5) - - # Generate dot graph with duplicates - dot_graph = get_dot_graph(dag, count_duplicates=True) - node_counts = count_dot_graph_nodes(dot_graph) - - assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) - - # Generate dot graph without duplicates - dot_graph = get_dot_graph(dag, count_duplicates=False) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts without duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) - - -def test_duplicate_nodes_with_comm_count_dot_graph(): - from pytato.visualization.dot import get_dot_graph - from pytato.analysis import get_num_nodes - from testlib import get_random_pt_dag_with_send_recv_nodes - from testlib import count_dot_graph_nodes - - rank = 0 - size = 2 - for i in range(20): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) - - # Generate dot graph with duplicates - dot_graph = get_dot_graph(dag, count_duplicates=True) - node_counts = count_dot_graph_nodes(dot_graph) - - assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) - - # Generate dot graph without duplicates - dot_graph = get_dot_graph(dag, count_duplicates=False) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts without duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) - - -def test_large_dot_graph_with_duplicates_count(): - from pytato.visualization.dot import get_dot_graph - from pytato.analysis import get_num_nodes - from testlib import make_large_dag - from testlib import count_dot_graph_nodes - - iterations = 100 - dag = make_large_dag(iterations, seed=42) - - # Generate dot graph with duplicates - dot_graph = get_dot_graph(dag, count_duplicates=True) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts with duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) - - # Generate dot graph without duplicates - dot_graph = get_dot_graph(dag, count_duplicates=False) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts without duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) - - if __name__ == "__main__": if len(sys.argv) > 1: diff --git a/test/test_pytato.py b/test/test_pytato.py index 325de5106..e95827148 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1342,6 +1342,78 @@ def test_dot_visualizers(): # }}} +def test_duplicate_node_count_dot_graph(): + from pytato.visualization.dot import get_dot_graph, show_dot_graph + from pytato.analysis import get_num_nodes + from testlib import get_random_pt_dag + from testlib import count_dot_graph_nodes + + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_duplicate_nodes_with_comm_count_dot_graph(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import get_random_pt_dag_with_send_recv_nodes + from testlib import count_dot_graph_nodes + + rank = 0 + size = 2 + for i in range(20): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, + size=size) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_large_dot_graph_with_duplicates_count(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import make_large_dag + from testlib import count_dot_graph_nodes + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts with duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + if __name__ == "__main__": if len(sys.argv) > 1: diff --git a/test/testlib.py b/test/testlib.py index 2f826f401..e0b9e30a5 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -345,8 +345,8 @@ def count_dot_graph_nodes(dot_graph: str) -> Dict[Any, int]: the count of each unique node identifier. """ - node_pattern = re.compile( - r'addr:(0x[0-9a-f]+)') + node_pattern = re.compile(r'(\barray_\d+\b|\barray\b)') + nodes = node_pattern.findall(dot_graph) node_counts: Dict[Any, int] = {} From e4c59af1cebcae380acd0c3f709d9b754ecf4623 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 21 Aug 2024 10:28:37 -0500 Subject: [PATCH 17/19] Ruff changes --- pytato/visualization/dot.py | 84 +++++++++++++++++++++---------------- test/test_codegen.py | 1 - test/test_pytato.py | 12 +++++- test/testlib.py | 4 +- 4 files changed, 62 insertions(+), 39 deletions(-) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index c0bc6e2ed..9ff329774 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -76,35 +76,35 @@ class _SubgraphTree: def extract_operation_symbol(expr): operation_replacements = { - r'NaN_if': 'if', - r'else': 'else', - r'isnan': 'is NaN', - r'<': '<', - r'>': '>', - r'\s*==\s*': '==', - r'\s*!=\s*': '!=', - r'\s*<=\s*': '<=', - r'\s*>=\s*': '>=', - r'\s*\+\s*': '+', - r'\s*\-\s*': '-', - r'\s*\*\*\s*': '**', - r'\s*\*\s*': '*', - r'\s*/\s*': '/', - r'\s*//\s*': '//', - r'\s*%\s*': '%', - r'\s*or\s*': 'or', - r'\s*and\s*': 'and', - r'\s*not\s*': 'not', - r'\s*<<\s*': '<<', - r'\s*>>\s*': '>>', - r'\s*\|\s*': '|', - r'\s*\^\s*': '^', - r'~\s*': '~', - r'\s*@\s*': '@', - r'\s*SumReductionOperation\s*': 'Σ', - r'<': '<', - r'>': '>', - r'&': '&', + r"NaN_if": "if", + r"else": "else", + r"isnan": "is NaN", + r"<": "<", + r">": ">", + r"\s*==\s*": "==", + r"\s*!=\s*": "!=", + r"\s*<=\s*": "<=", + r"\s*>=\s*": ">=", + r"\s*\+\s*": "+", + r"\s*\-\s*": "-", + r"\s*\*\*\s*": "**", + r"\s*\*\s*": "*", + r"\s*/\s*": "/", + r"\s*//\s*": "//", + r"\s*%\s*": "%", + r"\s*or\s*": "or", + r"\s*and\s*": "and", + r"\s*not\s*": "not", + r"\s*<<\s*": "<<", + r"\s*>>\s*": ">>", + r"\s*\|\s*": "|", + r"\s*\^\s*": "^", + r"~\s*": "~", + r"\s*@\s*": "@", + r"\s*SumReductionOperation\s*": "Σ", + r"<": "<", + r">": ">", + r"&": "&", } for pattern, replacement in operation_replacements.items(): @@ -116,23 +116,37 @@ def extract_operation_symbol(expr): def simplify_indexlambda_node_to_symbol_only(s): if "IndexLambda" in s: - expr_match = re.search(r'expr:(.*?)', s) + expr_match = re.search( + r'expr:(.*?)', s + ) + if expr_match: original_expr = expr_match.group(1) operation_symbol = extract_operation_symbol(original_expr) - print(operation_symbol) - tooltip_content = [] - tooltip_matches = re.findall(r'(.*?)(.*?)', s) + tooltip_matches = re.findall( + r'(.*?)' + r'(.*?)', + s + ) + for key, value in tooltip_matches: tooltip_content.append(f"{key}: {value}") tooltip_text = ",\n".join(tooltip_content) - new_label = f'{operation_symbol}' + new_label = ( + f'' + f'{operation_symbol}' + f'' + ) - s = f'{new_label}> style=filled fillcolor="white" tooltip="{tooltip_text}"];' + s = ( + f'{new_label}> ' + f'style=filled fillcolor="white" ' + f'tooltip="{tooltip_text}"];' + ) return s diff --git a/test/test_codegen.py b/test/test_codegen.py index b89eb0aa4..d54eec8f8 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -2013,7 +2013,6 @@ def call_bar(tracer, x, y): np.testing.assert_allclose(result_out[k], expect_out[k]) - if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/test/test_pytato.py b/test/test_pytato.py index e95827148..62f80b5dc 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -30,6 +30,15 @@ import pytest import attrs +import os + +# Get the directory containing the script +script_dir = os.path.dirname(__file__) + +# Add the parent directory to the Python path +parent_dir = os.path.abspath(os.path.join(script_dir, os.pardir)) +sys.path.append(parent_dir) + import pytato as pt from pyopencl.tools import ( # noqa @@ -1342,8 +1351,9 @@ def test_dot_visualizers(): # }}} + def test_duplicate_node_count_dot_graph(): - from pytato.visualization.dot import get_dot_graph, show_dot_graph + from pytato.visualization.dot import get_dot_graph from pytato.analysis import get_num_nodes from testlib import get_random_pt_dag from testlib import count_dot_graph_nodes diff --git a/test/testlib.py b/test/testlib.py index e0b9e30a5..10de762af 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -341,11 +341,11 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: def count_dot_graph_nodes(dot_graph: str) -> Dict[Any, int]: """ - Parses a dot graph and returns a dictionary with + Parses a dot graph and returns a dictionary with the count of each unique node identifier. """ - node_pattern = re.compile(r'(\barray_\d+\b|\barray\b)') + node_pattern = re.compile(r"(\barray_\d+\b|\barray\b)") nodes = node_pattern.findall(dot_graph) From d96a4bf036564b6937f8ac367483ee626e0152c0 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 21 Aug 2024 11:04:37 -0500 Subject: [PATCH 18/19] Fix imports --- pytato/analysis/__init__.py | 11 ++---- test/test_pytato.py | 70 ++++++++----------------------------- test/testlib.py | 6 ++-- 3 files changed, 20 insertions(+), 67 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 8a5c2d7b3..fadf1c92d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,15 +26,8 @@ THE SOFTWARE. """ -from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - Type, TYPE_CHECKING) -from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, - DictOfNamedArrays, NamedArray, - IndexBase, IndexRemappingBase, InputArgumentBase, - ShapeType) -from pytato.function import FunctionDefinition, Call, NamedCallResult -from pytato.transform import Mapper, ArrayOrNames, CachedWalkMapper -from pytato.loopy import LoopyCall +from typing import TYPE_CHECKING, Any, Mapping + from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method diff --git a/test/test_pytato.py b/test/test_pytato.py index 530800db7..7a7b79cf8 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -32,18 +32,6 @@ import attrs import numpy as np import pytest -import attrs - -import os - -# Get the directory containing the script -script_dir = os.path.dirname(__file__) - -# Add the parent directory to the Python path -parent_dir = os.path.abspath(os.path.join(script_dir, os.pardir)) -sys.path.append(parent_dir) - -import pytato as pt from pyopencl.tools import ( # noqa pytest_generate_tests_for_pyopencl as pytest_generate_tests, @@ -695,7 +683,7 @@ def test_random_dag_with_comm_count(): for i in range(10): dag = get_random_pt_dag_with_send_recv_nodes( seed=i, rank=rank, size=size) - + assert get_num_nodes(dag, count_duplicates=False) == len( pt.transform.DependencyMapper()(dag)) @@ -778,7 +766,8 @@ def test_large_dag_with_duplicates_count(): def test_duplicate_node_count(): from testlib import get_random_pt_dag - from pytato.analysis import get_num_nodes, get_node_multiplicities + + from pytato.analysis import get_node_multiplicities, get_num_nodes for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) @@ -798,7 +787,8 @@ def test_duplicate_node_count(): def test_duplicate_nodes_with_comm_count(): from testlib import get_random_pt_dag_with_send_recv_nodes - from pytato.analysis import get_num_nodes, get_node_multiplicities + + from pytato.analysis import get_node_multiplicities, get_num_nodes rank = 0 size = 2 @@ -821,37 +811,6 @@ def test_duplicate_nodes_with_comm_count(): pt.transform.DependencyMapper()(dag)) -def test_large_dag_with_duplicates_count(): - from pytato.analysis import ( - get_num_nodes, get_node_type_counts, get_node_multiplicities - ) - from testlib import make_large_dag - import pytato as pt - - iterations = 100 - dag = make_large_dag(iterations, seed=42) - - # Verify that the number of nodes is equal to iterations + 1 (placeholder) - node_count = get_num_nodes(dag, count_duplicates=True) - assert node_count == iterations + 1 - - # Get the number of expressions and the amount they're called - node_multiplicity = get_node_multiplicities(dag) - - num_duplicates = sum( - count - 1 for count in node_multiplicity.values() if count > 1) - - counts = get_node_type_counts(dag, count_duplicates=True) - assert len(counts) >= 1 - assert counts[pt.array.Placeholder] == 1 - assert counts[pt.array.IndexLambda] == 100 # 100 operations - assert sum(counts.values()) == iterations + 1 - - # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len( - pt.transform.DependencyMapper()(dag)) - - def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) @@ -1423,6 +1382,7 @@ def test_rewrite_einsums_with_no_broadcasts(): def test_dot_visualizers(): + from testlib import RandomDAGContext, make_random_dag a = pt.make_placeholder("A", shape=(10, 4)) x1 = pt.make_placeholder("x1", shape=4) x2 = pt.make_placeholder("x2", shape=4) @@ -1452,10 +1412,10 @@ def test_dot_visualizers(): def test_duplicate_node_count_dot_graph(): - from pytato.visualization.dot import get_dot_graph + from testlib import count_dot_graph_nodes, get_random_pt_dag + from pytato.analysis import get_num_nodes - from testlib import get_random_pt_dag - from testlib import count_dot_graph_nodes + from pytato.visualization.dot import get_dot_graph for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) @@ -1475,10 +1435,10 @@ def test_duplicate_node_count_dot_graph(): def test_duplicate_nodes_with_comm_count_dot_graph(): - from pytato.visualization.dot import get_dot_graph + from testlib import count_dot_graph_nodes, get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes - from testlib import get_random_pt_dag_with_send_recv_nodes - from testlib import count_dot_graph_nodes + from pytato.visualization.dot import get_dot_graph rank = 0 size = 2 @@ -1501,10 +1461,10 @@ def test_duplicate_nodes_with_comm_count_dot_graph(): def test_large_dot_graph_with_duplicates_count(): - from pytato.visualization.dot import get_dot_graph + from testlib import count_dot_graph_nodes, make_large_dag + from pytato.analysis import get_num_nodes - from testlib import make_large_dag - from testlib import count_dot_graph_nodes + from pytato.visualization.dot import get_dot_graph iterations = 100 dag = make_large_dag(iterations, seed=42) diff --git a/test/testlib.py b/test/testlib.py index 4610d0ebf..3136a7fb4 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -2,12 +2,12 @@ import operator import random +import re import types from typing import Any, Callable, Sequence import numpy as np -import re import pyopencl as cl from pytools.tag import Tag @@ -397,7 +397,7 @@ def make_large_dag_with_duplicates(iterations: int, return pt.make_dict_of_named_arrays({"result": result}) -def count_dot_graph_nodes(dot_graph: str) -> Dict[Any, int]: +def count_dot_graph_nodes(dot_graph: str) -> dict[Any, int]: """ Parses a dot graph and returns a dictionary with the count of each unique node identifier. @@ -407,7 +407,7 @@ def count_dot_graph_nodes(dot_graph: str) -> Dict[Any, int]: nodes = node_pattern.findall(dot_graph) - node_counts: Dict[Any, int] = {} + node_counts: dict[Any, int] = {} for node in nodes: node_counts[node] = node_counts.get(node, 0) + 1 From 724c7997ea0834b93850c7fa1b47ee78796f14a0 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 21 Aug 2024 11:05:26 -0500 Subject: [PATCH 19/19] Ruff changes --- pytato/visualization/dot.py | 174 +++++++++--------------------------- 1 file changed, 41 insertions(+), 133 deletions(-) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 0d475fc9d..971f1f4ce 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -27,25 +27,16 @@ """ +import gc import html +import re from functools import partial from typing import ( - TYPE_CHECKING, Any, Callable, Mapping, - Dict, - Tuple, - Union, - List, - Any, - FrozenSet, - Set, - Optional ) -import gc -import re import attrs from pytools import UniqueNameGenerator @@ -66,6 +57,7 @@ Stack, ) from pytato.codegen import normalize_outputs +from pytato.distributed.nodes import DistributedSendRefHolder from pytato.distributed.partition import ( DistributedGraphPart, DistributedGraphPartition, @@ -75,7 +67,6 @@ from pytato.loopy import LoopyCall from pytato.tags import FunctionIdentifier from pytato.transform import ArrayOrNames, CachedMapper, InputGatherer -from pytato.distributed.nodes import DistributedSendRefHolder __doc__ = """ @@ -173,84 +164,6 @@ def simplify_indexlambda_node_to_symbol_only(s): return s -def extract_operation_symbol(expr): - - operation_replacements = { - r"NaN_if": "if", - r"else": "else", - r"isnan": "is NaN", - r"<": "<", - r">": ">", - r"\s*==\s*": "==", - r"\s*!=\s*": "!=", - r"\s*<=\s*": "<=", - r"\s*>=\s*": ">=", - r"\s*\+\s*": "+", - r"\s*\-\s*": "-", - r"\s*\*\*\s*": "**", - r"\s*\*\s*": "*", - r"\s*/\s*": "/", - r"\s*//\s*": "//", - r"\s*%\s*": "%", - r"\s*or\s*": "or", - r"\s*and\s*": "and", - r"\s*not\s*": "not", - r"\s*<<\s*": "<<", - r"\s*>>\s*": ">>", - r"\s*\|\s*": "|", - r"\s*\^\s*": "^", - r"~\s*": "~", - r"\s*@\s*": "@", - r"\s*SumReductionOperation\s*": "Σ", - r"<": "<", - r">": ">", - r"&": "&", - } - - for pattern, replacement in operation_replacements.items(): - if re.search(pattern, expr.strip()): - return replacement - - return expr - - -def simplify_indexlambda_node_to_symbol_only(s): - if "IndexLambda" in s: - expr_match = re.search( - r'expr:(.*?)', s - ) - - if expr_match: - original_expr = expr_match.group(1) - operation_symbol = extract_operation_symbol(original_expr) - - tooltip_content = [] - tooltip_matches = re.findall( - r'(.*?)' - r'(.*?)', - s - ) - - for key, value in tooltip_matches: - tooltip_content.append(f"{key}: {value}") - - tooltip_text = ",\n".join(tooltip_content) - - new_label = ( - f'' - f'{operation_symbol}' - f'' - ) - - s = ( - f'{new_label}> ' - f'style=filled fillcolor="white" ' - f'tooltip="{tooltip_text}"];' - ) - - return s - - class DotEmitter: def __init__(self) -> None: self.subgraph_to_lines: dict[tuple[str, ...], list[str]] = {} @@ -325,12 +238,10 @@ def emit_subgraph(sg: _SubgraphTree) -> None: @attrs.define class _DotNodeInfo: title: str - fields: Dict[str, Any] - edges: Dict[str, Union[ - ArrayOrNames, - FunctionDefinition, - Tuple[Union[int, ArrayOrNames], ArrayOrNames]], - Array] + fields: dict[str, Any] + edges: dict[str, ArrayOrNames | + FunctionDefinition | tuple[int | + ArrayOrNames, ArrayOrNames], Array] def stringify_tags(tags: frozenset[Tag | None]) -> str: @@ -347,7 +258,7 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" -def get_object_by_id(object_id: int) -> Union[Any, ArrayOrNames]: +def get_object_by_id(object_id: int) -> Any | ArrayOrNames: """Find an object by its ID.""" for obj in gc.get_objects(): if id(obj) == object_id: @@ -357,11 +268,11 @@ def get_object_by_id(object_id: int) -> Union[Any, ArrayOrNames]: class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): def __init__(self, count_duplicates: bool = False): - self.node_to_dot: Dict[Union[int, ArrayOrNames], _DotNodeInfo] = {} - self.functions: Set[FunctionDefinition] = set() + self.node_to_dot: dict[int | ArrayOrNames, _DotNodeInfo] = {} + self.functions: set[FunctionDefinition] = set() self.count_duplicates = count_duplicates - def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: + def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: return id(expr) if self.count_duplicates else expr def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: @@ -373,10 +284,10 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: "non_equality_tags": expr.non_equality_tags, } - edges: Dict[str, - Union[ArrayOrNames, FunctionDefinition, - Tuple[Union[int, AbstractResultWithNamedArrays, - Array], Array]]] = {} + edges: dict[str, + ArrayOrNames | FunctionDefinition | + tuple[int | AbstractResultWithNamedArrays | + Array, Array]] = {} return _DotNodeInfo(title, fields, edges) def process_node(self, expr: ArrayOrNames) -> None: @@ -507,8 +418,8 @@ def map_einsum(self, expr: Einsum) -> None: self.node_to_dot[self.get_cache_key(expr)] = info def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: - edges: Dict[str, Union[ArrayOrNames, FunctionDefinition, Tuple[Union[ - int, ArrayOrNames], Array]]] = {} + edges: dict[str, ArrayOrNames | FunctionDefinition | + tuple[int | ArrayOrNames, Array]] = {} for name, val in expr._data.items(): self.process_node(val) key = self.get_cache_key(val) @@ -520,8 +431,8 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: edges=edges) def map_loopy_call(self, expr: LoopyCall) -> None: - edges: Dict[str, Union[ArrayOrNames, FunctionDefinition, Tuple[Union[ - int, ArrayOrNames], Array]]] = {} + edges: dict[str, ArrayOrNames | FunctionDefinition | + tuple[int | ArrayOrNames, Array]] = {} for name, arg in expr.bindings.items(): if isinstance(arg, Array): self.process_node(arg) @@ -593,7 +504,7 @@ def dot_escape_leave_space(s: str) -> str: return html.escape(s.replace("\\", "\\\\")) -def get_array_key(array: Union[ArrayOrNames, FunctionDefinition, int], +def get_array_key(array: ArrayOrNames | FunctionDefinition | int, count_duplicates: bool = False) -> Any: """Return a consistent key for the array.""" return id(array) if count_duplicates and not isinstance(array, int) else array @@ -610,7 +521,7 @@ def _stringify_created_at(non_equality_tags: frozenset[Tag]) -> str: return "" -def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], +def _emit_array(emit: Callable[[str], None], title: str, fields: dict[str, Any], dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' @@ -637,7 +548,7 @@ def _emit_name_cluster( emit: DotEmitter, subgraph_path: tuple[str, ...], names: Mapping[str, ArrayOrNames], array_to_id: Mapping[ - Union[int, ArrayOrNames], str], id_gen: Callable[[str], str], + int | ArrayOrNames, str], id_gen: Callable[[str], str], label: str, count_duplicates: bool = False) -> None: edges = [] @@ -649,7 +560,7 @@ def _emit_name_cluster( for name, array in names.items(): name_id = id_gen(dot_escape(name)) - emit_cluster('%s [label="%s"]' % (name_id, dot_escape(name))) + emit_cluster(f'{name_id} [label="{dot_escape(name)}"]') array_key = get_array_key(array, count_duplicates) array_id = array_to_id[array_key] # Edges must be outside the cluster. @@ -662,13 +573,13 @@ def _emit_name_cluster( def _emit_function( emitter: DotEmitter, subgraph_path: tuple[str, ...], id_gen: UniqueNameGenerator, - node_to_dot: Mapping[Union[int, ArrayOrNames], _DotNodeInfo], + node_to_dot: Mapping[int | ArrayOrNames, _DotNodeInfo], func_to_id: Mapping[FunctionDefinition, str], outputs: Mapping[str, Array], count_duplicates: bool = False) -> None: - input_arrays: List[Array] = [] - internal_arrays: List[Union[int, ArrayOrNames]] = [] - array_to_id: Dict[Union[int, ArrayOrNames], str] = {} + input_arrays: list[Array] = [] + internal_arrays: list[int | ArrayOrNames] = [] + array_to_id: dict[int | ArrayOrNames, str] = {} emit = partial(emitter, subgraph_path) for array in node_to_dot: @@ -745,13 +656,13 @@ def _gather_partition_node_information( id_gen: UniqueNameGenerator, partition: DistributedGraphPartition, count_duplicates: bool = False - ) -> Tuple[ - Dict[PartId, Dict[FunctionDefinition, str]], - Dict[Tuple[PartId, Optional[FunctionDefinition]], - Dict[Union[int, ArrayOrNames], _DotNodeInfo]]]: - part_id_to_func_to_id: Dict[PartId, Dict[FunctionDefinition, str]] = {} - part_id_func_to_node_info: Dict[Tuple[PartId, Optional[FunctionDefinition]], - Dict[Union[int, ArrayOrNames], + ) -> tuple[ + dict[PartId, dict[FunctionDefinition, str]], + dict[tuple[PartId, FunctionDefinition | None], + dict[int | ArrayOrNames, _DotNodeInfo]]]: + part_id_to_func_to_id: dict[PartId, dict[FunctionDefinition, str]] = {} + part_id_func_to_node_info: dict[tuple[PartId, FunctionDefinition | None], + dict[int | ArrayOrNames, _DotNodeInfo]] = {} for part in partition.parts.values(): @@ -804,7 +715,7 @@ def gather_function_info(f: FunctionDefinition) -> None: # }}} -def get_dot_graph(result: Union[Array, DictOfNamedArrays], +def get_dot_graph(result: Array | DictOfNamedArrays, count_duplicates: bool = False) -> str: r"""Return a string in the `dot `_ language depicting the graph of the computation of *result*. @@ -861,8 +772,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, emit_root("node [shape=rectangle]") - placeholder_to_id: Dict[Union[int, ArrayOrNames], str] = {} - part_id_to_array_to_id: Dict[PartId, Dict[Union[int, ArrayOrNames], str]] = {} + placeholder_to_id: dict[int | ArrayOrNames, str] = {} + part_id_to_array_to_id: dict[PartId, dict[int | ArrayOrNames, str]] = {} part_id_to_id = {pid: dot_escape(str(pid)) for pid in partition.parts} assert len(set(part_id_to_id.values())) == len(partition.parts) @@ -1034,8 +945,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, for array, node in part_node_to_info.items(): key = get_array_key(array, count_duplicates) - tail_item: Union[Array, AbstractResultWithNamedArrays, - FunctionDefinition] + tail_item: Array | AbstractResultWithNamedArrays | FunctionDefinition for label, edge_info in node.edges.items(): if isinstance(edge_info, tuple): tail_key, tail_item = edge_info @@ -1053,8 +963,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, raise ValueError( f"unexpected type of tail on edge: {type(tail_item)}") - emit_root('%s -> %s [label="%s"]' % - (tail, head, dot_escape(label))) + emit_root(f'{tail} -> {head} [label="{dot_escape(label)}"]') _emit_name_cluster( emitter, part_subgraph_path, @@ -1071,7 +980,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, # {{{ draw overall outputs - combined_array_to_id: Dict[Union[int, ArrayOrNames], str] = {} + combined_array_to_id: dict[int | ArrayOrNames, str] = {} for part_id in partition.parts.keys(): combined_array_to_id.update(part_id_to_array_to_id[part_id]) @@ -1087,8 +996,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, return emitter.generate() -def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, - DistributedGraphPartition], +def show_dot_graph(result: str | Array | DictOfNamedArrays | DistributedGraphPartition, count_duplicates: bool = False, **kwargs: Any) -> None: """Show a graph representing the computation of *result* in a browser.