From dfaaeeddf62e6c7015b060cb2ea62a6ec2d6f5f4 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Fri, 15 Mar 2024 17:27:36 -0500 Subject: [PATCH 1/3] Add node type counter --- pytato/analysis/__init__.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5bf374746..dce70e62f 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -413,6 +413,43 @@ def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} +# {{{ NodeTypeCountMapper + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class NodeTypeCountMapper(CachedWalkMapper): + """ + Counts the number of nodes of a given type in a DAG. + + .. attribute:: counts + + Dictionary mapping node types to number of nodes of that type. + """ + + def __init__(self) -> None: + super().__init__() + self.counts = {} + + def get_cache_key(self, expr: ArrayOrNames) -> int: + return id(expr) + + def post_visit(self, expr: Any) -> None: + self.counts[type(expr)] += 1 + + +def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> int: + """Returns the number of nodes of each given type in DAG *outputs*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeTypeCountMapper() + ncm(outputs) + + return ncm.counts + +# }}} + + # {{{ CallSiteCountMapper @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) From b9edcae372e4ad4c9fdbc2b65153bfe990e4ed54 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Fri, 15 Mar 2024 17:47:43 -0500 Subject: [PATCH 2/3] Fix some linting issues. --- pytato/analysis/__init__.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index dce70e62f..f56f29489 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -26,7 +26,7 @@ """ from typing import (Mapping, Dict, Union, Set, Tuple, Any, FrozenSet, - TYPE_CHECKING) + Type, TYPE_CHECKING) from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, DictOfNamedArrays, NamedArray, IndexBase, IndexRemappingBase, InputArgumentBase, @@ -426,18 +426,24 @@ class NodeTypeCountMapper(CachedWalkMapper): """ def __init__(self) -> None: + from collections import defaultdict super().__init__() - self.counts = {} + self.counts = defaultdict(int) def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) def post_visit(self, expr: Any) -> None: + if type(expr) not in self.counts: + self.counts[type(expr)] = 0 self.counts[type(expr)] += 1 -def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> int: - """Returns the number of nodes of each given type in DAG *outputs*.""" +def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: + """ + Returns a dictionary mapping node types to node count for that type + in DAG *outputs*. + """ from pytato.codegen import normalize_outputs outputs = normalize_outputs(outputs) From 42cb31299af77d459ef7e6147cf03be5da7e3b87 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Wed, 20 Mar 2024 08:27:13 -0500 Subject: [PATCH 3/3] Roll type counter into existing NodeCountMapper --- pytato/analysis/__init__.py | 52 +++++++++---------------------------- 1 file changed, 12 insertions(+), 40 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index f56f29489..4c18079be 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -380,43 +380,6 @@ def map_named_call_result(self, expr: NamedCallResult) -> FrozenSet[Array]: @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) class NodeCountMapper(CachedWalkMapper): - """ - Counts the number of nodes in a DAG. - - .. attribute:: count - - The number of nodes. - """ - - def __init__(self) -> None: - super().__init__() - self.count = 0 - - def get_cache_key(self, expr: ArrayOrNames) -> int: - return id(expr) - - def post_visit(self, expr: Any) -> None: - self.count += 1 - - -def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: - """Returns the number of nodes in DAG *outputs*.""" - - from pytato.codegen import normalize_outputs - outputs = normalize_outputs(outputs) - - ncm = NodeCountMapper() - ncm(outputs) - - return ncm.count - -# }}} - - -# {{{ NodeTypeCountMapper - -@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) -class NodeTypeCountMapper(CachedWalkMapper): """ Counts the number of nodes of a given type in a DAG. @@ -434,12 +397,10 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) def post_visit(self, expr: Any) -> None: - if type(expr) not in self.counts: - self.counts[type(expr)] = 0 self.counts[type(expr)] += 1 -def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: +def get_node_type_counts(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, int]: """ Returns a dictionary mapping node types to node count for that type in DAG *outputs*. @@ -453,6 +414,17 @@ def get_num_node_types(outputs: Union[Array, DictOfNamedArrays]) -> Dict[Type, i return ncm.counts +def get_num_nodes(outputs: Union[Array, DictOfNamedArrays]) -> int: + """Returns the number of nodes in DAG *outputs*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + ncm = NodeCountMapper() + ncm(outputs) + + return sum(ncm.counts.values()) + # }}}