From 6ae1b6a5417ba86ae3ca58965417700ab2735e08 Mon Sep 17 00:00:00 2001 From: Baris Demir Date: Thu, 19 Mar 2026 23:12:21 +0000 Subject: [PATCH] Arm backend: Preserve duplicate output slots with TOSA identity fanout When FuseEqualPlaceholdersPass fuses equal constant placeholders, the graph output can contain the same node in multiple output slots. In this case ToTosaMemoryFormatPass was rewriting the output node with replace_input_with() while inserting output transposes. That rewrote all matching occurrences at once, so duplicated logical output slots were collapsed onto the same transpose node instead of remaining distinct. Fix this by handling duplicate outputs in the output rewrite path. For shared output nodes, create a single boundary TOSA TRANSPOSE and preserve distinct output slots by inserting TOSA IDENTITY fanout nodes for later duplicates. This keeps insert_input_transpose() focused on normal input rewrites, avoids duplicating equivalent transposes for shared outputs, and preserves the output slot structure expected by later lowering and serialization stages. Add regression coverage for FuseEqualPlaceholdersPass + ToTosaMemoryFormatPass with duplicate outputs, and add TOSA IDENTITY dialect and visitor coverage. Signed-off-by: Baris Demir Change-Id: Ie14bc88bfadaad7f993b71ef1b5332b5953b72c8 --- backends/arm/_passes/__init__.py | 1 + backends/arm/_passes/arm_pass_manager.py | 2 + .../ensure_unique_output_nodes_pass.py | 82 +++++++++++++++++++ .../arm/_passes/to_tosa_memory_format_pass.py | 10 +-- backends/arm/operators/__init__.py | 1 + backends/arm/operators/op_tosa_identity.py | 61 ++++++++++++++ .../test/misc/test_tosa_dialect_identity.py | 25 ++++++ .../test_ensure_unique_output_nodes_pass.py | 66 +++++++++++++++ .../test/passes/test_to_tosa_memory_format.py | 59 +++++++++++++ backends/arm/tosa/dialect/__init__.py | 1 + backends/arm/tosa/dialect/ops/identity.py | 16 ++++ 11 files changed, 319 insertions(+), 5 deletions(-) create mode 100644 backends/arm/_passes/ensure_unique_output_nodes_pass.py create mode 100644 backends/arm/operators/op_tosa_identity.py create mode 100644 backends/arm/test/misc/test_tosa_dialect_identity.py create mode 100644 backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py create mode 100644 backends/arm/tosa/dialect/ops/identity.py diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1bd18de581d..90d626bcb2a 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -97,6 +97,7 @@ from .decompose_var_pass import DecomposeVarPass # noqa from .decompose_where_scalar_other_pass import DecomposeWhereScalarOtherPass # noqa from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa +from .ensure_unique_output_nodes_pass import EnsureUniqueOutputNodesPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, QuantizeClampArgumentsPass, diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 31cb7a2e2c7..a51e8dd9c65 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -96,6 +96,7 @@ DecomposeVarPass, DecomposeWhereScalarOtherPass, DecorateFp32toInt32CastingPass, + EnsureUniqueOutputNodesPass, FoldAndAnnotateQParamsPass, FuseBatchNorm2dPass, FuseConsecutiveConcatShapesPass, @@ -502,6 +503,7 @@ def _tosa_pipeline( FuseEqualPlaceholdersPass(exported_program), FuseConsecutiveConcatShapesPass(), ToTosaMemoryFormatPass(exported_program), + EnsureUniqueOutputNodesPass(), RemoveNoopPass(), InsertRescalePass(), ] diff --git a/backends/arm/_passes/ensure_unique_output_nodes_pass.py b/backends/arm/_passes/ensure_unique_output_nodes_pass.py new file mode 100644 index 00000000000..af0435da6fd --- /dev/null +++ b/backends/arm/_passes/ensure_unique_output_nodes_pass.py @@ -0,0 +1,82 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import Counter +from typing import Any, Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class EnsureUniqueOutputNodesPass(ArmPass): + """Ensure each graph output leaf references a unique producer node. + + If the same node appears multiple times in the output structure, insert a + ``tosa.IDENTITY`` node for each occurrence and replace the repeated output + entries with those identity nodes. + + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + @staticmethod + def _collect_output_nodes( + output_value: Any, counts: Counter[torch.fx.Node] + ) -> None: + if isinstance(output_value, torch.fx.Node): + counts[output_value] += 1 + return + if isinstance(output_value, (list, tuple)): + for value in output_value: + EnsureUniqueOutputNodesPass._collect_output_nodes(value, counts) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + output_node = graph.output_node() + output_value = output_node.args[0] + + counts: Counter[torch.fx.Node] = Counter() + self._collect_output_nodes(output_value, counts) + repeated_nodes = {node for node, count in counts.items() if count > 1} + if not repeated_nodes: + return PassResult(graph_module, False) + + modified = False + + def _replace_repeated_outputs(value: Any) -> Any: + nonlocal modified + if isinstance(value, torch.fx.Node): + if value not in repeated_nodes: + return value + with graph.inserting_before(output_node): + identity_node = create_node( + graph, + exir_ops.backend.tosa.IDENTITY.default, + args=(value,), + from_node=value, + ) + modified = True + return identity_node + + if isinstance(value, tuple): + return tuple(_replace_repeated_outputs(v) for v in value) + + if isinstance(value, list): + return [_replace_repeated_outputs(v) for v in value] + + return value + + new_output_value = _replace_repeated_outputs(output_value) + if modified: + output_node.args = (new_output_value,) + graph.eliminate_dead_code() + graph.lint() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 0f32fbb52df..c102c87ba70 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -264,7 +264,6 @@ def insert_output_transpose(node, graph_module): """Convert a producer's output to channels-last by appending a backend `TRANSPOSE` node and rewiring its users. """ - rank = len(get_first_fake_tensor(node).size()) spatial_rank = node.meta["tosa_spatial_rank"] mem_format = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank) @@ -383,17 +382,18 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): if output_dim_orders is None: raise RuntimeError(f"{output_dim_orders=} is not supported.") + transposed_output_inputs: set[torch.fx.Node] = set() for output_node_input, output_dim_order in zip( outputs, output_dim_orders, strict=True ): - if output_dim_order in ( - NCHW_ORDER, - NNCHW_ORDER, - NNNCHW_ORDER, + if ( + output_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER) + and output_node_input not in transposed_output_inputs ): self.insert_input_transpose( output_node, output_node_input, graph_module ) + transposed_output_inputs.add(output_node_input) def remove_dim_order_kwargs( self, graph_module: torch.fx.GraphModule, node: torch.fx.Node diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index c7f9da2ccd4..391ed0a82c2 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -53,6 +53,7 @@ op_tosa_conv3d, op_tosa_depthwise_conv2d, op_tosa_gather, + op_tosa_identity, op_tosa_matmul, op_tosa_pad, op_tosa_rescale, diff --git a/backends/arm/operators/op_tosa_identity.py b/backends/arm/operators/op_tosa_identity.py new file mode 100644 index 00000000000..81d31186ec4 --- /dev/null +++ b/backends/arm/operators/op_tosa_identity.py @@ -0,0 +1,61 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class IdentityVisitor(NodeVisitor): + """Lower the TOSA IDENTITY op.""" + + target = "tosa.IDENTITY.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 1) + validate_same_dtype(self.target, [inputs[0], output], ts) + validate_valid_dtype( + self.target, + [inputs[0], output], + [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.FP16, + ts.DType.FP32, + ts.DType.BF16, + ], + self.tosa_spec, + ) + + attr = ts.TosaSerializerAttribute() + self._serialize_operator( + node, + tosa_graph, + ts.Op.IDENTITY, + [inputs[0].name], + [output.name], + attr, + ) diff --git a/backends/arm/test/misc/test_tosa_dialect_identity.py b/backends/arm/test/misc/test_tosa_dialect_identity.py new file mode 100644 index 00000000000..19461cb676c --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_identity.py @@ -0,0 +1,25 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import torch +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_identity_tosa_FP() -> None: + sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ), FakeTensorMode() as mode: + output = exir_ops.backend.tosa.IDENTITY.default(mode.from_tensor(sample_input)) + + assert output.dtype == sample_input.dtype + assert tuple(output.shape) == tuple(sample_input.shape) diff --git a/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py b/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py new file mode 100644 index 00000000000..4dd03c1ca6e --- /dev/null +++ b/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py @@ -0,0 +1,66 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes import EnsureUniqueOutputNodesPass +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.test.harness.stages import StageType +from executorch.exir.dialects._ops import ops as exir_ops + + +class DuplicateOutputModule(torch.nn.Module): + def forward(self, x: torch.Tensor): + y = x + 1.0 + return y, y + + +class UniqueOutputModule(torch.nn.Module): + def forward(self, x: torch.Tensor): + y = x + 1.0 + z = x + 2.0 + return y, z + + +def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_output() -> ( + None +): + pipeline = PassPipeline[tuple[torch.Tensor]]( + DuplicateOutputModule(), + (torch.rand(2, 2),), + quantize=False, + pass_list=[EnsureUniqueOutputNodesPass], + ops_after_pass={ + "executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2, + }, + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + graph_module = ( + pipeline.tester.get_artifact(StageType.RUN_PASSES) + .exported_program() + .graph_module + ) + output_node = graph_module.graph.output_node() + outputs = list(output_node.args[0]) + + assert outputs[0] is not outputs[1] + assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default + assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default + assert outputs[0].args[0] is outputs[1].args[0] + + +def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() -> None: + pipeline = PassPipeline[tuple[torch.Tensor]]( + UniqueOutputModule(), + (torch.rand(2, 2),), + quantize=False, + pass_list=[EnsureUniqueOutputNodesPass], + ops_not_after_pass=[ + "executorch_exir_dialects_backend__ops_tosa_IDENTITY_default", + ], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index dfd57aa7e61..a035ac61f9e 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -8,15 +8,19 @@ import torch from executorch.backends.arm._passes import ( AnnotateOutputDimOrderPass, + EnsureUniqueOutputNodesPass, + FuseEqualPlaceholdersPass, ToTosaMemoryFormatPass, ) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( PassPipeline, + TosaPipelineFP, TosaPipelineINT, ) from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass +from executorch.exir.dialects._ops import ops as exir_ops input_t = Tuple[torch.Tensor] # Input x @@ -177,6 +181,26 @@ def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) +class DuplicateConstantOutputs(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("grid0", torch.zeros(1, 32, 32, 2)) + self.register_buffer("grid1", torch.zeros(1, 32, 32, 2)) + + def forward(self, x: torch.Tensor): + return self.grid0, self.grid1, x + + +class DuplicateConstantOutputsWithAdd(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("grid0", torch.zeros(1, 32, 32, 2)) + self.register_buffer("grid1", torch.zeros(1, 32, 32, 2)) + + def forward(self, x: torch.Tensor): + return self.grid0, self.grid1, x + x + + modules: Dict[str, ModuleMetadata] = { "no_nhwc": NoNHWC(), "parallel_clusters": ParallelClusters(), @@ -209,3 +233,38 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No module_nn = cast(torch.nn.Module, module) pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) pipeline.run() + + +def test_to_tosa_memory_format_no_target_preserves_duplicate_output_slots() -> None: + pipeline = PassPipeline[input_t]( + DuplicateConstantOutputs(), + (torch.rand(1, 2, 32, 32),), + quantize=False, + pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass], + passes_with_exported_program=[ + FuseEqualPlaceholdersPass, + ToTosaMemoryFormatPass, + EnsureUniqueOutputNodesPass, + ], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + graph_module = pipeline.tester.get_artifact().exported_program().graph_module + output_node = graph_module.graph.output_node() + outputs = list(output_node.args[0]) + + assert outputs[0] is not outputs[1] + assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default + assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default + assert outputs[0].args[0] is outputs[1].args[0] + + +def test_to_tosa_memory_format_tosa_FP_duplicate_output_identity() -> None: + pipeline = TosaPipelineFP[input_t]( + DuplicateConstantOutputsWithAdd(), + (torch.rand(1, 2, 32, 32),), + [], + [], + ) + pipeline.run() diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index dffd55cc52a..364b3b0ada6 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -8,6 +8,7 @@ conv3d, depthwise_conv2d, gather, + identity, matmul, pad, rescale, diff --git a/backends/arm/tosa/dialect/ops/identity.py b/backends/arm/tosa/dialect/ops/identity.py new file mode 100644 index 00000000000..6e26d8e8b22 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/identity.py @@ -0,0 +1,16 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import TosaSpecification + + +@register_fake_tosa_op( + "IDENTITY(Tensor input) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def IDENTITY(a): + return torch.empty_like(a, dtype=a.dtype)