diff --git a/src/gt4py/next/program_processors/runners/dace/library_nodes/__init__.py b/src/gt4py/next/program_processors/runners/dace/library_nodes/__init__.py new file mode 100644 index 0000000000..a357f7fea1 --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/library_nodes/__init__.py @@ -0,0 +1,24 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Final + +from dace import nodes as dace_nodes + +from gt4py.next.program_processors.runners.dace.library_nodes.reduce_with_skip_values import ( + ReduceWithSkipValues, +) + + +GTIR_LIBRARY_NODES: Final[tuple[dace_nodes.LibraryNode, ...]] = (ReduceWithSkipValues,) +"""List of available GTIR library nodes.""" + + +__all__ = [ + "ReduceWithSkipValues", +] diff --git a/src/gt4py/next/program_processors/runners/dace/library_nodes/reduce_with_skip_values.py b/src/gt4py/next/program_processors/runners/dace/library_nodes/reduce_with_skip_values.py new file mode 100644 index 0000000000..f7da8c92bb --- /dev/null +++ b/src/gt4py/next/program_processors/runners/dace/library_nodes/reduce_with_skip_values.py @@ -0,0 +1,175 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2024, ETH Zurich +# All rights reserved. +# +# Please, refer to the LICENSE file in the root directory. +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +from typing import Any, Final + +import dace +from dace import library as dace_library, properties as dace_properties +from dace.sdfg import graph as dace_graph +from dace.transformation import transformation as dace_transform + +from gt4py.next import common as gtx_common + + +@dace.library.node +class ReduceWithSkipValues(dace.sdfg.nodes.LibraryNode): + """Implements reduction with skip values.""" + + implementations: Final[dict[str, dace_transform.ExpandTransformation]] = {} + default_implementation: Final[str | None] = "pure" + + # Properties + wcr = dace_properties.LambdaProperty(allow_none=True) + identity = dace_properties.Property(allow_none=True) + init = dace_properties.Property(allow_none=True) + input_conn = dace_properties.Property(dtype=str) + output_conn = dace_properties.Property(dtype=str) + mask_conn = dace_properties.Property(dtype=str) + + def __init__( + self, + name: str, + wcr: str, + identity: dace.symbolic.SymbolicType, + init: dace.symbolic.SymbolicType, + input_conn: str, + output_conn: str, + mask_conn: str, + debuginfo: dace.dtypes.DebugInfo | None = None, + ) -> None: + super().__init__(name, inputs={input_conn, mask_conn}, outputs={output_conn}) + self.wcr = wcr + self.identity = identity + self.init = init + self.input_conn = input_conn + self.output_conn = output_conn + self.mask_conn = mask_conn + self.debuginfo = debuginfo + + def validate(self, sdfg: dace.SDFG, state: dace.SDFGState) -> None: + if len(list(state.in_edges_by_connector(self, self.input_conn))) != 1: + raise ValueError(f"Expected exactly one input edge on connector {self.input_conn}.") + inedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(self, self.input_conn) + ) + if len(list(state.out_edges_by_connector(self, self.output_conn))) != 1: + raise ValueError(f"Expected exactly one output edge on connector {self.output_conn}.") + outedge: dace_graph.MultiConnectorEdge = next( + state.out_edges_by_connector(self, self.output_conn) + ) + if len(list(state.in_edges_by_connector(self, self.mask_conn))) != 1: + raise ValueError(f"Expected exactly one input edge on connector {self.mask_conn}.") + maskedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(self, self.mask_conn) + ) + + mask_desc = sdfg.arrays[maskedge.data.data] + if len(mask_desc.shape) != 2: + raise ValueError(f"Invalid shape {mask_desc.shape} of mask array, expected 2d array.") + max_neighbors = mask_desc.shape[1] + if not (isinstance(max_neighbors, int) or str(max_neighbors).isdigit()): + raise ValueError( + f"Invalid shape {mask_desc.shape} of mask array, expected constant neighbors size." + ) + if ( + inedge.data.num_elements() != max_neighbors + or inedge.data.src_subset.size().count(max_neighbors) != 1 + ): + raise ValueError(f"Invalid memlet on input connector {self.input_conn}.") + if ( + maskedge.data.num_elements() != max_neighbors + or maskedge.data.src_subset.size().count(max_neighbors) != 1 + ): + raise ValueError(f"Invalid memlet on input connector {self.mask_conn}.") + if outedge.data.num_elements() != 1: + raise ValueError(f"Invalid memlet on output connector {self.output_conn}.") + + +@dace_library.register_expansion(ReduceWithSkipValues, "pure") +class ReduceWithSkipValuesExpandInlined(dace_transform.ExpandTransformation): + """Implements pure expansion of the ReduceWithSkipValues library node.""" + + environments: Final[list[Any]] = [] + + @staticmethod + def expansion(node: ReduceWithSkipValues, state: dace.SDFGState, sdfg: dace.SDFG) -> dace.SDFG: + assert len(list(state.in_edges_by_connector(node, node.input_conn))) == 1 + inedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(node, node.input_conn) + ) + assert len(list(state.out_edges_by_connector(node, node.output_conn))) == 1 + outedge: dace_graph.MultiConnectorEdge = next( + state.out_edges_by_connector(node, node.output_conn) + ) + assert len(list(state.in_edges_by_connector(node, node.mask_conn))) == 1 + maskedge: dace_graph.MultiConnectorEdge = next( + state.in_edges_by_connector(node, node.mask_conn) + ) + input_desc = sdfg.arrays[inedge.data.data] + output_desc = sdfg.arrays[outedge.data.data] + mask_desc = sdfg.arrays[maskedge.data.data] + assert len(mask_desc.shape) == 2 + max_neighbors = mask_desc.shape[1] + assert isinstance(max_neighbors, int) or str(max_neighbors).isdigit() + + # In validation, we already checked that the input subset collects exactly + # `max_neighbors` elements along one dimension. + local_dim_index = inedge.data.src_subset.size().index(max_neighbors) + + nsdfg = dace.SDFG(node.label) + inp, _ = nsdfg.add_array( + node.input_conn, + (max_neighbors,), + input_desc.dtype, + strides=(input_desc.strides[local_dim_index],), + ) + mask, _ = nsdfg.add_array( + node.mask_conn, + (max_neighbors,), + mask_desc.dtype, + strides=(mask_desc.strides[1],), + ) + outp, _ = nsdfg.add_scalar(node.output_conn, output_desc.dtype) + st_init = nsdfg.add_state("init") + init_tasklet = st_init.add_tasklet( + name="write", + inputs={}, + outputs={"__tlet_out"}, + code=f"__tlet_out = {input_desc.dtype}({node.init})", + ) + st_init.add_edge( + init_tasklet, + "__tlet_out", + st_init.add_access(outp), + None, + dace.Memlet(data=outp, subset="0"), + ) + st_reduce = nsdfg.add_state_after(st_init, "compute") + # Fill skip values in local dimension with the reduce identity value + skip_value = f"{input_desc.dtype}({node.identity})" + # Since this map operates on a pure local dimension, we explicitly set sequential + # schedule and we set the flag 'wcr_nonatomic=True' on the write memlet. + # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. + st_reduce.add_mapped_tasklet( + name="reduce_with_skip_values", + map_ranges={"i": f"0:{max_neighbors}"}, + inputs={ + "__tlet_inp": dace.Memlet(data=inp, subset="i"), + "__tlet_mask": dace.Memlet(data=mask, subset="i"), + }, + code=f"__tlet_out = __tlet_inp if __tlet_mask != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}", + outputs={ + "__tlet_out": dace.Memlet(data=outp, subset="0", wcr=node.wcr, wcr_nonatomic=True), + }, + external_edges=True, + schedule=dace.dtypes.ScheduleType.Sequential, + ) + + return nsdfg diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py index e8d1914aa8..c817d3593c 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_dataflow.py @@ -27,6 +27,7 @@ import dace from dace import nodes as dace_nodes, subsets as dace_subsets +from dace.libraries import standard as dace_stdlib from gt4py import eve from gt4py.eve.extended_typing import MaybeNestedInTuple, NestedTuple @@ -34,7 +35,10 @@ from gt4py.next.iterator import ir as gtir from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm, ir_makers as im from gt4py.next.iterator.transforms import symbol_ref_utils -from gt4py.next.program_processors.runners.dace import sdfg_args as gtx_dace_args +from gt4py.next.program_processors.runners.dace import ( + library_nodes as gtx_library_nodes, + sdfg_args as gtx_dace_args, +) from gt4py.next.program_processors.runners.dace.lowering import ( gtir_python_codegen, gtir_to_sdfg, @@ -1379,139 +1383,6 @@ def _visit_map(self, node: gtir.FunCall) -> ValueExpr: gt_dtype=ts.ListType(node.type.element_type, offset_type), ) - def _make_reduce_with_skip_values( - self, - input_expr: ValueExpr | MemletExpr, - offset_provider_type: gtx_common.NeighborConnectivityType, - reduce_init: SymbolExpr, - reduce_identity: SymbolExpr, - reduce_wcr: str, - result_node: dace_nodes.AccessNode, - ) -> None: - """ - Helper method to lower reduction on a local field containing skip values. - - The reduction is implemented as a nested SDFG containing 2 states. In first - state, the result (a scalar data node passed as argumet) is initialized. - In second state, a mapped tasklet uses a write-conflict resolution (wcr) - memlet to update the result. - We use the offset provider as a mask to identify skip values: the value - that is written to the result node is either the input value, when the - corresponding neighbor index in the connectivity table is valid, or the - identity value if the neighbor index is missing. - """ - origin_map_index = gtir_to_sdfg_utils.get_map_variable(offset_provider_type.source_dim) - - assert ( - isinstance(input_expr.gt_dtype, ts.ListType) - and input_expr.gt_dtype.offset_type is not None - ) - offset_type = input_expr.gt_dtype.offset_type - connectivity = gtx_dace_args.connectivity_identifier(offset_type.value) - connectivity_node = self.state.add_access(connectivity) - connectivity_desc = connectivity_node.desc(self.sdfg) - connectivity_desc.transient = False - - desc = input_expr.dc_node.desc(self.sdfg) - if isinstance(input_expr, MemletExpr): - local_dim_indices = [i for i, size in enumerate(input_expr.subset.size()) if size != 1] - else: - local_dim_indices = list(range(len(desc.shape))) - - if len(local_dim_indices) != 1: - raise NotImplementedError( - f"Found {len(local_dim_indices)} local dimensions in reduce expression, expected one." - ) - local_dim_index = local_dim_indices[0] - assert desc.shape[local_dim_index] == offset_provider_type.max_neighbors - - # we lower the reduction map with WCR out memlet in a nested SDFG - nsdfg = dace.SDFG(self.subgraph_builder.unique_nsdfg_name("reduce_with_skip_values")) - nsdfg.add_array( - "values", - (desc.shape[local_dim_index],), - desc.dtype, - strides=(desc.strides[local_dim_index],), - ) - nsdfg.add_array( - "neighbor_indices", - (connectivity_desc.shape[1],), - connectivity_desc.dtype, - strides=(connectivity_desc.strides[1],), - ) - nsdfg.add_scalar("acc", desc.dtype) - st_init = nsdfg.add_state(f"{nsdfg.label}_init") - init_tasklet, connector_mapping = self.subgraph_builder.add_tasklet( - name="init_acc", - sdfg=self.sdfg, - state=st_init, - inputs={}, - outputs={"val"}, - code=f"val = {reduce_init.dc_dtype}({reduce_init.value})", - ) - st_init.add_edge( - init_tasklet, - connector_mapping["val"], - st_init.add_access("acc"), - None, - dace.Memlet(data="acc", subset="0"), - ) - st_reduce = nsdfg.add_state_after(st_init, f"{nsdfg.label}_reduce") - # Fill skip values in local dimension with the reduce identity value - skip_value = f"{reduce_identity.dc_dtype}({reduce_identity.value})" - # Since this map operates on a pure local dimension, we explicitly set sequential - # schedule and we set the flag 'wcr_nonatomic=True' on the write memlet. - # TODO(phimuell): decide if auto-optimizer should reset `wcr_nonatomic` properties, as DaCe does. - self.subgraph_builder.add_mapped_tasklet( - name="reduce_with_skip_values", - sdfg=self.sdfg, - state=st_reduce, - map_ranges={"i": f"0:{offset_provider_type.max_neighbors}"}, - inputs={ - "val": dace.Memlet(data="values", subset="i"), - "neighbor_idx": dace.Memlet(data="neighbor_indices", subset="i"), - }, - code=f"out = val if neighbor_idx != {gtx_common._DEFAULT_SKIP_VALUE} else {skip_value}", - outputs={ - "out": dace.Memlet(data="acc", subset="0", wcr=reduce_wcr, wcr_nonatomic=True), - }, - external_edges=True, - schedule=dace.dtypes.ScheduleType.Sequential, - ) - - nsdfg_node = self.state.add_nested_sdfg( - nsdfg, inputs={"values", "neighbor_indices"}, outputs={"acc"} - ) - - if isinstance(input_expr, MemletExpr): - self._add_input_data_edge(input_expr.dc_node, input_expr.subset, nsdfg_node, "values") - else: - self.state.add_edge( - input_expr.dc_node, - None, - nsdfg_node, - "values", - self.sdfg.make_array_memlet(input_expr.dc_node.data), - ) - # The layout of connectivity tables is known. - assert len(offset_provider_type.domain) == 2 - assert offset_provider_type.domain[1].kind == gtx_common.DimensionKind.LOCAL - self._add_input_data_edge( - connectivity_node, - dace_subsets.Range.from_string( - f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" - ), - nsdfg_node, - "neighbor_indices", - ) - self.state.add_edge( - nsdfg_node, - "acc", - result_node, - None, - dace.Memlet(data=result_node.data, subset="0"), - ) - def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: assert isinstance(node.type, ts.ScalarType) op_name, reduce_init, reduce_identity = get_reduce_params(node) @@ -1530,28 +1401,65 @@ def _visit_reduce(self, node: gtir.FunCall) -> ValueExpr: offset_provider_type = self.subgraph_builder.get_offset_provider_type(offset_type.value) assert isinstance(offset_provider_type, gtx_common.NeighborConnectivityType) + inp_conn = "_in" + outp_conn = "_out" + mask_conn = "_mask" if offset_provider_type.has_skip_values: - self._make_reduce_with_skip_values( - input_expr, - offset_provider_type, - reduce_init, - reduce_identity, - reduce_wcr, - result_node, + assert ( + isinstance(input_expr.gt_dtype, ts.ListType) + and input_expr.gt_dtype.offset_type is not None ) + offset_type = input_expr.gt_dtype.offset_type + connectivity = gtx_dace_args.connectivity_identifier(offset_type.value) + self.sdfg.arrays[connectivity].transient = False + + reduce_node = gtx_library_nodes.ReduceWithSkipValues( + name=self.subgraph_builder.unique_lib_node_name("reduce_with_skip_values"), + wcr=reduce_wcr, + identity=reduce_identity.value, + init=reduce_init.value, + input_conn=inp_conn, + output_conn=outp_conn, + mask_conn=mask_conn, + debuginfo=gtir_to_sdfg_utils.debug_info(node), + ) + self.state.add_node(reduce_node) + origin_map_index = gtir_to_sdfg_utils.get_map_variable(offset_provider_type.source_dim) + self._add_input_data_edge( + self.state.add_access(connectivity), + dace_subsets.Range.from_string( + f"{origin_map_index}, 0:{offset_provider_type.max_neighbors}" + ), + reduce_node, + mask_conn, + ) else: - reduce_node = self.state.add_reduce(reduce_wcr, axes=None, identity=reduce_init.value) - if isinstance(input_expr, MemletExpr): - self._add_input_data_edge(input_expr.dc_node, input_expr.subset, reduce_node) - else: - self.state.add_nedge( - input_expr.dc_node, - reduce_node, - self.sdfg.make_array_memlet(input_expr.dc_node.data), - ) - self.state.add_nedge(reduce_node, result_node, dace.Memlet(data=result, subset="0")) + reduce_node = dace_stdlib.Reduce( + name=self.subgraph_builder.unique_lib_node_name("reduce"), + wcr=reduce_wcr, + axes=None, + identity=reduce_init.value, + debuginfo=gtir_to_sdfg_utils.debug_info(node), + inputs={inp_conn}, + outputs={outp_conn}, + ) + self.state.add_node(reduce_node) + + if isinstance(input_expr, MemletExpr): + self._add_input_data_edge(input_expr.dc_node, input_expr.subset, reduce_node, inp_conn) + else: + self.state.add_edge( + input_expr.dc_node, + None, + reduce_node, + inp_conn, + self.sdfg.make_array_memlet(input_expr.dc_node.data), + ) + self.state.add_edge( + reduce_node, outp_conn, result_node, None, dace.Memlet(data=result, subset="0") + ) return ValueExpr(result_node, node.type) def _split_shift_args( diff --git a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py index cb9b9c6d65..053f603d0b 100644 --- a/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace/lowering/gtir_to_sdfg.py @@ -73,6 +73,9 @@ def unique_tasklet_name(self, name: str) -> str: ... @abc.abstractmethod def unique_temp_name(self) -> str: ... + @abc.abstractmethod + def unique_lib_node_name(self, lib_node_type: str) -> str: ... + def add_temp_array( self, sdfg: dace.SDFG, shape: Sequence[Any], dtype: dace.dtypes.typeclass ) -> tuple[str, dace.data.Scalar]: @@ -118,12 +121,16 @@ def add_tasklet( code: str, language: dace.dtypes.Language = dace.dtypes.Language.Python, **kwargs: Any, - ) -> dace_nodes.Tasklet: + ) -> tuple[dace_nodes.Tasklet, dict[str, str]]: """Wrapper of `dace.SDFGState.add_tasklet` that assigns a unique name. It also modifies the tasklet connectors by adding a prefix string (see `gtir_to_sdfg_utils.get_tasklet_connector()`), in order to avoid name conflicts with SDFG data. Otherwise, SDFG validation would detect such conflicts and fail. + + Returns: + The created tasklet node and the mapping from original connector names to + modified connector names. """ if isinstance(inputs, set): inputs = {k: None for k in sorted(inputs)} @@ -161,6 +168,12 @@ def add_mapped_tasklet( """Wrapper of `dace.SDFGState.add_mapped_tasklet` that assigns a unique name. It also modifies the tasklet connectors, in the same way as `add_tasklet()`. + + Returns: + A tuple consisting of: + - The created tasklet node. + - The map entry and exit nodes of the created map. + - The mapping from original connector names to modified connector names. """ assert inputs.keys().isdisjoint(outputs.keys()) @@ -759,6 +772,9 @@ def unique_tasklet_name(self, name: str) -> str: def unique_temp_name(self) -> str: return f"{next(self.uids['gtir_tmp'])}" + def unique_lib_node_name(self, lib_node_type: str) -> str: + return f"{next(self.uids[lib_node_type])}" + def _make_array_shape_and_strides( self, name: str, dims: Sequence[gtx_common.Dimension] ) -> tuple[list[dace.symbolic.SymbolicType], list[dace.symbolic.SymbolicType]]: diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py index 799e8ad228..81f5bfb126 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/auto_optimize.py @@ -20,7 +20,10 @@ from dace.transformation.passes import analysis as dace_analysis from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace import transformations as gtx_transformations +from gt4py.next.program_processors.runners.dace import ( + library_nodes as gtx_library_nodes, + transformations as gtx_transformations, +) class GT4PyAutoOptHook(enum.Enum): @@ -369,6 +372,15 @@ def gt_auto_optimize( stacklevel=0, ) + # We now expand all GT4Py specific library nodes. + # We do this such that we have control over all the Maps that are there. + # TODO(edopao,phimuell): It is probably the right place, but maybe there is a better one. + for node, state in list(sdfg.all_nodes_recursive()): + if isinstance(node, gtx_library_nodes.GTIR_LIBRARY_NODES): + node.expand(state) + if validate_all: + sdfg.validate() + sdfg = _gt_auto_configure_maps_and_strides( sdfg=sdfg, gpu=gpu, @@ -751,8 +763,8 @@ def _gt_auto_process_dataflow_inside_maps( # NestedSDFGs inside the ConditionalBlocks it fuses. sdfg.apply_transformations_repeated( gtx_transformations.FuseHorizontalConditionBlocks(), - validate=True, - validate_all=True, + validate=False, + validate_all=validate_all, ) # Move dataflow into the branches of the `if` such that they are only evaluated diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py index 2f25d4f1c3..7109038612 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/loop_blocking.py @@ -16,6 +16,7 @@ subsets as dace_subsets, transformation as dace_transformation, ) +from dace.libraries import standard as dace_stdlib from dace.sdfg import ( graph as dace_graph, nodes as dace_nodes, @@ -25,7 +26,10 @@ from dace.transformation import helpers as dace_helpers from gt4py.next import common as gtx_common -from gt4py.next.program_processors.runners.dace import lowering as gtx_dace_lowering +from gt4py.next.program_processors.runners.dace import ( + library_nodes as gtx_lib, + lowering as gtx_dace_lowering, +) @dace_properties.make_properties @@ -492,7 +496,7 @@ def _classify_node( # set of new independent nodes. new_independent_nodes.update(map_scope.nodes()) - elif isinstance(node_to_classify, dace.libraries.standard.nodes.Reduce): + elif isinstance(node_to_classify, (dace_stdlib.Reduce, gtx_lib.ReduceWithSkipValues)): # The only checks we impose on them is the free symbols check and the # input output checks. pass diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py index 68a7c33201..c74c19092e 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/utils.py @@ -12,11 +12,14 @@ from typing import Optional, Sequence, TypeVar, Union import dace -from dace import data as dace_data, libraries as dace_lib, subsets as dace_sbs, symbolic as dace_sym +from dace import data as dace_data, subsets as dace_sbs, symbolic as dace_sym +from dace.libraries import standard as dace_stdlib from dace.sdfg import graph as dace_graph, nodes as dace_nodes from dace.transformation import pass_pipeline as dace_ppl from dace.transformation.passes import analysis as dace_analysis +from gt4py.next.program_processors.runners.dace import library_nodes as gtx_lib + _PassT = TypeVar("_PassT", bound=dace_ppl.Pass) @@ -555,7 +558,7 @@ def reconfigure_dataflow_after_rerouting( # the full array, but essentially slice a bit. pass - elif isinstance(other_node, dace_lib.standard.Reduce): + elif isinstance(other_node, (dace_stdlib.Reduce, gtx_lib.ReduceWithSkipValues)): # For now we only handle the case that the reduction node is writing into # `new_node`, before the data was written into `old_node`. In that case # there is nothing to do, we just do some checks.