From 08c222380d248d47fe31aa341d4f8e30792f88d2 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 19 Jul 2023 08:20:37 -0700 Subject: [PATCH 1/4] add function to compute the max node depth of a DAG --- pytato/analysis/__init__.py | 48 +++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 6db21e863..6d8a2427f 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -66,6 +66,7 @@ .. autofunction:: is_einsum_similar_to_subscript .. autofunction:: get_num_nodes +.. autofunction:: get_max_node_depth .. autofunction:: get_node_type_counts @@ -545,6 +546,53 @@ def get_node_multiplicities( # }}} +# {{{ NodeMaxDepthMapper + +@optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) +class NodeMaxDepthMapper(CachedWalkMapper): + """ + Finds the maximum depth of a node in a DAG. + + .. attribute:: max_depth + + The depth of the deepest node. + """ + + def __init__(self) -> None: + super().__init__() + self.depth = 0 + self.max_depth = 0 + + # FIXME: Do I need this? + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] + return id(expr) + + def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> None: + """Call the mapper method of *expr* and return the result.""" + self.depth += 1 + self.max_depth = max(self.max_depth, self.depth) + + try: + super().rec(expr, *args, **kwargs) + finally: + self.depth -= 1 + + +def get_max_node_depth(outputs: Union[Array, DictOfNamedArrays]) -> int: + """Finds the maximum depth of a node in *outputs*.""" + + from pytato.codegen import normalize_outputs + outputs = normalize_outputs(outputs) + + nmdm = NodeMaxDepthMapper() + nmdm(outputs) + + return nmdm.max_depth + +# }}} + + # {{{ CallSiteCountMapper @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) From 4eaa21d23c0d22380ab5429caced80e03d007934 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 19 Jul 2023 10:52:17 -0700 Subject: [PATCH 2/4] fix off-by-one error and add test --- pytato/analysis/__init__.py | 5 +++-- test/test_pytato.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 6d8a2427f..4447a688f 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -560,8 +560,9 @@ class NodeMaxDepthMapper(CachedWalkMapper): def __init__(self) -> None: super().__init__() - self.depth = 0 - self.max_depth = 0 + # Want the first rec() call to increment to 0, so start at -1 + self.depth = -1 + self.max_depth = -1 # FIXME: Do I need this? # type-ignore-reason: dropped the extra `*args, **kwargs`. diff --git a/test/test_pytato.py b/test/test_pytato.py index 3d16e28d9..dbac5bd0a 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -765,6 +765,16 @@ def test_large_dag_with_duplicates_count(): dag, count_duplicates=False) +def test_nodemaxdepthmapper(): + from pytato.analysis import get_max_node_depth + + x = pt.make_placeholder("x", shape=(10, 4), dtype=np.float64) + for i in range(9): + x = x + 1 + + assert get_max_node_depth(x) == 10 + + def test_rec_get_user_nodes(): x1 = pt.make_placeholder("x1", shape=(10, 4)) x2 = pt.make_placeholder("x2", shape=(10, 4)) From 37e66f730cdbab6828e0323c1eb03598f5135549 Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 19 Jul 2023 10:55:39 -0700 Subject: [PATCH 3/4] flake8 --- test/test_pytato.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index dbac5bd0a..1c27be659 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -769,7 +769,7 @@ def test_nodemaxdepthmapper(): from pytato.analysis import get_max_node_depth x = pt.make_placeholder("x", shape=(10, 4), dtype=np.float64) - for i in range(9): + for _ in range(9): x = x + 1 assert get_max_node_depth(x) == 10 From 271ec0e1efa4c19e59acb53063f2e0cccff5ff1c Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Wed, 19 Jul 2023 11:29:20 -0700 Subject: [PATCH 4/4] fix more bugs --- pytato/analysis/__init__.py | 15 +++++++++------ test/test_pytato.py | 5 ++++- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 4447a688f..4cdde2c82 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -571,13 +571,16 @@ def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] def rec(self, expr: ArrayOrNames, *args: Any, **kwargs: Any) -> None: """Call the mapper method of *expr* and return the result.""" - self.depth += 1 - self.max_depth = max(self.max_depth, self.depth) - - try: + if isinstance(expr, DictOfNamedArrays): super().rec(expr, *args, **kwargs) - finally: - self.depth -= 1 + else: + self.depth += 1 + self.max_depth = max(self.max_depth, self.depth) + + try: + super().rec(expr, *args, **kwargs) + finally: + self.depth -= 1 def get_max_node_depth(outputs: Union[Array, DictOfNamedArrays]) -> int: diff --git a/test/test_pytato.py b/test/test_pytato.py index 1c27be659..da75bcc89 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -769,7 +769,10 @@ def test_nodemaxdepthmapper(): from pytato.analysis import get_max_node_depth x = pt.make_placeholder("x", shape=(10, 4), dtype=np.float64) - for _ in range(9): + + assert get_max_node_depth(x) == 0 + + for _ in range(10): x = x + 1 assert get_max_node_depth(x) == 10