diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5bf374746..b822d6622 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,7 +26,7 @@ """ from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - TYPE_CHECKING) + Type, TYPE_CHECKING) from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, @@ -49,6 +49,8 @@ .. autofunction:: get_num_nodes +.. autofunction:: get_node_type_counts + .. autofunction:: get_num_call_sites .. autoclass:: DirectPredecessorsGetter @@ -381,26 +383,51 @@ def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]: @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NodeCountMapper(CachedWalkMapper): """ - Counts the number of nodes in a DAG. + Counts the number of nodes of a given type in a DAG. - .. attribute:: count + .. attribute:: counts - The number of nodes. + Dictionary mapping node types to number of nodes of that type. """ def __init__(self) -> None: + from collections import defaultdict super().__init__() - self.count = 0 + self.counts = defaultdict(int) # type: Dict[Type[Any], int] - def get_cache_key(self, expr: ArrayOrNames) -> int: - return id(expr) + def get_cache_key(self, expr: ArrayOrNames) -> ArrayOrNames: + # does NOT account for duplicate nodes + return expr def post_visit(self, expr: Any) -> None: - self.count += 1 + if not isinstance(expr, DictOfNamedArrays): + self.counts[type(expr)] += 1 + + +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays] + ) -> Dict[Type[Any], int]: + """ + Returns a dictionary mapping node types to node count for that type + in DAG *outputs*. + + Instances of `DictOfNamedArrays` are excluded from counting. + """ + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper() + ncm(outputs) + + return ncm.counts def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: - """Returns the number of nodes in DAG *outputs*.""" + """ + Returns the number of nodes in DAG *outputs*. + + Instances of `DictOfNamedArrays` are excluded from counting. + """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) @@ -408,7 +435,7 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: ncm = NodeCountMapper() ncm(outputs) - return ncm.count + return sum(ncm.counts.values()) # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index b9fc732b4..1484abf06 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -26,7 +26,6 @@ """ import sys - import numpy as np import pytest import attrs @@ -585,19 +584,90 @@ def test_repr_array_is_deterministic(): assert repr(dag) == repr(dag) -def test_nodecountmapper(): - from testlib import RandomDAGContext, make_random_dag +def test_empty_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + empty_dag = pt.make_dict_of_named_arrays({}) + + # Verify that get_num_nodes returns 0 for an empty DAG + assert get_num_nodes(empty_dag) == 0 + + counts = get_node_type_counts(empty_dag) + assert len(counts) == 0 + + +def test_single_node_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + data = np.random.rand(4, 4) + single_node_dag = pt.make_dict_of_named_arrays( + {"result": pt.make_data_wrapper(data)}) + + # Get counts per node type + node_counts = get_node_type_counts(single_node_dag) + + # Assert that there is only one node of type DataWrapper + assert node_counts == {pt.DataWrapper: 1} + + # Get total number of nodes + total_nodes = get_num_nodes(single_node_dag) + + assert total_nodes == 1 + + +def test_small_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + + # Make a DAG using two nodes and one operation + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + b = a + 1 + dag = pt.make_dict_of_named_arrays({"result": b}) # b = a + 1 + + # Verify that get_num_nodes returns 2 for a DAG with two nodes + assert get_num_nodes(dag) == 2 + + counts = get_node_type_counts(dag) + assert len(counts) == 2 + assert counts[pt.array.Placeholder] == 1 # "a" + assert counts[pt.array.IndexLambda] == 1 # single operation + + +def test_large_dag_count(): + from pytato.analysis import get_num_nodes, get_node_type_counts + from testlib import make_large_dag + + iterations = 100 + dag = make_large_dag(iterations, seed=42) + + # Verify that the number of nodes is equal to iterations + 1 (placeholder) + assert get_num_nodes(dag) == iterations + 1 + + counts = get_node_type_counts(dag) + assert len(counts) >= 1 + assert counts[pt.array.Placeholder] == 1 + assert counts[pt.array.IndexLambda] == 100 # 100 operations + assert sum(counts.values()) == iterations + 1 + + +def test_random_dag_count(): + from testlib import get_random_pt_dag from pytato.analysis import get_num_nodes + for i in range(80): + dag = get_random_pt_dag(seed=i, axis_len=5) - axis_len = 5 + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) + +def test_random_dag_with_comm_count(): + from testlib import get_random_pt_dag_with_send_recv_nodes + from pytato.analysis import get_num_nodes + rank = 0 + size = 2 for i in range(10): - rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=False) - dag = make_random_dag(rdagc) + dag = get_random_pt_dag_with_send_recv_nodes( + seed=i, rank=rank, size=size) - # Subtract 1 since NodeCountMapper counts an extra one for DictOfNamedArrays. - assert get_num_nodes(dag)-1 == len(pt.transform.DependencyMapper()(dag)) + assert get_num_nodes(dag) == len(pt.transform.DependencyMapper()(dag)) def test_rec_get_user_nodes(): diff --git a/test/testlib.py b/test/testlib.py index 73daf101f..3ec4fc6d7 100644 --- a/test/testlib.py +++ b/test/testlib.py @@ -311,6 +311,32 @@ def gen_comm(rdagc: RandomDAGContext) -> pt.Array: convert_dws_to_placeholders=convert_dws_to_placeholders, additional_generators=[(comm_fake_probability, gen_comm)]) + +def make_large_dag(iterations: int, seed: int = 0) -> pt.DictOfNamedArrays: + """ + Builds a DAG with emphasis on number of operations. + """ + import random + import operator + + rng = np.random.default_rng(seed) + random.seed(seed) + + # Begin with a placeholder + a = pt.make_placeholder(name="a", shape=(2, 2), dtype=np.float64) + current = a + + # Will randomly choose from the operators + operations = [operator.add, operator.sub, operator.mul, operator.truediv] + + for _ in range(iterations): + operation = random.choice(operations) + value = rng.uniform(1, 10) + current = operation(current, value) + + # DAG should have `iterations` number of operations + return pt.make_dict_of_named_arrays({"result": current}) + # }}}