diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 072ea8e66..fadf1c92d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -61,6 +61,10 @@ .. autofunction:: get_num_nodes +.. autofunction:: get_node_type_counts + +.. autofunction:: get_node_multiplicities + .. autofunction:: get_num_call_sites .. autoclass:: DirectPredecessorsGetter @@ -398,34 +402,115 @@ def map_named_call_result( @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NodeCountMapper(CachedWalkMapper): """ - Counts the number of nodes in a DAG. + Counts the number of nodes of a given type in a DAG. - .. attribute:: count + .. autoattribute:: expr_type_counts + .. autoattribute:: count_duplicates - The number of nodes. + Dictionary mapping node types to number of nodes of that type. + """ + + def __init__(self, count_duplicates: bool = False) -> None: + from collections import defaultdict + super().__init__() + self.expr_type_counts: dict[type[Any], int] = defaultdict(int) + self.count_duplicates = count_duplicates + + def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: + # Returns unique nodes only if count_duplicates is False + return id(expr) if self.count_duplicates else expr + + def post_visit(self, expr: Any) -> None: + if not isinstance(expr, DictOfNamedArrays): + self.expr_type_counts[type(expr)] += 1 + + +def get_node_type_counts( + outputs: Array | DictOfNamedArrays, + count_duplicates: bool = False + ) -> dict[type[Any], int]: + """ + Returns a dictionary mapping node types to node count for that type + in DAG *outputs*. + + Instances of `DictOfNamedArrays` are excluded from counting. """ + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper(count_duplicates) + ncm(outputs) + + return ncm.expr_type_counts + + +def get_num_nodes( + outputs: Array | DictOfNamedArrays, + count_duplicates: bool | None = None + ) -> int: + """ + Returns the number of nodes in DAG *outputs*. + Instances of `DictOfNamedArrays` are excluded from counting. + """ + if count_duplicates is None: + from warnings import warn + warn( + "The default value of 'count_duplicates' will change " + "from True to False in 2025. " + "For now, pass the desired value explicitly.", + DeprecationWarning, stacklevel=2) + count_duplicates = True + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper(count_duplicates) + ncm(outputs) + + return sum(ncm.expr_type_counts.values()) + +# }}} + + +# {{{ NodeMultiplicityMapper + + +class NodeMultiplicityMapper(CachedWalkMapper): + """ + Computes the multiplicity of each unique node in a DAG. + + The multiplicity of a node `x` is the number of nodes with distinct `id()`\\ s + that equal `x`. + + .. autoattribute:: expr_multiplicity_counts + """ def __init__(self) -> None: + from collections import defaultdict super().__init__() - self.count = 0 + self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: + # Returns each node, including nodes that are duplicates return id(expr) def post_visit(self, expr: Any) -> None: - self.count += 1 + if not isinstance(expr, DictOfNamedArrays): + self.expr_multiplicity_counts[expr] += 1 -def get_num_nodes(outputs: Array | DictOfNamedArrays) -> int: - """Returns the number of nodes in DAG *outputs*.""" - +def get_node_multiplicities( + outputs: Array | DictOfNamedArrays) -> dict[Array, int]: + """ + Returns the multiplicity per `expr`. + """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) - ncm = NodeCountMapper() - ncm(outputs) + nmm = NodeMultiplicityMapper() + nmm(outputs) - return ncm.count + return nmm.expr_multiplicity_counts # }}} diff --git a/pytato/distributed/verify.py b/pytato/distributed/verify.py index 6a0ed80c0..730cf346c 100644 --- a/pytato/distributed/verify.py +++ b/pytato/distributed/verify.py @@ -194,12 +194,14 @@ def _run_partition_diagnostics( from pytato.analysis import get_num_nodes num_nodes_per_part = [get_num_nodes(make_dict_of_named_arrays( - {x: gp.name_to_output[x] for x in part.output_names})) + {x: gp.name_to_output[x] for x in part.output_names}), + count_duplicates=False) for part in gp.parts.values()] - logger.info(f"find_distributed_partition: Split {get_num_nodes(outputs)} nodes " - f"into {len(gp.parts)} parts, with {num_nodes_per_part} nodes in each " - "partition.") + logger.info("find_distributed_partition: " + f"Split {get_num_nodes(outputs, count_duplicates=False)} nodes " + f"into {len(gp.parts)} parts, with {num_nodes_per_part} nodes in each " + "partition.") # }}} diff --git a/test/test_codegen.py b/test/test_codegen.py index 7e72809aa..3a6f2b7c1 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1611,7 +1611,7 @@ def get_np_input_args(): _, (pt_result,) = knl(cq) from pytato.analysis import get_num_nodes - print(get_num_nodes(pt_dag)) + print(get_num_nodes(pt_dag, count_duplicates=False)) np.testing.assert_allclose(pt_result, np_result) @@ -1637,8 +1637,9 @@ def test_zero_size_cl_array_dedup(ctx_factory): dedup_dw_out = pt.transform.deduplicate_data_wrappers(out) - num_nodes_old = pt.analysis.get_num_nodes(out) - num_nodes_new = pt.analysis.get_num_nodes(dedup_dw_out) + num_nodes_old = pt.analysis.get_num_nodes(out, count_duplicates=True) + num_nodes_new = pt.analysis.get_num_nodes( + dedup_dw_out, count_duplicates=True) # 'x2' would be merged with 'x1' as both of them point to the same data # 'x3' would be merged with 'x4' as both of them point to the same data assert num_nodes_new == (num_nodes_old - 2) diff --git a/test/test_pytato.py b/test/test_pytato.py index b22e8277e..f67e7e5f1 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -598,20 +598,171 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) -def test_nodecountmapper(): - from testlib import RandomDAGContext, make_random_dag +def test_empty_dag_count(): + from pytato.analysis import get_node_type_counts, get_num_nodes + + empty_dag = pt.make_dict_of_named_arrays({}) + + # Verify that get_num_nodes returns 0 for an empty DAG + assert get_num_nodes(empty_dag, count_duplicates=False) == 0 + + counts = get_node_type_counts(empty_dag) + assert len(counts) == 0 + + +def test_single_node_dag_count(): + from pytato.analysis import get_node_type_counts, get_num_nodes + + data = np.random.rand(4, 4) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) + + # Get counts per node type + node_counts = get_node_type_counts(single_node_dag) + + # Assert that there is only one node of type DataWrapper + assert node_counts == {pt.DataWrapper: 1} + + # Get total number of nodes + total_nodes = get_num_nodes(single_node_dag, count_duplicates=False) + + assert total_nodes == 1 + + +def test_small_dag_count(): + from pytato.analysis import get_node_type_counts, get_num_nodes + + # Make a DAG using two nodes and one operation + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + + # Verify that get_num_nodes returns 2 for a DAG with two nodes + assert get_num_nodes(dag, count_duplicates=False) == 2 + + counts = get_node_type_counts(dag) + assert len(counts) == 2 + assert counts[pt.array.Placeholder] == 1 # "a" + assert counts[pt.array.IndexLambda] == 1 # single operation + + +def test_large_dag_count(): + from testlib import make_large_dag + + from pytato.analysis import get_node_type_counts, get_num_nodes + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of nodes is equal to iterations + 1 (placeholder) + assert get_num_nodes(dag, count_duplicates=False) == iterations + 1 + + counts = get_node_type_counts(dag) + assert len(counts) >= 1 + assert counts[pt.array.Placeholder] == 1 + assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert sum(counts.values()) == iterations + 1 + + +def test_random_dag_count(): + from testlib import get_random_pt_dag from pytato.analysis import get_num_nodes + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) - axis_len = 5 + assert get_num_nodes(dag, count_duplicates=False) == len( + pt.transform.DependencyMapper()(dag)) + + +def test_random_dag_with_comm_count(): + from testlib import get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes + rank = 0 + size = 2 for i in range(10): - rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=False) - dag = make_random_dag(rdagc) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) + + assert get_num_nodes(dag, count_duplicates=False) == len( + pt.transform.DependencyMapper()(dag)) + + +def test_small_dag_with_duplicates_count(): + from testlib import make_small_dag_with_duplicates + + from pytato.analysis import ( + get_node_multiplicities, + get_node_type_counts, + get_num_nodes, + ) + + dag = make_small_dag_with_duplicates() + + # Get the number of expressions, including duplicates + node_count = get_num_nodes(dag, count_duplicates=True) + expected_node_count = 4 + assert node_count == expected_node_count + + # Get the number of occurrences of each unique expression + node_multiplicity = get_node_multiplicities(dag) + assert any(count > 1 for count in node_multiplicity.values()) + + # Get difference in duplicates + num_duplicates = sum(count - 1 for count in node_multiplicity.values()) + + counts = get_node_type_counts(dag, count_duplicates=True) + expected_counts = { + pt.array.Placeholder: 1, + pt.array.IndexLambda: 3 + } + + for node_type, expected_count in expected_counts.items(): + assert counts[node_type] == expected_count + + # Check that duplicates are correctly calculated + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == get_num_nodes( + dag, count_duplicates=False) + + +def test_large_dag_with_duplicates_count(): + from testlib import make_large_dag_with_duplicates + + from pytato.analysis import ( + get_node_multiplicities, + get_node_type_counts, + get_num_nodes, + ) + + iterations = 100 + dag = make_large_dag_with_duplicates(iterations, seed=42) + + # Get the number of expressions, including duplicates + node_count = get_num_nodes(dag, count_duplicates=True) + + # Get the number of occurrences of each unique expression + node_multiplicity = get_node_multiplicities(dag) + assert any(count > 1 for count in node_multiplicity.values()) + + expected_node_count = sum(count for count in node_multiplicity.values()) + assert node_count == expected_node_count + + # Get difference in duplicates + num_duplicates = sum(count - 1 for count in node_multiplicity.values()) + + counts = get_node_type_counts(dag, count_duplicates=True) + + assert counts[pt.array.Placeholder] == 1 + assert sum(counts.values()) == expected_node_count - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) + # Check that duplicates are correctly calculated + assert node_count - num_duplicates == len( + pt.transform.DependencyMapper()(dag)) + assert node_count - num_duplicates == get_num_nodes( + dag, count_duplicates=False) def test_rec_get_user_nodes(): diff --git a/test/testlib.py b/test/testlib.py index 3a413a225..53bf79436 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -1,6 +1,7 @@ from __future__ import annotations import operator +import random import types from typing import Any, Callable, Sequence @@ -327,6 +328,73 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: convert_dws_to_placeholders=convert_dws_to_placeholders, additional_generators=[(comm_fake_probability, gen_comm)]) + +def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: + """ + Builds a DAG with emphasis on number of operations. + """ + + rng = np.random.default_rng(seed) + random.seed(seed) + + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a + + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] + + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) + + # DAG should have `iterations` number of operations + return pt.make_dict_of_named_arrays({"result": current}) + + +def make_small_dag_with_duplicates() -> pt.DictOfNamedArrays: + x = pt.make_placeholder(name="x", shape=(2, 2), dtype=np.float64) + + expr1 = 2 * x + expr2 = 2 * x + + y = expr1 + expr2 + + # Has duplicates of the 2*x operation + return pt.make_dict_of_named_arrays({"result": y}) + + +def make_large_dag_with_duplicates(iterations: int, + seed: int = 0) -> pt.DictOfNamedArrays: + + random.seed(seed) + rng = np.random.default_rng(seed) + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a + + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] + duplicates = [] + + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) + + # Introduce duplicates intentionally + if rng.uniform() > 0.2: + dup1 = operation(a, value) + dup2 = operation(a, value) + duplicates.append(dup1) + duplicates.append(dup2) + current = operation(current, dup1) + + all_exprs = [current, *duplicates] + combined_expr = pt.stack(all_exprs, axis=0) + + result = pt.sum(combined_expr, axis=0) + return pt.make_dict_of_named_arrays({"result": result}) + # }}}