-
Notifications
You must be signed in to change notification settings - Fork 16
Add duplicate node counter functionality and tests #508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
3e19358
Add node counter tests
kajalpatelinfo ea2402c
CI fixes
kajalpatelinfo b122aa9
Add comments
kajalpatelinfo 4a52c8d
Remove unnecessary test
kajalpatelinfo 570eda4
Add duplicate node functionality and tests
kajalpatelinfo d8dbe62
Remove incrementation for DictOfNamedArrays and update tests
kajalpatelinfo 84262cc
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo 178127c
Edit tests to account for not counting DictOfNamedArrays
kajalpatelinfo 326045e
Fix CI tests
kajalpatelinfo 6a0a2a9
Fix comments
kajalpatelinfo e235f8f
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo 0dca4d7
Clarify wording and clean up
kajalpatelinfo d695c9f
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo 9489ecf
Move `get_node_multiplicities` to its own mapper
kajalpatelinfo 27d6283
Add autofunction
kajalpatelinfo a89bf52
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo e3a2986
Linting
kajalpatelinfo 25c79a6
Add Dict typedef and format
kajalpatelinfo 0b56ea4
Format further
kajalpatelinfo 7f2e3ef
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo 6fdcfe5
Fix CI errors
kajalpatelinfo b4a8cb8
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo 275c609
Fix wording
kajalpatelinfo 4ca47b2
Implement new DAG generator with guaranteed duplicates
kajalpatelinfo 02917e8
Apply suggestions from code review
kajalpatelinfo 2c39189
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo 7e24f46
Ruff fixes
kajalpatelinfo 900937b
remove prints
majosm 00436f1
Apply suggestions from code review
kajalpatelinfo 8d8066f
Add explicit bool for count_duplicates
kajalpatelinfo f2b8e02
Merge branch 'main' into duplicate_node_counts
kajalpatelinfo 59ec433
Update test/testlib.py
kajalpatelinfo d8469b3
Seed random
kajalpatelinfo 572f382
Merge branch 'main' into duplicate_node_counts
majosm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -61,6 +61,10 @@ | |
|
|
||
| .. autofunction:: get_num_nodes | ||
|
|
||
| .. autofunction:: get_node_type_counts | ||
|
|
||
| .. autofunction:: get_node_multiplicities | ||
|
|
||
| .. autofunction:: get_num_call_sites | ||
|
|
||
| .. autoclass:: DirectPredecessorsGetter | ||
|
|
@@ -398,34 +402,115 @@ def map_named_call_result( | |
| @optimize_mapper(drop_args=True, drop_kwargs=True, inline_get_cache_key=True) | ||
| class NodeCountMapper(CachedWalkMapper): | ||
| """ | ||
| Counts the number of nodes in a DAG. | ||
| Counts the number of nodes of a given type in a DAG. | ||
|
|
||
| .. attribute:: count | ||
| .. autoattribute:: expr_type_counts | ||
| .. autoattribute:: count_duplicates | ||
|
|
||
| The number of nodes. | ||
| Dictionary mapping node types to number of nodes of that type. | ||
| """ | ||
|
|
||
| def __init__(self, count_duplicates: bool = False) -> None: | ||
| from collections import defaultdict | ||
| super().__init__() | ||
| self.expr_type_counts: dict[type[Any], int] = defaultdict(int) | ||
| self.count_duplicates = count_duplicates | ||
|
|
||
| def get_cache_key(self, expr: ArrayOrNames) -> int | ArrayOrNames: | ||
| # Returns unique nodes only if count_duplicates is False | ||
| return id(expr) if self.count_duplicates else expr | ||
|
|
||
| def post_visit(self, expr: Any) -> None: | ||
| if not isinstance(expr, DictOfNamedArrays): | ||
| self.expr_type_counts[type(expr)] += 1 | ||
|
|
||
|
|
||
| def get_node_type_counts( | ||
| outputs: Array | DictOfNamedArrays, | ||
| count_duplicates: bool = False | ||
| ) -> dict[type[Any], int]: | ||
| """ | ||
| Returns a dictionary mapping node types to node count for that type | ||
| in DAG *outputs*. | ||
|
|
||
| Instances of `DictOfNamedArrays` are excluded from counting. | ||
| """ | ||
|
|
||
| from pytato.codegen import normalize_outputs | ||
| outputs = normalize_outputs(outputs) | ||
|
|
||
| ncm = NodeCountMapper(count_duplicates) | ||
| ncm(outputs) | ||
|
|
||
| return ncm.expr_type_counts | ||
|
|
||
|
|
||
| def get_num_nodes( | ||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This breaks compatibility ( Maybe deprecate not specifying the argument. (@majosm can explain) |
||
| outputs: Array | DictOfNamedArrays, | ||
| count_duplicates: bool | None = None | ||
| ) -> int: | ||
| """ | ||
| Returns the number of nodes in DAG *outputs*. | ||
| Instances of `DictOfNamedArrays` are excluded from counting. | ||
| """ | ||
| if count_duplicates is None: | ||
| from warnings import warn | ||
| warn( | ||
| "The default value of 'count_duplicates' will change " | ||
| "from True to False in 2025. " | ||
| "For now, pass the desired value explicitly.", | ||
| DeprecationWarning, stacklevel=2) | ||
| count_duplicates = True | ||
|
|
||
| from pytato.codegen import normalize_outputs | ||
| outputs = normalize_outputs(outputs) | ||
|
|
||
| ncm = NodeCountMapper(count_duplicates) | ||
| ncm(outputs) | ||
|
|
||
| return sum(ncm.expr_type_counts.values()) | ||
|
|
||
| # }}} | ||
|
|
||
|
|
||
| # {{{ NodeMultiplicityMapper | ||
|
|
||
|
|
||
| class NodeMultiplicityMapper(CachedWalkMapper): | ||
| """ | ||
| Computes the multiplicity of each unique node in a DAG. | ||
|
|
||
| The multiplicity of a node `x` is the number of nodes with distinct `id()`\\ s | ||
| that equal `x`. | ||
|
|
||
| .. autoattribute:: expr_multiplicity_counts | ||
| """ | ||
| def __init__(self) -> None: | ||
| from collections import defaultdict | ||
| super().__init__() | ||
| self.count = 0 | ||
| self.expr_multiplicity_counts: dict[Array, int] = defaultdict(int) | ||
|
|
||
| def get_cache_key(self, expr: ArrayOrNames) -> int: | ||
| # Returns each node, including nodes that are duplicates | ||
| return id(expr) | ||
|
|
||
| def post_visit(self, expr: Any) -> None: | ||
| self.count += 1 | ||
| if not isinstance(expr, DictOfNamedArrays): | ||
| self.expr_multiplicity_counts[expr] += 1 | ||
|
|
||
|
|
||
| def get_num_nodes(outputs: Array | DictOfNamedArrays) -> int: | ||
| """Returns the number of nodes in DAG *outputs*.""" | ||
|
|
||
| def get_node_multiplicities( | ||
| outputs: Array | DictOfNamedArrays) -> dict[Array, int]: | ||
| """ | ||
| Returns the multiplicity per `expr`. | ||
| """ | ||
| from pytato.codegen import normalize_outputs | ||
| outputs = normalize_outputs(outputs) | ||
|
|
||
| ncm = NodeCountMapper() | ||
| ncm(outputs) | ||
| nmm = NodeMultiplicityMapper() | ||
| nmm(outputs) | ||
|
|
||
| return ncm.count | ||
| return nmm.expr_multiplicity_counts | ||
|
|
||
| # }}} | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.