diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 1bd18de581d..90d626bcb2a 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -97,6 +97,7 @@ from .decompose_var_pass import DecomposeVarPass # noqa from .decompose_where_scalar_other_pass import DecomposeWhereScalarOtherPass # noqa from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa +from .ensure_unique_output_nodes_pass import EnsureUniqueOutputNodesPass # noqa from .fold_qdq_with_annotated_qparams_pass import ( # noqa FoldAndAnnotateQParamsPass, QuantizeClampArgumentsPass, diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 31cb7a2e2c7..a51e8dd9c65 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -96,6 +96,7 @@ DecomposeVarPass, DecomposeWhereScalarOtherPass, DecorateFp32toInt32CastingPass, + EnsureUniqueOutputNodesPass, FoldAndAnnotateQParamsPass, FuseBatchNorm2dPass, FuseConsecutiveConcatShapesPass, @@ -502,6 +503,7 @@ def _tosa_pipeline( FuseEqualPlaceholdersPass(exported_program), FuseConsecutiveConcatShapesPass(), ToTosaMemoryFormatPass(exported_program), + EnsureUniqueOutputNodesPass(), RemoveNoopPass(), InsertRescalePass(), ] diff --git a/backends/arm/_passes/ensure_unique_output_nodes_pass.py b/backends/arm/_passes/ensure_unique_output_nodes_pass.py new file mode 100644 index 00000000000..af0435da6fd --- /dev/null +++ b/backends/arm/_passes/ensure_unique_output_nodes_pass.py @@ -0,0 +1,82 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections import Counter +from typing import Any, Set, Type + +import torch +from executorch.backends.arm._passes import ArmPass +from executorch.backends.arm._passes.arm_pass_utils import create_node +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class EnsureUniqueOutputNodesPass(ArmPass): + """Ensure each graph output leaf references a unique producer node. + + If the same node appears multiple times in the output structure, insert a + ``tosa.IDENTITY`` node for each occurrence and replace the repeated output + entries with those identity nodes. + + """ + + _passes_required_after: Set[Type[ExportPass]] = set() + + @staticmethod + def _collect_output_nodes( + output_value: Any, counts: Counter[torch.fx.Node] + ) -> None: + if isinstance(output_value, torch.fx.Node): + counts[output_value] += 1 + return + if isinstance(output_value, (list, tuple)): + for value in output_value: + EnsureUniqueOutputNodesPass._collect_output_nodes(value, counts) + + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + graph = graph_module.graph + output_node = graph.output_node() + output_value = output_node.args[0] + + counts: Counter[torch.fx.Node] = Counter() + self._collect_output_nodes(output_value, counts) + repeated_nodes = {node for node, count in counts.items() if count > 1} + if not repeated_nodes: + return PassResult(graph_module, False) + + modified = False + + def _replace_repeated_outputs(value: Any) -> Any: + nonlocal modified + if isinstance(value, torch.fx.Node): + if value not in repeated_nodes: + return value + with graph.inserting_before(output_node): + identity_node = create_node( + graph, + exir_ops.backend.tosa.IDENTITY.default, + args=(value,), + from_node=value, + ) + modified = True + return identity_node + + if isinstance(value, tuple): + return tuple(_replace_repeated_outputs(v) for v in value) + + if isinstance(value, list): + return [_replace_repeated_outputs(v) for v in value] + + return value + + new_output_value = _replace_repeated_outputs(output_value) + if modified: + output_node.args = (new_output_value,) + graph.eliminate_dead_code() + graph.lint() + graph_module.recompile() + graph_module = super().call(graph_module).graph_module + + return PassResult(graph_module, modified) diff --git a/backends/arm/_passes/to_tosa_memory_format_pass.py b/backends/arm/_passes/to_tosa_memory_format_pass.py index 0f32fbb52df..c102c87ba70 100644 --- a/backends/arm/_passes/to_tosa_memory_format_pass.py +++ b/backends/arm/_passes/to_tosa_memory_format_pass.py @@ -264,7 +264,6 @@ def insert_output_transpose(node, graph_module): """Convert a producer's output to channels-last by appending a backend `TRANSPOSE` node and rewiring its users. """ - rank = len(get_first_fake_tensor(node).size()) spatial_rank = node.meta["tosa_spatial_rank"] mem_format = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank) @@ -383,17 +382,18 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule): if output_dim_orders is None: raise RuntimeError(f"{output_dim_orders=} is not supported.") + transposed_output_inputs: set[torch.fx.Node] = set() for output_node_input, output_dim_order in zip( outputs, output_dim_orders, strict=True ): - if output_dim_order in ( - NCHW_ORDER, - NNCHW_ORDER, - NNNCHW_ORDER, + if ( + output_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER) + and output_node_input not in transposed_output_inputs ): self.insert_input_transpose( output_node, output_node_input, graph_module ) + transposed_output_inputs.add(output_node_input) def remove_dim_order_kwargs( self, graph_module: torch.fx.GraphModule, node: torch.fx.Node diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index c7f9da2ccd4..391ed0a82c2 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -53,6 +53,7 @@ op_tosa_conv3d, op_tosa_depthwise_conv2d, op_tosa_gather, + op_tosa_identity, op_tosa_matmul, op_tosa_pad, op_tosa_rescale, diff --git a/backends/arm/operators/op_tosa_identity.py b/backends/arm/operators/op_tosa_identity.py new file mode 100644 index 00000000000..81d31186ec4 --- /dev/null +++ b/backends/arm/operators/op_tosa_identity.py @@ -0,0 +1,61 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg + + +@register_node_visitor +class IdentityVisitor(NodeVisitor): + """Lower the TOSA IDENTITY op.""" + + target = "tosa.IDENTITY.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, 1) + validate_same_dtype(self.target, [inputs[0], output], ts) + validate_valid_dtype( + self.target, + [inputs[0], output], + [ + ts.DType.BOOL, + ts.DType.INT8, + ts.DType.INT16, + ts.DType.INT32, + ts.DType.FP16, + ts.DType.FP32, + ts.DType.BF16, + ], + self.tosa_spec, + ) + + attr = ts.TosaSerializerAttribute() + self._serialize_operator( + node, + tosa_graph, + ts.Op.IDENTITY, + [inputs[0].name], + [output.name], + attr, + ) diff --git a/backends/arm/test/misc/test_tosa_dialect_identity.py b/backends/arm/test/misc/test_tosa_dialect_identity.py new file mode 100644 index 00000000000..19461cb676c --- /dev/null +++ b/backends/arm/test/misc/test_tosa_dialect_identity.py @@ -0,0 +1,25 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import executorch.backends.arm.tosa.dialect # noqa: F401 +import torch +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops +from torch._subclasses.fake_tensor import FakeTensorMode + + +def test_identity_tosa_FP() -> None: + sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32) + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+FP") + ), FakeTensorMode() as mode: + output = exir_ops.backend.tosa.IDENTITY.default(mode.from_tensor(sample_input)) + + assert output.dtype == sample_input.dtype + assert tuple(output.shape) == tuple(sample_input.shape) diff --git a/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py b/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py new file mode 100644 index 00000000000..4dd03c1ca6e --- /dev/null +++ b/backends/arm/test/passes/test_ensure_unique_output_nodes_pass.py @@ -0,0 +1,66 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes import EnsureUniqueOutputNodesPass +from executorch.backends.arm.test.tester.test_pipeline import PassPipeline +from executorch.backends.test.harness.stages import StageType +from executorch.exir.dialects._ops import ops as exir_ops + + +class DuplicateOutputModule(torch.nn.Module): + def forward(self, x: torch.Tensor): + y = x + 1.0 + return y, y + + +class UniqueOutputModule(torch.nn.Module): + def forward(self, x: torch.Tensor): + y = x + 1.0 + z = x + 2.0 + return y, z + + +def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_output() -> ( + None +): + pipeline = PassPipeline[tuple[torch.Tensor]]( + DuplicateOutputModule(), + (torch.rand(2, 2),), + quantize=False, + pass_list=[EnsureUniqueOutputNodesPass], + ops_after_pass={ + "executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2, + }, + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + graph_module = ( + pipeline.tester.get_artifact(StageType.RUN_PASSES) + .exported_program() + .graph_module + ) + output_node = graph_module.graph.output_node() + outputs = list(output_node.args[0]) + + assert outputs[0] is not outputs[1] + assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default + assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default + assert outputs[0].args[0] is outputs[1].args[0] + + +def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() -> None: + pipeline = PassPipeline[tuple[torch.Tensor]]( + UniqueOutputModule(), + (torch.rand(2, 2),), + quantize=False, + pass_list=[EnsureUniqueOutputNodesPass], + ops_not_after_pass=[ + "executorch_exir_dialects_backend__ops_tosa_IDENTITY_default", + ], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() diff --git a/backends/arm/test/passes/test_to_tosa_memory_format.py b/backends/arm/test/passes/test_to_tosa_memory_format.py index dfd57aa7e61..a035ac61f9e 100644 --- a/backends/arm/test/passes/test_to_tosa_memory_format.py +++ b/backends/arm/test/passes/test_to_tosa_memory_format.py @@ -8,15 +8,19 @@ import torch from executorch.backends.arm._passes import ( AnnotateOutputDimOrderPass, + EnsureUniqueOutputNodesPass, + FuseEqualPlaceholdersPass, ToTosaMemoryFormatPass, ) from executorch.backends.arm.test import common from executorch.backends.arm.test.tester.test_pipeline import ( PassPipeline, + TosaPipelineFP, TosaPipelineINT, ) from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass +from executorch.exir.dialects._ops import ops as exir_ops input_t = Tuple[torch.Tensor] # Input x @@ -177,6 +181,26 @@ def get_inputs(self) -> input_t: return (torch.rand(4, 4, 4, 4),) +class DuplicateConstantOutputs(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("grid0", torch.zeros(1, 32, 32, 2)) + self.register_buffer("grid1", torch.zeros(1, 32, 32, 2)) + + def forward(self, x: torch.Tensor): + return self.grid0, self.grid1, x + + +class DuplicateConstantOutputsWithAdd(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.register_buffer("grid0", torch.zeros(1, 32, 32, 2)) + self.register_buffer("grid1", torch.zeros(1, 32, 32, 2)) + + def forward(self, x: torch.Tensor): + return self.grid0, self.grid1, x + x + + modules: Dict[str, ModuleMetadata] = { "no_nhwc": NoNHWC(), "parallel_clusters": ParallelClusters(), @@ -209,3 +233,38 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No module_nn = cast(torch.nn.Module, module) pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), []) pipeline.run() + + +def test_to_tosa_memory_format_no_target_preserves_duplicate_output_slots() -> None: + pipeline = PassPipeline[input_t]( + DuplicateConstantOutputs(), + (torch.rand(1, 2, 32, 32),), + quantize=False, + pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass], + passes_with_exported_program=[ + FuseEqualPlaceholdersPass, + ToTosaMemoryFormatPass, + EnsureUniqueOutputNodesPass, + ], + ) + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + graph_module = pipeline.tester.get_artifact().exported_program().graph_module + output_node = graph_module.graph.output_node() + outputs = list(output_node.args[0]) + + assert outputs[0] is not outputs[1] + assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default + assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default + assert outputs[0].args[0] is outputs[1].args[0] + + +def test_to_tosa_memory_format_tosa_FP_duplicate_output_identity() -> None: + pipeline = TosaPipelineFP[input_t]( + DuplicateConstantOutputsWithAdd(), + (torch.rand(1, 2, 32, 32),), + [], + [], + ) + pipeline.run() diff --git a/backends/arm/tosa/dialect/__init__.py b/backends/arm/tosa/dialect/__init__.py index dffd55cc52a..364b3b0ada6 100644 --- a/backends/arm/tosa/dialect/__init__.py +++ b/backends/arm/tosa/dialect/__init__.py @@ -8,6 +8,7 @@ conv3d, depthwise_conv2d, gather, + identity, matmul, pad, rescale, diff --git a/backends/arm/tosa/dialect/ops/identity.py b/backends/arm/tosa/dialect/ops/identity.py new file mode 100644 index 00000000000..6e26d8e8b22 --- /dev/null +++ b/backends/arm/tosa/dialect/ops/identity.py @@ -0,0 +1,16 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm.tosa.dialect.ops_registration import register_fake_tosa_op +from executorch.backends.arm.tosa.specification import TosaSpecification + + +@register_fake_tosa_op( + "IDENTITY(Tensor input) -> Tensor", + TosaSpecification.all_versions_and_profiles(), +) +def IDENTITY(a): + return torch.empty_like(a, dtype=a.dtype)