From 3e19358d7cf2f57a45bccd6759a61ca72890eab7 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:08:57 -0600 Subject: [PATCH 01/15] Add node counter tests --- pytato/analysis/__init__.py | 35 +++++++++++---- test/test_pytato.py | 90 ++++++++++++++++++++++++++++++++++++- test/testlib.py | 26 ++++++++++- 3 files changed, 139 insertions(+), 12 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5bf374746..41c21aa9d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,7 +26,7 @@ """ from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - TYPE_CHECKING) + Type, TYPE_CHECKING) from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, @@ -49,6 +49,8 @@ .. autofunction:: get_num_nodes +.. autofunction:: get_node_type_counts + .. autofunction:: get_num_call_sites .. autoclass:: DirectPredecessorsGetter @@ -381,23 +383,38 @@ def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]: @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NodeCountMapper(CachedWalkMapper): """ - Counts the number of nodes in a DAG. + Counts the number of nodes of a given type in a DAG. - .. attribute:: count + .. attribute:: counts - The number of nodes. + Dictionary mapping node types to number of nodes of that type. """ def __init__(self) -> None: + from collections import defaultdict super().__init__() - self.count = 0 + self.counts = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: - return id(expr) + # does NOT account for duplicate nodes + return expr def post_visit(self, expr: Any) -> None: - self.count += 1 + self.counts[type(expr)] += 1 + +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + """ + Returns a dictionary mapping node types to node count for that type + in DAG *outputs*. + """ + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper() + ncm(outputs) + return ncm.counts def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" @@ -408,7 +425,7 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: ncm = NodeCountMapper() ncm(outputs) - return ncm.count + return sum(ncm.counts.values()) # }}} @@ -463,4 +480,4 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} -# vim: fdm=marker +# vim: fdm=marker \ No newline at end of file diff --git a/test/test_pytato.py b/test/test_pytato.py index 8939073cb..cea10480c 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -26,7 +26,6 @@ """ import sys - import numpy as np import pytest import attrs @@ -585,7 +584,7 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) -def test_nodecountmapper(): +def test_node_count_mapper(): from testlib import RandomDAGContext, make_random_dag from pytato.analysis import get_num_nodes @@ -600,6 +599,93 @@ def test_nodecountmapper(): assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) +def test_empty_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + empty_dag = pt.make_dict_of_named_arrays({}) + + # Verify that get_num_nodes returns 0 for an empty DAG + assert get_num_nodes(empty_dag) - 1 == 0 + + counts = get_node_type_counts(empty_dag) + assert len(counts) == 1 + +def test_single_node_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + data = np.random.rand(4, 4) + single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + + # Get counts per node type + node_counts = get_node_type_counts(single_node_dag) + + # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # DictOfNamedArrays is automatically added + assert node_counts == {pt.DataWrapper: 1, pt.DictOfNamedArrays: 1} + assert sum(node_counts.values()) - 1 == 1 # Total node count is 1 + + # Get total number of nodes + total_nodes = get_num_nodes(single_node_dag) + + assert total_nodes - 1 == 1 + + +def test_small_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + # Make a DAG using two nodes and one operation + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + + # Verify that get_num_nodes returns 2 for a DAG with two nodes + assert get_num_nodes(dag) - 1 == 2 + + counts = get_node_type_counts(dag) + assert len(counts) - 1 == 2 + assert counts[pt.array.Placeholder] == 1 # "a" + assert counts[pt.array.IndexLambda] == 1 # single operation + + +def test_large_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + from testlib import make_large_dag + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of nodes is equal to iterations + 1 (placeholder) + assert get_num_nodes(dag) - 1 == iterations + 1 + + # Verify that the counts dictionary has correct counts for the complicated DAG + counts = get_node_type_counts(dag) + assert len(counts) >= 1 + assert counts[pt.array.Placeholder] == 1 + assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert sum(counts.values()) - 1 == iterations + 1 + + +def test_random_dag_count(): + from testlib import get_random_pt_dag + from pytato.analysis import get_num_nodes + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. + assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + +def test_random_dag_with_comm_count(): + from testlib import get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes + rank = 0 + size = 2 + for i in range(10): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. + assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index 5cd1342d3..cdf827e96 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -311,6 +311,30 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: convert_dws_to_placeholders=convert_dws_to_placeholders, additional_generators=[(comm_fake_probability, gen_comm)]) +def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: + """ + Builds a DAG with emphasis on number of operations. + """ + import random + import operator + + rng = np.random.default_rng(seed) + random.seed(seed) + + # Begin with a placeholder + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a + + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] + + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) + + return pt.make_dict_of_named_arrays({"result": current}) + # }}} @@ -369,4 +393,4 @@ class QuuxTag(TestlibTag): # }}} -# vim: foldmethod=marker +# vim: foldmethod=marker \ No newline at end of file From ea2402c9fe933dbf30cee8a71cf5247387461e65 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:41:47 -0600 Subject: [PATCH 02/15] CI fixes --- doc/conf.py | 1 + pytato/analysis/__init__.py | 10 ++++++---- test/test_pytato.py | 14 ++++++++------ test/testlib.py | 37 +++++++++++++++++++------------------ 4 files changed, 34 insertions(+), 28 deletions(-) diff --git a/doc/conf.py b/doc/conf.py index 081642f1d..e6f7ac0c0 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -46,6 +46,7 @@ nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"typing_extensions(.+)"], + ["py:class", r"numpy.bool_"], # As of 2023-10-05, it doesn't look like there's sphinx documentation # available. ["py:class", r"immutabledict(.*)"], diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 41c21aa9d..5da4ea70e 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -393,16 +393,17 @@ class NodeCountMapper(CachedWalkMapper): def __init__(self) -> None: from collections import defaultdict super().__init__() - self.counts = defaultdict(int) + self.counts = defaultdict(int) # type: Dict[Type[Any], int] - def get_cache_key(self, expr: ArrayOrNames) -> int: + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: # does NOT account for duplicate nodes return expr def post_visit(self, expr: Any) -> None: self.counts[type(expr)] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -416,6 +417,7 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, return ncm.counts + def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """Returns the number of nodes in DAG *outputs*.""" @@ -480,4 +482,4 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} -# vim: fdm=marker \ No newline at end of file +# vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index cea10480c..51e5381d5 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -610,6 +610,7 @@ def test_empty_dag_count(): counts = get_node_type_counts(empty_dag) assert len(counts) == 1 + def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts @@ -626,7 +627,7 @@ def test_single_node_dag_count(): # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) - + assert total_nodes - 1 == 1 @@ -636,15 +637,15 @@ def test_small_dag_count(): # Make a DAG using two nodes and one operation a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) b = a + 1 - dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes assert get_num_nodes(dag) - 1 == 2 counts = get_node_type_counts(dag) assert len(counts) - 1 == 2 - assert counts[pt.array.Placeholder] == 1 # "a" - assert counts[pt.array.IndexLambda] == 1 # single operation + assert counts[pt.array.Placeholder] == 1 # "a" + assert counts[pt.array.IndexLambda] == 1 # single operation def test_large_dag_count(): @@ -661,7 +662,7 @@ def test_large_dag_count(): counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 - assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert counts[pt.array.IndexLambda] == 100 # 100 operations assert sum(counts.values()) - 1 == iterations + 1 @@ -670,10 +671,11 @@ def test_random_dag_count(): from pytato.analysis import get_num_nodes for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - + # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + def test_random_dag_with_comm_count(): from testlib import get_random_pt_dag_with_send_recv_nodes from pytato.analysis import get_num_nodes diff --git a/test/testlib.py b/test/testlib.py index cdf827e96..e15489c4b 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -311,29 +311,30 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: convert_dws_to_placeholders=convert_dws_to_placeholders, additional_generators=[(comm_fake_probability, gen_comm)]) + def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: - """ - Builds a DAG with emphasis on number of operations. - """ - import random - import operator + """ + Builds a DAG with emphasis on number of operations. + """ + import random + import operator - rng = np.random.default_rng(seed) - random.seed(seed) + rng = np.random.default_rng(seed) + random.seed(seed) - # Begin with a placeholder - a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) - current = a + # Begin with a placeholder + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a - # Will randomly choose from the operators - operations = [operator.add, operator.sub, operator.mul, operator.truediv] + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] - for _ in range(iterations): - operation = random.choice(operations) - value = rng.uniform(1, 10) - current = operation(current, value) + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) - return pt.make_dict_of_named_arrays({"result": current}) + return pt.make_dict_of_named_arrays({"result": current}) # }}} @@ -393,4 +394,4 @@ class QuuxTag(TestlibTag): # }}} -# vim: foldmethod=marker \ No newline at end of file +# vim: foldmethod=marker From b122aa9a6e87419bd5ddc77ac9ca098f6091d7c0 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 17 Jun 2024 19:49:18 -0600 Subject: [PATCH 03/15] Add comments --- doc/conf.py | 1 - test/testlib.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/conf.py b/doc/conf.py index e6f7ac0c0..081642f1d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -46,7 +46,6 @@ nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"typing_extensions(.+)"], - ["py:class", r"numpy.bool_"], # As of 2023-10-05, it doesn't look like there's sphinx documentation # available. ["py:class", r"immutabledict(.*)"], diff --git a/test/testlib.py b/test/testlib.py index e15489c4b..a208f0816 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -334,6 +334,7 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: value = rng.uniform(1, 10) current = operation(current, value) + # DAG should have `iterations` number of operations return pt.make_dict_of_named_arrays({"result": current}) # }}} From 4a52c8d8ce400b71461e97424bb62fa43e5ce28e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 18 Jun 2024 10:08:34 -0600 Subject: [PATCH 04/15] Remove unnecessary test --- test/test_pytato.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 51e5381d5..962fe337e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -584,21 +584,6 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) -def test_node_count_mapper(): - from testlib import RandomDAGContext, make_random_dag - from pytato.analysis import get_num_nodes - - axis_len = 5 - - for i in range(10): - rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=False) - dag = make_random_dag(rdagc) - - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) - - def test_empty_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts From 570eda4d72c46b4922c662c8f80bc8ae976d5262 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Sun, 23 Jun 2024 16:52:02 -0600 Subject: [PATCH 05/15] Add duplicate node functionality and tests --- pytato/analysis/__init__.py | 38 +++++++++++++++------- test/test_pytato.py | 65 +++++++++++++++++++++++++++++++++++++ 2 files changed, 91 insertions(+), 12 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5da4ea70e..f009f89f6 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -385,25 +385,29 @@ class NodeCountMapper(CachedWalkMapper): """ Counts the number of nodes of a given type in a DAG. - .. attribute:: counts + .. attribute:: expr_type_counts + .. attribute:: count_duplicates + .. attribute:: expr_call_counts Dictionary mapping node types to number of nodes of that type. """ - def __init__(self) -> None: + def __init__(self, count_duplicates=False) -> None: # added parameter from collections import defaultdict super().__init__() - self.counts = defaultdict(int) # type: Dict[Type[Any], int] + self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] + self.count_duplicates = count_duplicates + self.expr_call_counts = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: - # does NOT account for duplicate nodes - return expr + return id(expr) if self.count_duplicates else expr # returns unique nodes only if count_duplicates is True def post_visit(self, expr: Any) -> None: - self.counts[type(expr)] += 1 + self.expr_type_counts[type(expr)] += 1 + self.expr_call_counts[expr] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -412,22 +416,32 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper() + ncm = NodeCountMapper(count_duplicates) ncm(outputs) - return ncm.counts + return ncm.expr_type_counts -def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: +def get_num_nodes(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> int: """Returns the number of nodes in DAG *outputs*.""" from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper() + ncm = NodeCountMapper(count_duplicates) + ncm(outputs) + + return sum(ncm.expr_type_counts.values()) + + +def get_expr_calls(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper(count_duplicates) ncm(outputs) - return sum(ncm.counts.values()) + return ncm.expr_call_counts # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 962fe337e..a1f3bc518 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -672,6 +672,71 @@ def test_random_dag_with_comm_count(): # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) +def test_duplicate_node_count(): + from testlib import get_random_pt_dag + from pytato.analysis import get_num_nodes, get_expr_calls + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Get the number of types of expressions + node_count = get_num_nodes(dag, count_duplicates=True) + + # Get the number of expressions and the amount they're called + expr_counts = get_expr_calls(dag, count_duplicates=True) + + num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + # Check that duplicates are correctly calculated + assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + + +def test_duplicate_nodes_with_comm_count(): + from testlib import get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes, get_expr_calls + + rank = 0 + size = 2 + for i in range(20): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Get the number of types of expressions + node_count = get_num_nodes(dag, count_duplicates=True) + + # Get the number of expressions and the amount they're called + expr_counts = get_expr_calls(dag, count_duplicates=True) + + num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + + # Check that duplicates are correctly calculated + assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + + +def test_large_dag_with_duplicates_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts, get_expr_calls + from testlib import make_large_dag + import pytato as pt + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of nodes is equal to iterations + 1 (placeholder) + node_count = get_num_nodes(dag, count_duplicates=True) + assert node_count - 1 == iterations + 1 + + # Get the number of expressions and the amount they're called + expr_counts = get_expr_calls(dag, count_duplicates=True) + + num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + + # Verify that the counts dictionary has correct counts for the complicated DAG + counts = get_node_type_counts(dag, count_duplicates=True) + assert len(counts) >= 1 + assert counts[pt.array.Placeholder] == 1 + assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert sum(counts.values()) - 1 == iterations + 1 + + # Check that duplicates are correctly calculated + assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) From d8dbe62f5a17a477b84592d8817d11b547bf59ee Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:16:20 -0600 Subject: [PATCH 06/15] Remove incrementation for DictOfNamedArrays and update tests --- pytato/analysis/__init__.py | 11 +++++++++-- test/test_pytato.py | 23 +++++++++++------------ 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5da4ea70e..4938f6794 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -400,13 +400,16 @@ def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: return expr def post_visit(self, expr: Any) -> None: - self.counts[type(expr)] += 1 + if type(expr) is not DictOfNamedArrays: + self.counts[type(expr)] += 1 def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. + + `DictOfNamedArrays` are added when *outputs* is normalized and ignored. """ from pytato.codegen import normalize_outputs @@ -419,7 +422,11 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: - """Returns the number of nodes in DAG *outputs*.""" + """ + Returns the number of nodes in DAG *outputs*. + + `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) diff --git a/test/test_pytato.py b/test/test_pytato.py index 962fe337e..3d9e2a684 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -590,10 +590,10 @@ def test_empty_dag_count(): empty_dag = pt.make_dict_of_named_arrays({}) # Verify that get_num_nodes returns 0 for an empty DAG - assert get_num_nodes(empty_dag) - 1 == 0 + assert get_num_nodes(empty_dag) == 0 counts = get_node_type_counts(empty_dag) - assert len(counts) == 1 + assert len(counts) == 0 def test_single_node_dag_count(): @@ -606,14 +606,13 @@ def test_single_node_dag_count(): node_counts = get_node_type_counts(single_node_dag) # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays - # DictOfNamedArrays is automatically added - assert node_counts == {pt.DataWrapper: 1, pt.DictOfNamedArrays: 1} - assert sum(node_counts.values()) - 1 == 1 # Total node count is 1 + assert node_counts == {pt.DataWrapper: 1} + assert sum(node_counts.values()) == 1 # Total node count is 1 # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) - assert total_nodes - 1 == 1 + assert total_nodes == 1 def test_small_dag_count(): @@ -625,10 +624,10 @@ def test_small_dag_count(): dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 # Verify that get_num_nodes returns 2 for a DAG with two nodes - assert get_num_nodes(dag) - 1 == 2 + assert get_num_nodes(dag) == 2 counts = get_node_type_counts(dag) - assert len(counts) - 1 == 2 + assert len(counts) == 2 assert counts[pt.array.Placeholder] == 1 # "a" assert counts[pt.array.IndexLambda] == 1 # single operation @@ -641,14 +640,14 @@ def test_large_dag_count(): dag = make_large_dag(iterations, seed=42) # Verify that the number of nodes is equal to iterations + 1 (placeholder) - assert get_num_nodes(dag) - 1 == iterations + 1 + assert get_num_nodes(dag) == iterations + 1 # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 assert counts[pt.array.IndexLambda] == 100 # 100 operations - assert sum(counts.values()) - 1 == iterations + 1 + assert sum(counts.values()) == iterations + 1 def test_random_dag_count(): @@ -658,7 +657,7 @@ def test_random_dag_count(): dag = get_random_pt_dag(seed=i, axis_len=5) # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_random_dag_with_comm_count(): @@ -670,7 +669,7 @@ def test_random_dag_with_comm_count(): dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag) - 1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): From 178127c9a7608b58901ae49f9a06c378d0b128d8 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:24:57 -0600 Subject: [PATCH 07/15] Edit tests to account for not counting DictOfNamedArrays --- test/test_pytato.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 16d1343e6..6f302caa2 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -656,7 +656,6 @@ def test_random_dag_count(): for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) @@ -668,9 +667,9 @@ def test_random_dag_with_comm_count(): for i in range(10): dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) + def test_duplicate_node_count(): from testlib import get_random_pt_dag from pytato.analysis import get_num_nodes, get_expr_calls @@ -683,9 +682,10 @@ def test_duplicate_node_count(): # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) + # Get difference in duplicates num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) def test_duplicate_nodes_with_comm_count(): @@ -703,10 +703,11 @@ def test_duplicate_nodes_with_comm_count(): # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) + # Get difference in duplicates num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) def test_large_dag_with_duplicates_count(): @@ -719,7 +720,7 @@ def test_large_dag_with_duplicates_count(): # Verify that the number of nodes is equal to iterations + 1 (placeholder) node_count = get_num_nodes(dag, count_duplicates=True) - assert node_count - 1 == iterations + 1 + assert node_count == iterations + 1 # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) @@ -731,10 +732,10 @@ def test_large_dag_with_duplicates_count(): assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 assert counts[pt.array.IndexLambda] == 100 # 100 operations - assert sum(counts.values()) - 1 == iterations + 1 + assert sum(counts.values()) == iterations + 1 # Check that duplicates are correctly calculated - assert node_count - num_duplicates - 1 == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): From 326045ecc2b34e425606a67c819a38c75a6d349e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:37:55 -0600 Subject: [PATCH 08/15] Fix CI tests --- pytato/analysis/__init__.py | 30 ++++++++++++++++++++++-------- test/test_pytato.py | 35 ++++++++++++++++++++++------------- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index aa015a760..844c9214b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -392,15 +392,16 @@ class NodeCountMapper(CachedWalkMapper): Dictionary mapping node types to number of nodes of that type. """ - def __init__(self, count_duplicates=False) -> None: # added parameter + def __init__(self, count_duplicates: bool = False) -> None: from collections import defaultdict super().__init__() - self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] + self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] self.count_duplicates = count_duplicates - self.expr_call_counts = defaultdict(int) + self.expr_call_counts = defaultdict(int) # type: Dict[Any, int] - def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: - return id(expr) if self.count_duplicates else expr # returns unique nodes only if count_duplicates is True + def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: + # Returns unique nodes only if count_duplicates is True + return id(expr) if self.count_duplicates else expr def post_visit(self, expr: Any) -> None: if type(expr) is not DictOfNamedArrays: @@ -408,7 +409,10 @@ def post_visit(self, expr: Any) -> None: self.expr_call_counts[expr] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: +def get_node_type_counts( + outputs: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False + ) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -425,7 +429,10 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays], count_duplica return ncm.expr_type_counts -def get_num_nodes(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> int: +def get_num_nodes( + outputs: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False + ) -> int: """ Returns the number of nodes in DAG *outputs*. @@ -441,7 +448,14 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays], count_duplicates=Fal return sum(ncm.expr_type_counts.values()) -def get_expr_calls(outputs: Union[Array, DictOfNamedArrays], count_duplicates=False) -> Dict[Type[Any], int]: +def get_expr_calls( + outputs: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False + ) -> Dict[Type[Any], int]: + """ + Returns the count of calls per `expr`. + """ + from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) diff --git a/test/test_pytato.py b/test/test_pytato.py index 6f302caa2..20fffb948 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -600,12 +600,13 @@ def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts data = np.random.rand(4, 4) - single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) # Get counts per node type node_counts = get_node_type_counts(single_node_dag) - # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} assert sum(node_counts.values()) == 1 # Total node count is 1 @@ -642,7 +643,6 @@ def test_large_dag_count(): # Verify that the number of nodes is equal to iterations + 1 (placeholder) assert get_num_nodes(dag) == iterations + 1 - # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 @@ -665,7 +665,8 @@ def test_random_dag_with_comm_count(): rank = 0 size = 2 for i in range(10): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) @@ -683,9 +684,11 @@ def test_duplicate_node_count(): expr_counts = get_expr_calls(dag, count_duplicates=True) # Get difference in duplicates - num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + num_duplicates = sum( + count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) def test_duplicate_nodes_with_comm_count(): @@ -695,7 +698,8 @@ def test_duplicate_nodes_with_comm_count(): rank = 0 size = 2 for i in range(20): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) # Get the number of types of expressions node_count = get_num_nodes(dag, count_duplicates=True) @@ -704,14 +708,18 @@ def test_duplicate_nodes_with_comm_count(): expr_counts = get_expr_calls(dag, count_duplicates=True) # Get difference in duplicates - num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + num_duplicates = sum( + count - 1 for count in expr_counts.values() if count > 1) # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) def test_large_dag_with_duplicates_count(): - from pytato.analysis import get_num_nodes, get_node_type_counts, get_expr_calls + from pytato.analysis import ( + get_num_nodes, get_node_type_counts, get_expr_calls + ) from testlib import make_large_dag import pytato as pt @@ -725,9 +733,9 @@ def test_large_dag_with_duplicates_count(): # Get the number of expressions and the amount they're called expr_counts = get_expr_calls(dag, count_duplicates=True) - num_duplicates = sum(count - 1 for count in expr_counts.values() if count > 1) + num_duplicates = sum( + count - 1 for count in expr_counts.values() if count > 1) - # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag, count_duplicates=True) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 @@ -735,7 +743,8 @@ def test_large_dag_with_duplicates_count(): assert sum(counts.values()) == iterations + 1 # Check that duplicates are correctly calculated - assert node_count - num_duplicates == len(pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): From 6a0a2a9f44f58d1d4ba7e39122a4f2577ad0b157 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 25 Jun 2024 18:39:57 -0600 Subject: [PATCH 09/15] Fix comments --- test/test_pytato.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 3d9e2a684..e202fd2c3 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -600,12 +600,13 @@ def test_single_node_dag_count(): from pytato.analysis import get_num_nodes, get_node_type_counts data = np.random.rand(4, 4) - single_node_dag = pt.make_dict_of_named_arrays({"result": pt.make_data_wrapper(data)}) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) # Get counts per node type node_counts = get_node_type_counts(single_node_dag) - # Assert that there is only one node of type DataWrapper and one node of DictOfNamedArrays + # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} assert sum(node_counts.values()) == 1 # Total node count is 1 @@ -642,7 +643,6 @@ def test_large_dag_count(): # Verify that the number of nodes is equal to iterations + 1 (placeholder) assert get_num_nodes(dag) == iterations + 1 - # Verify that the counts dictionary has correct counts for the complicated DAG counts = get_node_type_counts(dag) assert len(counts) >= 1 assert counts[pt.array.Placeholder] == 1 @@ -656,7 +656,6 @@ def test_random_dag_count(): for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) @@ -666,9 +665,9 @@ def test_random_dag_with_comm_count(): rank = 0 size = 2 for i in range(10): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) From 0dca4d7c4295179f8e946f9411349d2dfa43512e Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 26 Jun 2024 19:50:38 -0600 Subject: [PATCH 10/15] Clarify wording and clean up --- pytato/analysis/__init__.py | 6 +++--- test/test_pytato.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 4938f6794..3a112501b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -400,7 +400,7 @@ def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: return expr def post_visit(self, expr: Any) -> None: - if type(expr) is not DictOfNamedArrays: + if not isinstance(expr, DictOfNamedArrays): self.counts[type(expr)] += 1 @@ -409,7 +409,7 @@ def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. - `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + Instances of `DictOfNamedArrays` are excluded from counting. """ from pytato.codegen import normalize_outputs @@ -425,7 +425,7 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: """ Returns the number of nodes in DAG *outputs*. - `DictOfNamedArrays` are added when *outputs* is normalized and ignored. + Instances of `DictOfNamedArrays` are excluded from counting. """ from pytato.codegen import normalize_outputs diff --git a/test/test_pytato.py b/test/test_pytato.py index e202fd2c3..23aebd9e4 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -608,7 +608,6 @@ def test_single_node_dag_count(): # Assert that there is only one node of type DataWrapper assert node_counts == {pt.DataWrapper: 1} - assert sum(node_counts.values()) == 1 # Total node count is 1 # Get total number of nodes total_nodes = get_num_nodes(single_node_dag) From 9489ecf05dc9089903c9fc6b64f1252cb792f8d8 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 26 Jun 2024 20:21:44 -0600 Subject: [PATCH 11/15] Move `get_node_multiplicities` to its own mapper --- pytato/analysis/__init__.py | 41 +++++++++++++++++++++++++++---------- test/test_pytato.py | 18 ++++++++-------- 2 files changed, 39 insertions(+), 20 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index ad585db16..1f84aa912 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -387,7 +387,6 @@ class NodeCountMapper(CachedWalkMapper): .. attribute:: expr_type_counts .. attribute:: count_duplicates - .. attribute:: expr_call_counts Dictionary mapping node types to number of nodes of that type. """ @@ -397,7 +396,6 @@ def __init__(self, count_duplicates: bool = False) -> None: super().__init__() self.expr_type_counts = defaultdict(int) # type: Dict[Type[Any], int] self.count_duplicates = count_duplicates - self.expr_call_counts = defaultdict(int) # type: Dict[Any, int] def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: # Returns unique nodes only if count_duplicates is True @@ -406,7 +404,6 @@ def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: def post_visit(self, expr: Any) -> None: if not isinstance(expr, DictOfNamedArrays): self.expr_type_counts[type(expr)] += 1 - self.expr_call_counts[expr] += 1 def get_node_type_counts( @@ -447,21 +444,43 @@ def get_num_nodes( return sum(ncm.expr_type_counts.values()) +# }}} -def get_expr_calls( - outputs: Union[Array, DictOfNamedArrays], - count_duplicates: bool = False - ) -> Dict[Type[Any], int]: + +# {{{ NodeMultiplicityMapper + + +class NodeMultiplicityMapper(CachedWalkMapper): """ - Returns the count of calls per `expr`. + Counts the number of unique nodes by ID in a DAG. + + .. attribute:: expr_multiplicity_counts + """ + def __init__(self) -> None: + from collections import defaultdict + super().__init__() + self.expr_multiplicity_counts = defaultdict(int) # type: Dict[Any, int] + + def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: + # Returns unique nodes + return id(expr) + + def post_visit(self, expr: Any) -> None: + if not isinstance(expr, DictOfNamedArrays): + self.expr_multiplicity_counts[expr] += 1 + + +def get_node_multiplicities(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: + """ + Returns the multiplicity per `expr`. """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper(count_duplicates) - ncm(outputs) + nmm = NodeMultiplicityMapper() + nmm(outputs) - return ncm.expr_call_counts + return nmm.expr_multiplicity_counts # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 4ef754a69..c34f14776 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -672,7 +672,7 @@ def test_random_dag_with_comm_count(): def test_duplicate_node_count(): from testlib import get_random_pt_dag - from pytato.analysis import get_num_nodes, get_expr_calls + from pytato.analysis import get_num_nodes, get_node_multiplicities for i in range(80): dag = get_random_pt_dag(seed=i, axis_len=5) @@ -680,11 +680,11 @@ def test_duplicate_node_count(): node_count = get_num_nodes(dag, count_duplicates=True) # Get the number of expressions and the amount they're called - expr_counts = get_expr_calls(dag, count_duplicates=True) + node_multiplicity = get_node_multiplicities(dag) # Get difference in duplicates num_duplicates = sum( - count - 1 for count in expr_counts.values() if count > 1) + count - 1 for count in node_multiplicity.values() if count > 1) # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( pt.transform.DependencyMapper()(dag)) @@ -692,7 +692,7 @@ def test_duplicate_node_count(): def test_duplicate_nodes_with_comm_count(): from testlib import get_random_pt_dag_with_send_recv_nodes - from pytato.analysis import get_num_nodes, get_expr_calls + from pytato.analysis import get_num_nodes, get_node_multiplicities rank = 0 size = 2 @@ -704,11 +704,11 @@ def test_duplicate_nodes_with_comm_count(): node_count = get_num_nodes(dag, count_duplicates=True) # Get the number of expressions and the amount they're called - expr_counts = get_expr_calls(dag, count_duplicates=True) + node_multiplicity = get_node_multiplicities(dag) # Get difference in duplicates num_duplicates = sum( - count - 1 for count in expr_counts.values() if count > 1) + count - 1 for count in node_multiplicity.values() if count > 1) # Check that duplicates are correctly calculated assert node_count - num_duplicates == len( @@ -717,7 +717,7 @@ def test_duplicate_nodes_with_comm_count(): def test_large_dag_with_duplicates_count(): from pytato.analysis import ( - get_num_nodes, get_node_type_counts, get_expr_calls + get_num_nodes, get_node_type_counts, get_node_multiplicities ) from testlib import make_large_dag import pytato as pt @@ -730,10 +730,10 @@ def test_large_dag_with_duplicates_count(): assert node_count == iterations + 1 # Get the number of expressions and the amount they're called - expr_counts = get_expr_calls(dag, count_duplicates=True) + node_multiplicity = get_node_multiplicities(dag) num_duplicates = sum( - count - 1 for count in expr_counts.values() if count > 1) + count - 1 for count in node_multiplicity.values() if count > 1) counts = get_node_type_counts(dag, count_duplicates=True) assert len(counts) >= 1 From 27d6283ba74c24ca2744a41797a8beef36f2ff8c Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Wed, 26 Jun 2024 20:26:30 -0600 Subject: [PATCH 12/15] Add autofunction --- pytato/analysis/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 1f84aa912..3bc6785a7 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -51,6 +51,8 @@ .. autofunction:: get_node_type_counts +.. autofunction:: get_node_multiplicities + .. autofunction:: get_num_call_sites .. autoclass:: DirectPredecessorsGetter From 1444c50d77badc5d947cb2aa536eaaca1861c5df Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Tue, 2 Jul 2024 18:05:48 -0600 Subject: [PATCH 13/15] Formatting --- pytato/analysis/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3a112501b..b822d6622 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -404,7 +404,8 @@ def post_visit(self, expr: Any) -> None: self.counts[type(expr)] += 1 -def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays] + ) -> Dict[Type[Any], int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. From 528c617b80350bfe643e4a9de82a3e3fe06be34c Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Sat, 6 Jul 2024 21:58:09 -0600 Subject: [PATCH 14/15] Add functionality to graph duplicate nodes via dot graph --- pytato/visualization/dot.py | 366 ++++++++++++++++++++++-------------- test/test_pytato.py | 73 +++++++ test/testlib.py | 16 ++ 3 files changed, 315 insertions(+), 140 deletions(-) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 0e3396df0..e02ec05c5 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -29,6 +29,7 @@ from functools import partial import html import attrs +import gc from typing import (TYPE_CHECKING, Callable, Dict, Tuple, Union, List, Mapping, Any, FrozenSet, Set, Optional) @@ -51,8 +52,7 @@ from pytato.distributed.partition import ( DistributedGraphPartition, DistributedGraphPart, PartId) -if TYPE_CHECKING: - from pytato.distributed.nodes import DistributedSendRefHolder +from pytato.distributed.nodes import DistributedSendRefHolder __doc__ = """ @@ -162,12 +162,23 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" -class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): - def __init__(self) -> None: - super().__init__() - self.node_to_dot: Dict[ArrayOrNames, _DotNodeInfo] = {} - self.functions: Set[FunctionDefinition] = set() +def get_object_by_id(object_id): + """Find an object by its ID.""" + for obj in gc.get_objects(): + if id(obj) == object_id: + return obj + return None + + +class ArrayToDotNodeInfoMapper: + def __init__(self, count_duplicates: bool = False): + self.count_duplicates = count_duplicates + self.node_to_dot = {} + self.functions = set() + def get_cache_key(self, expr): + return id(expr) if self.count_duplicates else expr + def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: title = type(expr).__name__ fields = {"addr": hex(id(expr)), @@ -180,65 +191,83 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} return _DotNodeInfo(title, fields, edges) - # type-ignore-reason: incompatible with supertype - def handle_unsupported_array(self, # type: ignore[override] - expr: Array) -> None: - # Default handler, does its best to guess how to handle fields. - info = self.get_common_dot_info(expr) + def process_node(self, expr: Array) -> None: + if isinstance(expr, DataWrapper): + self.map_data_wrapper(expr) + elif isinstance(expr, IndexLambda): + self.map_index_lambda(expr) + elif isinstance(expr, Stack): + self.map_stack(expr) + elif isinstance(expr, (IndexBase, IndexLambda)): + self.map_basic_index(expr) + elif isinstance(expr, Einsum): + self.map_einsum(expr) + elif isinstance(expr, DictOfNamedArrays): + self.map_dict_of_named_arrays(expr) + elif isinstance(expr, LoopyCall): + self.map_loopy_call(expr) + elif isinstance(expr, DistributedSendRefHolder): + self.map_distributed_send_ref_holder(expr) + elif isinstance(expr, Call): + self.map_call(expr) + elif isinstance(expr, NamedCallResult): + self.map_named_call_result(expr) + else: + self.handle_unsupported_array(expr) - # pylint: disable=not-an-iterable + def handle_unsupported_array(self, expr: Array) -> None: + info = self.get_common_dot_info(expr) + expr_key = self.get_cache_key(expr) for field in attrs.fields(type(expr)): if field.name in info.fields: continue attr = getattr(expr, field.name) - if isinstance(attr, Array): - self.rec(attr) - info.edges[field.name] = attr - + self.process_node(attr) + key = self.get_cache_key(attr) + info.edges[field.name] = (key, attr) elif isinstance(attr, AbstractResultWithNamedArrays): - self.rec(attr) - info.edges[field.name] = attr - + self.process_node(attr) + key = self.get_cache_key(attr) + info.edges[field.name] = (key, attr) elif isinstance(attr, tuple): info.fields[field.name] = stringify_shape(attr) - else: info.fields[field.name] = str(attr) - - self.node_to_dot[expr] = info + self.node_to_dot[expr_key] = info def map_data_wrapper(self, expr: DataWrapper) -> None: info = self.get_common_dot_info(expr) if expr.name is not None: info.fields["name"] = expr.name - # Only show summarized data import numpy as np with np.printoptions(threshold=4, precision=2): info.fields["data"] = str(expr.data) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_index_lambda(self, expr: IndexLambda) -> None: info = self.get_common_dot_info(expr) info.fields["expr"] = str(expr.expr) for name, val in expr.bindings.items(): - self.rec(val) - info.edges[name] = val + self.process_node(val) + key = self.get_cache_key(val) + info.edges[name] = (key, val) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_stack(self, expr: Stack) -> None: info = self.get_common_dot_info(expr) info.fields["axis"] = str(expr.axis) for i, array in enumerate(expr.arrays): - self.rec(array) - info.edges[str(i)] = array + self.process_node(array) + key = self.get_cache_key(array) + info.edges[str(i)] = (key, array) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info map_concatenate = map_stack @@ -254,9 +283,10 @@ def map_basic_index(self, expr: IndexBase) -> None: elif isinstance(index, Array): label = f"i{i}" - self.rec(index) + self.process_node(index) + key = self.get_cache_key(index) indices_parts.append(label) - info.edges[label] = index + info.edges[label] = (key, index) elif index is None: indices_parts.append("newaxis") @@ -266,10 +296,11 @@ def map_basic_index(self, expr: IndexBase) -> None: info.fields["indices"] = ", ".join(indices_parts) - self.rec(expr.array) - info.edges["array"] = expr.array + self.process_node(expr.array) + key = self.get_cache_key(expr.array) + info.edges["array"] = (key, expr.array) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info map_contiguous_advanced_index = map_basic_index map_non_contiguous_advanced_index = map_basic_index @@ -279,18 +310,20 @@ def map_einsum(self, expr: Einsum) -> None: for iarg, (access_descr, val) in enumerate(zip(expr.access_descriptors, expr.args)): - self.rec(val) - info.edges[f"{iarg}: {access_descr}"] = val + self.process_node(val) + key = self.get_cache_key(val) + info.edges[f"{iarg}: {access_descr}"] = (key, val) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} for name, val in expr._data.items(): - edges[name] = val - self.rec(val) + self.process_node(val) + key = self.get_cache_key(val) + edges[name] = (key, val) - self.node_to_dot[expr] = _DotNodeInfo( + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=type(expr).__name__, fields={}, edges=edges) @@ -299,10 +332,11 @@ def map_loopy_call(self, expr: LoopyCall) -> None: edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} for name, arg in expr.bindings.items(): if isinstance(arg, Array): - edges[name] = arg - self.rec(arg) + self.process_node(arg) + key = self.get_cache_key(arg) + edges[name] = (key, arg) - self.node_to_dot[expr] = _DotNodeInfo( + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=type(expr).__name__, fields={"addr": hex(id(expr)), "entrypoint": expr.entrypoint}, edges=edges) @@ -312,29 +346,30 @@ def map_distributed_send_ref_holder( info = self.get_common_dot_info(expr) - self.rec(expr.passthrough_data) - info.edges["passthrough"] = expr.passthrough_data + self.process_node(expr.passthrough_data) + key = self.get_cache_key(expr.passthrough_data) + info.edges["passthrough"] = (key, expr.passthrough_data) - self.rec(expr.send.data) - info.edges["sent"] = expr.send.data + self.process_node(expr.send.data) + key = self.get_cache_key(expr.send.data) + info.edges["sent"] = (key, expr.send.data) info.fields["dest_rank"] = str(expr.send.dest_rank) - info.fields["comm_tag"] = str(expr.send.comm_tag) - self.node_to_dot[expr] = info + self.node_to_dot[self.get_cache_key(expr)] = info def map_call(self, expr: Call) -> None: self.functions.add(expr.function) for bnd in expr.bindings.values(): - self.rec(bnd) + self.process_node(bnd) - self.node_to_dot[expr] = _DotNodeInfo( + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=expr.__class__.__name__, edges={ "": expr.function, - **expr.bindings}, + **{name: (self.get_cache_key(bnd), bnd) for name, bnd in expr.bindings.items()}}, fields={ "addr": hex(id(expr)), "tags": stringify_tags(expr.tags), @@ -342,14 +377,16 @@ def map_call(self, expr: Call) -> None: ) def map_named_call_result(self, expr: NamedCallResult) -> None: - self.rec(expr._container) - self.node_to_dot[expr] = _DotNodeInfo( + self.process_node(expr._container) + key = self.get_cache_key(expr._container) + self.node_to_dot[self.get_cache_key(expr)] = _DotNodeInfo( title=expr.__class__.__name__, - edges={"": expr._container}, + edges={"": (key, expr._container)}, fields={"addr": hex(id(expr)), "name": expr.name}, ) + # }}} @@ -363,6 +400,11 @@ def dot_escape_leave_space(s: str) -> str: return html.escape(s.replace("\\", "\\\\")) +def get_array_key(array, count_duplicates): + """Return a consistent key for the array.""" + return id(array) if count_duplicates and not isinstance(array, int) else array + + # {{{ emit helpers def _stringify_created_at(non_equality_tags: FrozenSet[Tag]) -> str: @@ -375,7 +417,7 @@ def _stringify_created_at(non_equality_tags: FrozenSet[Tag]) -> str: def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], - dot_node_id: str, color: str = "white") -> None: + dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' @@ -397,11 +439,20 @@ def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], f'tooltip="{tooltip}"]') +def preprocess_all_nodes(partition, array_to_id, id_gen, count_duplicates): + mapper = ArrayToDotNodeInfoMapper(count_duplicates) + for part in partition.parts.values(): + for out_name in part.output_names: + node = partition.name_to_output[out_name] + mapper.process_node(node) + + def _emit_name_cluster( emit: DotEmitter, subgraph_path: Tuple[str, ...], names: Mapping[str, ArrayOrNames], array_to_id: Mapping[ArrayOrNames, str], id_gen: Callable[[str], str], - label: str) -> None: + label: str, + count_duplicates: bool = False) -> None: edges = [] cluster_subgraph_path = subgraph_path + (f"cluster_{dot_escape(label)}",) @@ -412,7 +463,8 @@ def _emit_name_cluster( for name, array in names.items(): name_id = id_gen(dot_escape(name)) emit_cluster('%s [label="%s"]' % (name_id, dot_escape(name))) - array_id = array_to_id[array] + array_key = get_array_key(array, count_duplicates) + array_id = array_to_id[array_key] # Edges must be outside the cluster. edges.append((name_id, array_id)) @@ -425,14 +477,16 @@ def _emit_function( id_gen: UniqueNameGenerator, node_to_dot: Mapping[ArrayOrNames, _DotNodeInfo], func_to_id: Mapping[FunctionDefinition, str], - outputs: Mapping[str, Array]) -> None: + outputs: Mapping[str, Array], + count_duplicates: bool = False) -> None: input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] array_to_id: Dict[ArrayOrNames, str] = {} emit = partial(emitter, subgraph_path) for array in node_to_dot: - array_to_id[array] = id_gen("array") + key = get_array_key(array, count_duplicates) + array_to_id[key] = id_gen("array") if isinstance(array, InputArgumentBase): input_arrays.append(array) else: @@ -444,36 +498,46 @@ def _emit_function( emit_input('label="Arguments"') for array in input_arrays: + key = get_array_key(array, count_duplicates) _emit_array( emit_input, node_to_dot[array].title, node_to_dot[array].fields, - array_to_id[array]) + array_to_id[key]) # Emit non-inputs. for array in internal_arrays: + key = get_array_key(array, count_duplicates) _emit_array(emit, node_to_dot[array].title, node_to_dot[array].fields, - array_to_id[array]) + array_to_id[key]) # Emit edges. for array, node in node_to_dot.items(): - for label, tail_item in node.edges.items(): - head = array_to_id[array] + key = get_array_key(array, count_duplicates) + for label, edge_info in node.edges.items(): + if isinstance(edge_info, tuple): + tail_key, tail_item = edge_info + else: + tail_item = edge_info + tail_key = get_array_key(tail_item, count_duplicates) + + head = array_to_id[key] if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)): - tail = array_to_id[tail_item] + tail = array_to_id[tail_key] elif isinstance(tail_item, FunctionDefinition): tail = func_to_id[tail_item] else: - raise ValueError( - f"unexpected type of tail on edge: {type(tail_item)}") + raise ValueError(f"unexpected type of tail on edge: {type(tail_item)}") emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label))) # Emit output/namespace name mappings. _emit_name_cluster( - emitter, subgraph_path, outputs, array_to_id, id_gen, label="Returns") + emitter, subgraph_path, + outputs, array_to_id, id_gen, + label="Returns", count_duplicates=count_duplicates) # }}} @@ -491,7 +555,8 @@ def _get_function_name(f: FunctionDefinition) -> Optional[str]: def _gather_partition_node_information( id_gen: UniqueNameGenerator, - partition: DistributedGraphPartition + partition: DistributedGraphPartition, + count_duplicates: bool = False ) -> Tuple[ Mapping[PartId, Mapping[FunctionDefinition, str]], Mapping[Tuple[PartId, Optional[FunctionDefinition]], @@ -502,9 +567,9 @@ def _gather_partition_node_information( Dict[ArrayOrNames, _DotNodeInfo]] = {} for part in partition.parts.values(): - mapper = ArrayToDotNodeInfoMapper() + mapper = ArrayToDotNodeInfoMapper(count_duplicates) for out_name in part.output_names: - mapper(partition.name_to_output[out_name]) + mapper.process_node(partition.name_to_output[out_name]) part_id_func_to_node_info[part.pid, None] = mapper.node_to_dot part_id_to_func_to_id[part.pid] = {} @@ -519,9 +584,9 @@ def gather_function_info(f: FunctionDefinition) -> None: if key in part_id_func_to_node_info: return - mapper = ArrayToDotNodeInfoMapper() + mapper = ArrayToDotNodeInfoMapper(count_duplicates) for elem in f.returns.values(): - mapper(elem) + mapper.process_node(elem) part_id_func_to_node_info[key] = mapper.node_to_dot @@ -547,10 +612,12 @@ def gather_function_info(f: FunctionDefinition) -> None: return part_id_to_func_to_id, part_id_func_to_node_info + # }}} -def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: +def get_dot_graph(result: Union[Array, DictOfNamedArrays], + count_duplicates: bool = False) -> str: r"""Return a string in the `dot `_ language depicting the graph of the computation of *result*. @@ -560,30 +627,32 @@ def get_dot_graph(result: Union[Array, DictOfNamedArrays]) -> str: outputs: DictOfNamedArrays = normalize_outputs(result) - return get_dot_graph_from_partition( - DistributedGraphPartition( - parts={ - None: DistributedGraphPart( - pid=None, - needed_pids=frozenset(), - user_input_names=frozenset( + partition = DistributedGraphPartition( + parts={ + None: DistributedGraphPart( + pid=None, + needed_pids=frozenset(), + user_input_names=frozenset( expr.name for expr in InputGatherer()(outputs) if isinstance(expr, Placeholder) ), - partition_input_names=frozenset(), - output_names=frozenset(outputs.keys()), - name_to_recv_node={}, - name_to_send_nodes={}, - ) - }, - name_to_output=outputs._data, - overall_output_names=tuple(outputs), - )) - - -def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: - r"""Return a string in the `dot `_ language depicting the + partition_input_names=frozenset(), + output_names=frozenset(outputs.keys()), + name_to_recv_node={}, + name_to_send_nodes={}, + ) + }, + name_to_output=outputs._data, + overall_output_names=tuple(outputs), + ) + + return get_dot_graph_from_partition(partition, count_duplicates) + + +def get_dot_graph_from_partition(partition: DistributedGraphPartition, + count_duplicates: bool = False) -> str: + """Return a string in the `dot `_ language depicting the graph of the partitioned computation of *partition*. :arg partition: Outputs of :func:`~pytato.find_distributed_partition`. @@ -595,9 +664,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # The "None" function is the body of the partition. part_id_to_func_to_id, part_id_func_to_node_info = \ - _gather_partition_node_information(id_gen, partition) - - # }}} + _gather_partition_node_information(id_gen, partition, count_duplicates) emitter = DotEmitter() emit_root = partial(emitter, ()) @@ -606,8 +673,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: emit_root("node [shape=rectangle]") - placeholder_to_id: Dict[ArrayOrNames, str] = {} - part_id_to_array_to_id: Dict[PartId, Dict[ArrayOrNames, str]] = {} + placeholder_to_id: Dict[Union[int, ArrayOrNames], str] = {} + part_id_to_array_to_id: Dict[PartId, Dict[Union[int, ArrayOrNames], str]] = {} part_id_to_id = {pid: dot_escape(str(pid)) for pid in partition.parts} assert len(set(part_id_to_id.values())) == len(partition.parts) @@ -617,16 +684,18 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: for part in partition.parts.values(): array_to_id = {} for array in part_id_func_to_node_info[part.pid, None].keys(): + if isinstance(array, int): # if the key is an ID + array = get_object_by_id(array) + key = get_array_key(array, count_duplicates) if isinstance(array, Placeholder): - # Placeholders are only emitted once - if array in placeholder_to_id: - node_id = placeholder_to_id[array] + if key in placeholder_to_id: + node_id = placeholder_to_id[key] else: node_id = id_gen("array") - placeholder_to_id[array] = node_id + placeholder_to_id[key] = node_id else: node_id = id_gen("array") - array_to_id[array] = node_id + array_to_id[key] = node_id part_id_to_array_to_id[part.pid] = array_to_id # }}} @@ -663,32 +732,34 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: _emit_function(emitter, func_subgraph_path, id_gen, part_id_func_to_node_info[part.pid, func], part_id_to_func_to_id[part.pid], - func.returns) + func.returns, + count_duplicates=count_duplicates) # }}} # {{{ emit receives nodes part_dist_recv_var_name_to_node_id = {} - for name, recv in ( - part.name_to_recv_node.items()): + for name, recv in part.name_to_recv_node.items(): node_id = id_gen("recv") _emit_array(emit_part, "DistributedRecv", { "shape": stringify_shape(recv.shape), "dtype": str(recv.dtype), "src_rank": str(recv.src_rank), "comm_tag": str(recv.comm_tag), - }, node_id) + }, node_id) part_dist_recv_var_name_to_node_id[name] = node_id # }}} - + part_node_to_info = part_id_func_to_node_info[part.pid, None] input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] for array in part_node_to_info.keys(): + if isinstance(array, int): # if the key is an ID + array = get_object_by_id(array) if isinstance(array, InputArgumentBase): input_arrays.append(array) else: @@ -702,26 +773,26 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # subgraphs. for array in input_arrays: + key = array = get_array_key(array, count_duplicates) if not isinstance(array, Placeholder): _emit_array(emit_part, part_node_to_info[array].title, part_node_to_info[array].fields, - array_to_id[array], "deepskyblue") + array_to_id[key], "deepskyblue") else: # Is a Placeholder - if array in emitted_placeholders: + if key in emitted_placeholders: continue - _emit_array(emit_root, part_node_to_info[array].title, part_node_to_info[array].fields, - array_to_id[array], "deepskyblue") + array_to_id[key], "deepskyblue") # Emit cross-partition edges if array.name in part_dist_recv_var_name_to_node_id: tgt = part_dist_recv_var_name_to_node_id[array.name] - emit_root(f"{tgt} -> {array_to_id[array]} [style=dotted]") - emitted_placeholders.add(array) + emit_root(f"{tgt} -> {array_to_id[key]} [style=dotted]") + emitted_placeholders.add(key) elif array.name in part.user_input_names: # no arrows for these pass @@ -734,18 +805,22 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: break assert computing_pid is not None tgt = part_id_to_array_to_id[computing_pid][ - partition.name_to_output[array.name]] - emit_root(f"{tgt} -> {array_to_id[array]} [style=dashed]") - emitted_placeholders.add(array) + id(partition.name_to_output[array.name]) + if count_duplicates + else partition.name_to_output[array.name]] + emit_root(f"{tgt} -> {array_to_id[key]} [style=dashed]") + emitted_placeholders.add(key) # }}} # Emit internal nodes + for array in internal_arrays: + key = array = get_array_key(array, count_duplicates) _emit_array(emit_part, part_node_to_info[array].title, part_node_to_info[array].fields, - array_to_id[array]) + array_to_id[key]) # {{{ emit send nodes if distributed @@ -756,36 +831,45 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: _emit_array(emit_part, "DistributedSend", { "dest_rank": str(send.dest_rank), "comm_tag": str(send.comm_tag), - }, node_id) + }, node_id) # If an edge is emitted in a subgraph, it drags its # nodes into the subgraph, too. Not what we want. emit_root( - f"{array_to_id[send.data]} -> {node_id}" - f'[style=dotted, label="{dot_escape(name)}"]') + f"{array_to_id[id(send.data) if count_duplicates else send.data]} -> {node_id}" + f'[style=dotted, label="{dot_escape(name)}"]') # }}} # Emit intra-partition edges for array, node in part_node_to_info.items(): - for label, tail_item in node.edges.items(): - head = array_to_id[array] + key = get_array_key(array, count_duplicates) + for label, edge_info in node.edges.items(): + if isinstance(edge_info, tuple): + tail_key, tail_item = edge_info + else: + tail_item = edge_info + tail_key = get_array_key(tail_item, count_duplicates) + + head = array_to_id[key] if isinstance(tail_item, (Array, AbstractResultWithNamedArrays)): - tail = array_to_id[tail_item] + tail = array_to_id[tail_key] elif isinstance(tail_item, FunctionDefinition): tail = part_id_to_func_to_id[part.pid][tail_item] else: raise ValueError( - f"unexpected type of tail on edge: {type(tail_item)}") + f"unexpected type of tail on edge: {type(tail_item)}") emit_root('%s -> %s [label="%s"]' % - (tail, head, dot_escape(label))) + (tail, head, dot_escape(label))) _emit_name_cluster( - emitter, part_subgraph_path, - {name: partition.name_to_output[name] for name in part.output_names}, - array_to_id, id_gen, "Part outputs") + emitter, part_subgraph_path, + {name: partition.name_to_output[name] + for name in part.output_names}, + array_to_id, id_gen, "Part outputs", + count_duplicates) # }}} @@ -795,15 +879,16 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: # {{{ draw overall outputs - combined_array_to_id: Dict[ArrayOrNames, str] = {} + combined_array_to_id: Dict[Union[int, ArrayOrNames], str] = {} for part_id in partition.parts.keys(): combined_array_to_id.update(part_id_to_array_to_id[part_id]) _emit_name_cluster( - emitter, (), - {name: partition.name_to_output[name] - for name in partition.overall_output_names}, - combined_array_to_id, id_gen, "Overall outputs") + emitter, (), + {name: partition.name_to_output[name] + for name in partition.overall_output_names}, + combined_array_to_id, id_gen, "Overall outputs", + count_duplicates) # }}} @@ -812,7 +897,8 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition) -> str: def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, DistributedGraphPartition], - **kwargs: Any) -> None: + count_duplicates: bool = False, + **kwargs: Any) -> None: """Show a graph representing the computation of *result* in a browser. :arg result: Outputs of the computation (cf. @@ -825,9 +911,9 @@ def show_dot_graph(result: Union[str, Array, DictOfNamedArrays, if isinstance(result, str): dot_code = result elif isinstance(result, DistributedGraphPartition): - dot_code = get_dot_graph_from_partition(result) + dot_code = get_dot_graph_from_partition(result, count_duplicates) else: - dot_code = get_dot_graph(result) + dot_code = get_dot_graph(result, count_duplicates) from pytools.graphviz import show_dot show_dot(dot_code, **kwargs) diff --git a/test/test_pytato.py b/test/test_pytato.py index c34f14776..45102158e 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -746,6 +746,79 @@ def test_large_dag_with_duplicates_count(): pt.transform.DependencyMapper()(dag)) +def test_duplicate_node_count_dot_graph(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import get_random_pt_dag + from testlib import count_dot_graph_nodes + + for i in range(80): + # print("curr i:", i) + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_duplicate_nodes_with_comm_count_dot_graph(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import get_random_pt_dag_with_send_recv_nodes + from testlib import count_dot_graph_nodes + + rank = 0 + size = 2 + for i in range(20): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_large_dot_graph_with_duplicates_count(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import make_large_dag + from testlib import count_dot_graph_nodes + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts with duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index a208f0816..f81c46682 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -3,6 +3,7 @@ import types from typing import Any, Dict, Optional, List, Tuple, Union, Sequence, Callable import operator +import re import pyopencl as cl import numpy as np import pytato as pt @@ -337,6 +338,21 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: # DAG should have `iterations` number of operations return pt.make_dict_of_named_arrays({"result": current}) + +def count_dot_graph_nodes(dot_graph: str): + """ + Parses a dot graph and returns a dictionary with the count of each unique node identifier. + """ + + node_pattern = re.compile(r'addr:(0x[0-9a-f]+)') + nodes = node_pattern.findall(dot_graph) + + node_counts = {} + for node in nodes: + node_counts[node] = node_counts.get(node, 0) + 1 + + return node_counts + # }}} From f3ff93a835a7e555312b02338e5a4e221d212639 Mon Sep 17 00:00:00 2001 From: kajalpatelinfo Date: Mon, 8 Jul 2024 20:56:53 -0600 Subject: [PATCH 15/15] Linting and formatting --- pytato/analysis/__init__.py | 3 +- pytato/visualization/dot.py | 87 +++++++++++++++++++++---------------- test/test_codegen.py | 77 +++++++++++++++++++++++++++++++- test/test_pytato.py | 73 ------------------------------- test/testlib.py | 10 +++-- 5 files changed, 132 insertions(+), 118 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3bc6785a7..f025ff350 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -472,7 +472,8 @@ def post_visit(self, expr: Any) -> None: self.expr_multiplicity_counts[expr] += 1 -def get_node_multiplicities(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: +def get_node_multiplicities( + outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type[Any], int]: """ Returns the multiplicity per `expr`. """ diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index e02ec05c5..38f482a06 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -31,7 +31,7 @@ import attrs import gc -from typing import (TYPE_CHECKING, Callable, Dict, Tuple, Union, List, +from typing import (Callable, Dict, Tuple, Union, List, Mapping, Any, FrozenSet, Set, Optional) from pytools import UniqueNameGenerator @@ -145,7 +145,11 @@ def emit_subgraph(sg: _SubgraphTree) -> None: class _DotNodeInfo: title: str fields: Dict[str, Any] - edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] + edges: Dict[str, Union[ + ArrayOrNames, + FunctionDefinition, + Tuple[Union[int, ArrayOrNames], ArrayOrNames]], + Array] def stringify_tags(tags: FrozenSet[Optional[Tag]]) -> str: @@ -162,7 +166,7 @@ def stringify_shape(shape: ShapeType) -> str: return "(" + ", ".join(components) + ")" -def get_object_by_id(object_id): +def get_object_by_id(object_id: int) -> Union[Any, ArrayOrNames]: """Find an object by its ID.""" for obj in gc.get_objects(): if id(obj) == object_id: @@ -170,15 +174,15 @@ def get_object_by_id(object_id): return None -class ArrayToDotNodeInfoMapper: +class ArrayToDotNodeInfoMapper(CachedMapper[ArrayOrNames]): def __init__(self, count_duplicates: bool = False): + self.node_to_dot: Dict[Union[int, ArrayOrNames], _DotNodeInfo] = {} + self.functions: Set[FunctionDefinition] = set() self.count_duplicates = count_duplicates - self.node_to_dot = {} - self.functions = set() - def get_cache_key(self, expr): + def get_cache_key(self, expr: ArrayOrNames) -> Union[int, ArrayOrNames]: return id(expr) if self.count_duplicates else expr - + def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: title = type(expr).__name__ fields = {"addr": hex(id(expr)), @@ -188,17 +192,20 @@ def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: "non_equality_tags": expr.non_equality_tags, } - edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} + edges: Dict[str, + Union[ArrayOrNames, FunctionDefinition, + Tuple[Union[int, AbstractResultWithNamedArrays, + Array], Array]]] = {} return _DotNodeInfo(title, fields, edges) - def process_node(self, expr: Array) -> None: + def process_node(self, expr: ArrayOrNames) -> None: if isinstance(expr, DataWrapper): self.map_data_wrapper(expr) elif isinstance(expr, IndexLambda): self.map_index_lambda(expr) elif isinstance(expr, Stack): self.map_stack(expr) - elif isinstance(expr, (IndexBase, IndexLambda)): + elif isinstance(expr, IndexBase): self.map_basic_index(expr) elif isinstance(expr, Einsum): self.map_einsum(expr) @@ -215,7 +222,9 @@ def process_node(self, expr: Array) -> None: else: self.handle_unsupported_array(expr) - def handle_unsupported_array(self, expr: Array) -> None: + def handle_unsupported_array(self, + expr: Array) -> None: + # Default handler, does its best to guess how to handle fields. info = self.get_common_dot_info(expr) expr_key = self.get_cache_key(expr) for field in attrs.fields(type(expr)): @@ -317,7 +326,8 @@ def map_einsum(self, expr: Einsum) -> None: self.node_to_dot[self.get_cache_key(expr)] = info def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: - edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition, Tuple[Union[ + int, ArrayOrNames], Array]]] = {} for name, val in expr._data.items(): self.process_node(val) key = self.get_cache_key(val) @@ -329,7 +339,8 @@ def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> None: edges=edges) def map_loopy_call(self, expr: LoopyCall) -> None: - edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} + edges: Dict[str, Union[ArrayOrNames, FunctionDefinition, Tuple[Union[ + int, ArrayOrNames], Array]]] = {} for name, arg in expr.bindings.items(): if isinstance(arg, Array): self.process_node(arg) @@ -369,7 +380,8 @@ def map_call(self, expr: Call) -> None: title=expr.__class__.__name__, edges={ "": expr.function, - **{name: (self.get_cache_key(bnd), bnd) for name, bnd in expr.bindings.items()}}, + **{name: (self.get_cache_key(bnd), bnd) + for name, bnd in expr.bindings.items()}}, fields={ "addr": hex(id(expr)), "tags": stringify_tags(expr.tags), @@ -400,7 +412,8 @@ def dot_escape_leave_space(s: str) -> str: return html.escape(s.replace("\\", "\\\\")) -def get_array_key(array, count_duplicates): +def get_array_key(array: Union[ArrayOrNames, FunctionDefinition, int], + count_duplicates: bool = False) -> Any: """Return a consistent key for the array.""" return id(array) if count_duplicates and not isinstance(array, int) else array @@ -439,18 +452,11 @@ def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], f'tooltip="{tooltip}"]') -def preprocess_all_nodes(partition, array_to_id, id_gen, count_duplicates): - mapper = ArrayToDotNodeInfoMapper(count_duplicates) - for part in partition.parts.values(): - for out_name in part.output_names: - node = partition.name_to_output[out_name] - mapper.process_node(node) - - def _emit_name_cluster( emit: DotEmitter, subgraph_path: Tuple[str, ...], names: Mapping[str, ArrayOrNames], - array_to_id: Mapping[ArrayOrNames, str], id_gen: Callable[[str], str], + array_to_id: Mapping[ + Union[int, ArrayOrNames], str], id_gen: Callable[[str], str], label: str, count_duplicates: bool = False) -> None: edges = [] @@ -475,13 +481,13 @@ def _emit_name_cluster( def _emit_function( emitter: DotEmitter, subgraph_path: Tuple[str, ...], id_gen: UniqueNameGenerator, - node_to_dot: Mapping[ArrayOrNames, _DotNodeInfo], + node_to_dot: Mapping[Union[int, ArrayOrNames], _DotNodeInfo], func_to_id: Mapping[FunctionDefinition, str], outputs: Mapping[str, Array], count_duplicates: bool = False) -> None: input_arrays: List[Array] = [] - internal_arrays: List[ArrayOrNames] = [] - array_to_id: Dict[ArrayOrNames, str] = {} + internal_arrays: List[Union[int, ArrayOrNames]] = [] + array_to_id: Dict[Union[int, ArrayOrNames], str] = {} emit = partial(emitter, subgraph_path) for array in node_to_dot: @@ -529,7 +535,8 @@ def _emit_function( elif isinstance(tail_item, FunctionDefinition): tail = func_to_id[tail_item] else: - raise ValueError(f"unexpected type of tail on edge: {type(tail_item)}") + raise ValueError( + f"unexpected type of tail on edge: {type(tail_item)}") emit('%s -> %s [label="%s"]' % (tail, head, dot_escape(label))) @@ -558,13 +565,13 @@ def _gather_partition_node_information( partition: DistributedGraphPartition, count_duplicates: bool = False ) -> Tuple[ - Mapping[PartId, Mapping[FunctionDefinition, str]], - Mapping[Tuple[PartId, Optional[FunctionDefinition]], - Mapping[ArrayOrNames, _DotNodeInfo]] - ]: + Dict[PartId, Dict[FunctionDefinition, str]], + Dict[Tuple[PartId, Optional[FunctionDefinition]], + Dict[Union[int, ArrayOrNames], _DotNodeInfo]]]: part_id_to_func_to_id: Dict[PartId, Dict[FunctionDefinition, str]] = {} part_id_func_to_node_info: Dict[Tuple[PartId, Optional[FunctionDefinition]], - Dict[ArrayOrNames, _DotNodeInfo]] = {} + Dict[Union[int, ArrayOrNames], + _DotNodeInfo]] = {} for part in partition.parts.values(): mapper = ArrayToDotNodeInfoMapper(count_duplicates) @@ -650,7 +657,7 @@ def get_dot_graph(result: Union[Array, DictOfNamedArrays], return get_dot_graph_from_partition(partition, count_duplicates) -def get_dot_graph_from_partition(partition: DistributedGraphPartition, +def get_dot_graph_from_partition(partition: DistributedGraphPartition, count_duplicates: bool = False) -> str: """Return a string in the `dot `_ language depicting the graph of the partitioned computation of *partition*. @@ -752,7 +759,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, part_dist_recv_var_name_to_node_id[name] = node_id # }}} - + part_node_to_info = part_id_func_to_node_info[part.pid, None] input_arrays: List[Array] = [] internal_arrays: List[ArrayOrNames] = [] @@ -814,7 +821,7 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, # }}} # Emit internal nodes - + for array in internal_arrays: key = array = get_array_key(array, count_duplicates) _emit_array(emit_part, @@ -835,8 +842,9 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, # If an edge is emitted in a subgraph, it drags its # nodes into the subgraph, too. Not what we want. + data = id(send.data) if count_duplicates else send.data emit_root( - f"{array_to_id[id(send.data) if count_duplicates else send.data]} -> {node_id}" + f"{array_to_id[data]} -> {node_id}" f'[style=dotted, label="{dot_escape(name)}"]') # }}} @@ -844,6 +852,9 @@ def get_dot_graph_from_partition(partition: DistributedGraphPartition, # Emit intra-partition edges for array, node in part_node_to_info.items(): key = get_array_key(array, count_duplicates) + + tail_item: Union[Array, AbstractResultWithNamedArrays, + FunctionDefinition] for label, edge_info in node.edges.items(): if isinstance(edge_info, tuple): tail_key, tail_item = edge_info diff --git a/test/test_codegen.py b/test/test_codegen.py index 0f1456d9b..ca43d6715 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -27,11 +27,10 @@ """ from typing import Union +import sys import itertools import operator -import sys - import loopy as lp import numpy as np import pyopencl as cl @@ -2002,6 +2001,80 @@ def call_bar(tracer, x, y): np.testing.assert_allclose(result_out[k], expect_out[k]) +def test_duplicate_node_count_dot_graph(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import get_random_pt_dag + from testlib import count_dot_graph_nodes + + for i in range(80): + # print("curr i:", i) + dag = get_random_pt_dag(seed=i, axis_len=5) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_duplicate_nodes_with_comm_count_dot_graph(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import get_random_pt_dag_with_send_recv_nodes + from testlib import count_dot_graph_nodes + + rank = 0 + size = 2 + for i in range(20): + dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + +def test_large_dot_graph_with_duplicates_count(): + from pytato.visualization.dot import get_dot_graph + from pytato.analysis import get_num_nodes + from testlib import make_large_dag + from testlib import count_dot_graph_nodes + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Generate dot graph with duplicates + dot_graph = get_dot_graph(dag, count_duplicates=True) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts with duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) + + # Generate dot graph without duplicates + dot_graph = get_dot_graph(dag, count_duplicates=False) + node_counts = count_dot_graph_nodes(dot_graph) + + # Verify node counts without duplicates + assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) + + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) diff --git a/test/test_pytato.py b/test/test_pytato.py index 45102158e..c34f14776 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -746,79 +746,6 @@ def test_large_dag_with_duplicates_count(): pt.transform.DependencyMapper()(dag)) -def test_duplicate_node_count_dot_graph(): - from pytato.visualization.dot import get_dot_graph - from pytato.analysis import get_num_nodes - from testlib import get_random_pt_dag - from testlib import count_dot_graph_nodes - - for i in range(80): - # print("curr i:", i) - dag = get_random_pt_dag(seed=i, axis_len=5) - - # Generate dot graph with duplicates - dot_graph = get_dot_graph(dag, count_duplicates=True) - node_counts = count_dot_graph_nodes(dot_graph) - - assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) - - # Generate dot graph without duplicates - dot_graph = get_dot_graph(dag, count_duplicates=False) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts without duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) - - -def test_duplicate_nodes_with_comm_count_dot_graph(): - from pytato.visualization.dot import get_dot_graph - from pytato.analysis import get_num_nodes - from testlib import get_random_pt_dag_with_send_recv_nodes - from testlib import count_dot_graph_nodes - - rank = 0 - size = 2 - for i in range(20): - dag = get_random_pt_dag_with_send_recv_nodes(seed=i, rank=rank, size=size) - - # Generate dot graph with duplicates - dot_graph = get_dot_graph(dag, count_duplicates=True) - node_counts = count_dot_graph_nodes(dot_graph) - - assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) - - # Generate dot graph without duplicates - dot_graph = get_dot_graph(dag, count_duplicates=False) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts without duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) - - -def test_large_dot_graph_with_duplicates_count(): - from pytato.visualization.dot import get_dot_graph - from pytato.analysis import get_num_nodes - from testlib import make_large_dag - from testlib import count_dot_graph_nodes - - iterations = 100 - dag = make_large_dag(iterations, seed=42) - - # Generate dot graph with duplicates - dot_graph = get_dot_graph(dag, count_duplicates=True) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts with duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=True) - - # Generate dot graph without duplicates - dot_graph = get_dot_graph(dag, count_duplicates=False) - node_counts = count_dot_graph_nodes(dot_graph) - - # Verify node counts without duplicates - assert len(node_counts) == get_num_nodes(dag, count_duplicates=False) - - def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) diff --git a/test/testlib.py b/test/testlib.py index f81c46682..7508eb203 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -339,15 +339,17 @@ def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: return pt.make_dict_of_named_arrays({"result": current}) -def count_dot_graph_nodes(dot_graph: str): +def count_dot_graph_nodes(dot_graph: str) -> Dict[Any, int]: """ - Parses a dot graph and returns a dictionary with the count of each unique node identifier. + Parses a dot graph and returns a dictionary with + the count of each unique node identifier. """ - node_pattern = re.compile(r'addr:(0x[0-9a-f]+)') + node_pattern = re.compile( + r'addr:(0x[0-9a-f]+)') nodes = node_pattern.findall(dot_graph) - node_counts = {} + node_counts: Dict[Any, int] = {} for node in nodes: node_counts[node] = node_counts.get(node, 0) + 1