-
Notifications
You must be signed in to change notification settings - Fork 57
refactor[next-dace]: Add library node for reduce with skip values #2603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e628361
c466998
308cfc7
63306a1
ba0190a
e89572d
54a0c54
ec731af
a9db106
81bdc32
14443e2
a8187b5
7d769d6
b47e948
01a05c7
f474149
243512e
5c1cd3b
ef1bb8e
4d15987
94a7595
b9f8d92
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # GT4Py - GridTools Framework | ||
| # | ||
| # Copyright (c) 2014-2024, ETH Zurich | ||
| # All rights reserved. | ||
| # | ||
| # Please, refer to the LICENSE file in the root directory. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from typing import Final | ||
|
|
||
| from dace import nodes as dace_nodes | ||
|
|
||
| from gt4py.next.program_processors.runners.dace.library_nodes.reduce_with_skip_values import ( | ||
| ReduceWithSkipValues, | ||
| ) | ||
|
|
||
|
|
||
| GTIR_LIBRARY_NODES: Final[tuple[dace_nodes.LibraryNode, ...]] = (ReduceWithSkipValues,) | ||
| """List of available GTIR library nodes.""" | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "ReduceWithSkipValues", | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| # GT4Py - GridTools Framework | ||
| # | ||
| # Copyright (c) 2014-2024, ETH Zurich | ||
| # All rights reserved. | ||
| # | ||
| # Please, refer to the LICENSE file in the root directory. | ||
| # SPDX-License-Identifier: BSD-3-Clause | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import Any, Final | ||
|
|
||
| import dace | ||
| from dace import library as dace_library, properties as dace_properties | ||
| from dace.sdfg import graph as dace_graph | ||
| from dace.transformation import transformation as dace_transform | ||
|
|
||
| from gt4py.next import common as gtx_common | ||
|
|
||
|
|
||
| @dace.library.node | ||
| class ReduceWithSkipValues(dace.sdfg.nodes.LibraryNode): | ||
| """Implements reduction with skip values.""" | ||
|
|
||
| implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {} | ||
| default_implementation: Final[str | None] = "pure" | ||
|
|
||
| # Properties | ||
| wcr = dace_properties.LambdaProperty(allow_none=True) | ||
| identity = dace_properties.Property(allow_none=True) | ||
| init = dace_properties.Property(allow_none=True) | ||
| input_conn = dace_properties.Property(dtype=str) | ||
| output_conn = dace_properties.Property(dtype=str) | ||
| mask_conn = dace_properties.Property(dtype=str) | ||
|
|
||
| def __init__( | ||
| self, | ||
| name: str, | ||
| wcr: str, | ||
| identity: dace.symbolic.SymbolicType, | ||
| init: dace.symbolic.SymbolicType, | ||
| input_conn: str, | ||
| output_conn: str, | ||
| mask_conn: str, | ||
| debuginfo: dace.dtypes.DebugInfo | None = None, | ||
| ) -> None: | ||
| super().__init__(name, inputs={input_conn, mask_conn}, outputs={output_conn}) | ||
| self.wcr = wcr | ||
| self.identity = identity | ||
| self.init = init | ||
| self.input_conn = input_conn | ||
| self.output_conn = output_conn | ||
| self.mask_conn = mask_conn | ||
| self.debuginfo = debuginfo | ||
|
|
||
| def validate(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: | ||
| if len(list(state.in_edges_by_connector(self, self.input_conn))) != 1: | ||
| raise ValueError(f"Expected exactly one input edge on connector {self.input_conn}.") | ||
| inedge: dace_graph.MultiConnectorEdge = next( | ||
| state.in_edges_by_connector(self, self.input_conn) | ||
| ) | ||
| if len(list(state.out_edges_by_connector(self, self.output_conn))) != 1: | ||
| raise ValueError(f"Expected exactly one output edge on connector {self.output_conn}.") | ||
| outedge: dace_graph.MultiConnectorEdge = next( | ||
| state.out_edges_by_connector(self, self.output_conn) | ||
| ) | ||
| if len(list(state.in_edges_by_connector(self, self.mask_conn))) != 1: | ||
| raise ValueError(f"Expected exactly one input edge on connector {self.mask_conn}.") | ||
| maskedge: dace_graph.MultiConnectorEdge = next( | ||
| state.in_edges_by_connector(self, self.mask_conn) | ||
| ) | ||
|
|
||
| mask_desc = sdfg.arrays[maskedge.data.data] | ||
| if len(mask_desc.shape) != 2: | ||
| raise ValueError(f"Invalid shape {mask_desc.shape} of mask array, expected 2d array.") | ||
| max_neighbors = mask_desc.shape[1] | ||
| if not (isinstance(max_neighbors, int) or str(max_neighbors).isdigit()): | ||
| raise ValueError( | ||
| f"Invalid shape {mask_desc.shape} of mask array, expected constant neighbors size." | ||
| ) | ||
| if ( | ||
| inedge.data.num_elements() != max_neighbors | ||
| or inedge.data.src_subset.size().count(max_neighbors) != 1 | ||
| ): | ||
| raise ValueError(f"Invalid memlet on input connector {self.input_conn}.") | ||
| if ( | ||
| maskedge.data.num_elements() != max_neighbors | ||
| or maskedge.data.src_subset.size().count(max_neighbors) != 1 | ||
| ): | ||
| raise ValueError(f"Invalid memlet on input connector {self.mask_conn}.") | ||
| if outedge.data.num_elements() != 1: | ||
| raise ValueError(f"Invalid memlet on output connector {self.output_conn}.") | ||
|
|
||
|
|
||
| @dace_library.register_expansion(ReduceWithSkipValues, "pure") | ||
| class ReduceWithSkipValuesExpandInlined(dace_transform.ExpandTransformation): | ||
| """Implements pure expansion of the ReduceWithSkipValues library node.""" | ||
|
|
||
| environments: Final[list[Any]] = [] | ||
|
|
||
| @staticmethod | ||
| def expansion(node: ReduceWithSkipValues, state: dace.SDFGState, sdfg: dace.SDFG) -> dace.SDFG: | ||
| assert len(list(state.in_edges_by_connector(node, node.input_conn))) == 1 | ||
| inedge: dace_graph.MultiConnectorEdge = next( | ||
| state.in_edges_by_connector(node, node.input_conn) | ||
| ) | ||
| assert len(list(state.out_edges_by_connector(node, node.output_conn))) == 1 | ||
| outedge: dace_graph.MultiConnectorEdge = next( | ||
| state.out_edges_by_connector(node, node.output_conn) | ||
| ) | ||
| assert len(list(state.in_edges_by_connector(node, node.mask_conn))) == 1 | ||
| maskedge: dace_graph.MultiConnectorEdge = next( | ||
| state.in_edges_by_connector(node, node.mask_conn) | ||
| ) | ||
| input_desc = sdfg.arrays[inedge.data.data] | ||
| output_desc = sdfg.arrays[outedge.data.data] | ||
| mask_desc = sdfg.arrays[maskedge.data.data] | ||
| assert len(mask_desc.shape) == 2 | ||
| max_neighbors = mask_desc.shape[1] | ||
| assert isinstance(max_neighbors, int) or str(max_neighbors).isdigit() | ||
|
|
||
| # In validation, we already checked that the input subset collects exactly | ||
| # `max_neighbors` elements along one dimension. | ||
| local_dim_index = inedge.data.src_subset.size().index(max_neighbors) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might break in cases where the number of {edges, cell, vertices} is the same as local dimensions.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have an idea. I can check that the size of the subset is 1 in all dimensions except one, the local dimension. I will push a new commit.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should work (except for the case that there is only a single neighbour, which is probably unlikely). |
||
|
|
||
| nsdfg = dace.SDFG(node.label) | ||
| inp, _ = nsdfg.add_array( | ||
| node.input_conn, | ||
| (max_neighbors,), | ||
| input_desc.dtype, | ||
| strides=(input_desc.strides[local_dim_index],), | ||
| ) | ||
| mask, _ = nsdfg.add_array( | ||
| node.mask_conn, | ||
| (max_neighbors,), | ||
| mask_desc.dtype, | ||
| strides=(mask_desc.strides[1],), | ||
| ) | ||
| outp, _ = nsdfg.add_scalar(node.output_conn, output_desc.dtype) | ||
| st_init = nsdfg.add_state("init") | ||
| init_tasklet = st_init.add_tasklet( | ||
| name="write", | ||
| inputs={}, | ||
| outputs={"__tlet_out"}, | ||
| code=f"__tlet_out = {input_desc.dtype}({node.init})", | ||
| ) | ||
| st_init.add_edge( | ||
| init_tasklet, | ||
| "__tlet_out", | ||
| st_init.add_access(outp), | ||
| None, | ||
| dace.Memlet(data=outp, subset="0"), | ||
| ) | ||
| st_reduce = nsdfg.add_state_after(st_init, "compute") | ||
| # Fill skip values in local dimension with the reduce identity value | ||
| skip_value = f"{input_desc.dtype}({node.identity})" | ||
| # Since this map operates on a pure local dimension, we explicitly set sequential | ||
| # schedule and we set the flag 'wcr_nonatomic=True' on the write memlet. | ||
| # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. | ||
| st_reduce.add_mapped_tasklet( | ||
| name="reduce_with_skip_values", | ||
| map_ranges={"i": f"0:{max_neighbors}"}, | ||
| inputs={ | ||
| "__tlet_inp": dace.Memlet(data=inp, subset="i"), | ||
| "__tlet_mask": dace.Memlet(data=mask, subset="i"), | ||
| }, | ||
| code=f"__tlet_out = __tlet_inp if __tlet_mask != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}", | ||
| outputs={ | ||
| "__tlet_out": dace.Memlet(data=outp, subset="0", wcr=node.wcr, wcr_nonatomic=True), | ||
| }, | ||
| external_edges=True, | ||
| schedule=dace.dtypes.ScheduleType.Sequential, | ||
| ) | ||
|
|
||
| return nsdfg | ||
Uh oh!
There was an error while loading. Please reload this page.