Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/arm/_passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from .decompose_var_pass import DecomposeVarPass # noqa
from .decompose_where_scalar_other_pass import DecomposeWhereScalarOtherPass # noqa
from .decorate_fp32_to_int32_casting_pass import DecorateFp32toInt32CastingPass # noqa
from .ensure_unique_output_nodes_pass import EnsureUniqueOutputNodesPass # noqa
from .fold_qdq_with_annotated_qparams_pass import ( # noqa
FoldAndAnnotateQParamsPass,
QuantizeClampArgumentsPass,
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
DecomposeVarPass,
DecomposeWhereScalarOtherPass,
DecorateFp32toInt32CastingPass,
EnsureUniqueOutputNodesPass,
FoldAndAnnotateQParamsPass,
FuseBatchNorm2dPass,
FuseConsecutiveConcatShapesPass,
Expand Down Expand Up @@ -502,6 +503,7 @@ def _tosa_pipeline(
FuseEqualPlaceholdersPass(exported_program),
FuseConsecutiveConcatShapesPass(),
ToTosaMemoryFormatPass(exported_program),
EnsureUniqueOutputNodesPass(),
RemoveNoopPass(),
InsertRescalePass(),
]
Expand Down
82 changes: 82 additions & 0 deletions backends/arm/_passes/ensure_unique_output_nodes_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections import Counter
from typing import Any, Set, Type

import torch
from executorch.backends.arm._passes import ArmPass
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class EnsureUniqueOutputNodesPass(ArmPass):
"""Ensure each graph output leaf references a unique producer node.

If the same node appears multiple times in the output structure, insert a
``tosa.IDENTITY`` node for each occurrence and replace the repeated output
entries with those identity nodes.

"""

_passes_required_after: Set[Type[ExportPass]] = set()

@staticmethod
def _collect_output_nodes(
output_value: Any, counts: Counter[torch.fx.Node]
) -> None:
if isinstance(output_value, torch.fx.Node):
counts[output_value] += 1
return
if isinstance(output_value, (list, tuple)):
for value in output_value:
EnsureUniqueOutputNodesPass._collect_output_nodes(value, counts)

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
graph = graph_module.graph
output_node = graph.output_node()
output_value = output_node.args[0]

counts: Counter[torch.fx.Node] = Counter()
self._collect_output_nodes(output_value, counts)
repeated_nodes = {node for node, count in counts.items() if count > 1}
if not repeated_nodes:
return PassResult(graph_module, False)

modified = False

def _replace_repeated_outputs(value: Any) -> Any:
nonlocal modified
if isinstance(value, torch.fx.Node):
if value not in repeated_nodes:
return value
with graph.inserting_before(output_node):
identity_node = create_node(
graph,
exir_ops.backend.tosa.IDENTITY.default,
args=(value,),
from_node=value,
)
modified = True
return identity_node

if isinstance(value, tuple):
return tuple(_replace_repeated_outputs(v) for v in value)

if isinstance(value, list):
return [_replace_repeated_outputs(v) for v in value]

return value

new_output_value = _replace_repeated_outputs(output_value)
if modified:
output_node.args = (new_output_value,)
graph.eliminate_dead_code()
graph.lint()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module

return PassResult(graph_module, modified)
10 changes: 5 additions & 5 deletions backends/arm/_passes/to_tosa_memory_format_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,6 @@ def insert_output_transpose(node, graph_module):
"""Convert a producer's output to channels-last by appending a backend
`TRANSPOSE` node and rewiring its users.
"""

rank = len(get_first_fake_tensor(node).size())
spatial_rank = node.meta["tosa_spatial_rank"]
mem_format = ToTosaMemoryFormatPass._channels_last_order(rank, spatial_rank)
Expand Down Expand Up @@ -383,17 +382,18 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
if output_dim_orders is None:
raise RuntimeError(f"{output_dim_orders=} is not supported.")

transposed_output_inputs: set[torch.fx.Node] = set()
for output_node_input, output_dim_order in zip(
outputs, output_dim_orders, strict=True
):
if output_dim_order in (
NCHW_ORDER,
NNCHW_ORDER,
NNNCHW_ORDER,
if (
output_dim_order in (NCHW_ORDER, NNCHW_ORDER, NNNCHW_ORDER)
and output_node_input not in transposed_output_inputs
):
self.insert_input_transpose(
output_node, output_node_input, graph_module
)
transposed_output_inputs.add(output_node_input)

def remove_dim_order_kwargs(
self, graph_module: torch.fx.GraphModule, node: torch.fx.Node
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
op_tosa_conv3d,
op_tosa_depthwise_conv2d,
op_tosa_gather,
op_tosa_identity,
op_tosa_matmul,
op_tosa_pad,
op_tosa_rescale,
Expand Down
61 changes: 61 additions & 0 deletions backends/arm/operators/op_tosa_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, List

import torch
import tosa_serializer as ts

from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.operators.operator_validation_utils import (
validate_num_inputs,
validate_same_dtype,
validate_valid_dtype,
)
from executorch.backends.arm.tosa.mapping import TosaArg


@register_node_visitor
class IdentityVisitor(NodeVisitor):
"""Lower the TOSA IDENTITY op."""

target = "tosa.IDENTITY.default"

def define_node(
self,
node: torch.fx.Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
validate_num_inputs(self.target, inputs, 1)
validate_same_dtype(self.target, [inputs[0], output], ts)
validate_valid_dtype(
self.target,
[inputs[0], output],
[
ts.DType.BOOL,
ts.DType.INT8,
ts.DType.INT16,
ts.DType.INT32,
ts.DType.FP16,
ts.DType.FP32,
ts.DType.BF16,
],
self.tosa_spec,
)

attr = ts.TosaSerializerAttribute()
self._serialize_operator(
node,
tosa_graph,
ts.Op.IDENTITY,
[inputs[0].name],
[output.name],
attr,
)
25 changes: 25 additions & 0 deletions backends/arm/test/misc/test_tosa_dialect_identity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import executorch.backends.arm.tosa.dialect # noqa: F401
import torch
from executorch.backends.arm.tosa.specification import (
TosaLoweringContext,
TosaSpecification,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch._subclasses.fake_tensor import FakeTensorMode


def test_identity_tosa_FP() -> None:
sample_input = torch.randn((1, 2, 3, 4), dtype=torch.float32)

with TosaLoweringContext(
TosaSpecification.create_from_string("TOSA-1.0+FP")
), FakeTensorMode() as mode:
output = exir_ops.backend.tosa.IDENTITY.default(mode.from_tensor(sample_input))

assert output.dtype == sample_input.dtype
assert tuple(output.shape) == tuple(sample_input.shape)
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright 2026 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import torch
from executorch.backends.arm._passes import EnsureUniqueOutputNodesPass
from executorch.backends.arm.test.tester.test_pipeline import PassPipeline
from executorch.backends.test.harness.stages import StageType
from executorch.exir.dialects._ops import ops as exir_ops


class DuplicateOutputModule(torch.nn.Module):
def forward(self, x: torch.Tensor):
y = x + 1.0
return y, y


class UniqueOutputModule(torch.nn.Module):
def forward(self, x: torch.Tensor):
y = x + 1.0
z = x + 2.0
return y, z


def test_ensure_unique_output_nodes_no_target_inserts_identity_per_repeated_output() -> (
None
):
pipeline = PassPipeline[tuple[torch.Tensor]](
DuplicateOutputModule(),
(torch.rand(2, 2),),
quantize=False,
pass_list=[EnsureUniqueOutputNodesPass],
ops_after_pass={
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default": 2,
},
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()

graph_module = (
pipeline.tester.get_artifact(StageType.RUN_PASSES)
.exported_program()
.graph_module
)
output_node = graph_module.graph.output_node()
outputs = list(output_node.args[0])

assert outputs[0] is not outputs[1]
assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default
assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default
assert outputs[0].args[0] is outputs[1].args[0]


def test_ensure_unique_output_nodes_no_target_keeps_unique_outputs_unchanged() -> None:
pipeline = PassPipeline[tuple[torch.Tensor]](
UniqueOutputModule(),
(torch.rand(2, 2),),
quantize=False,
pass_list=[EnsureUniqueOutputNodesPass],
ops_not_after_pass=[
"executorch_exir_dialects_backend__ops_tosa_IDENTITY_default",
],
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()
59 changes: 59 additions & 0 deletions backends/arm/test/passes/test_to_tosa_memory_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@
import torch
from executorch.backends.arm._passes import (
AnnotateOutputDimOrderPass,
EnsureUniqueOutputNodesPass,
FuseEqualPlaceholdersPass,
ToTosaMemoryFormatPass,
)

from executorch.backends.arm.test import common
from executorch.backends.arm.test.tester.test_pipeline import (
PassPipeline,
TosaPipelineFP,
TosaPipelineINT,
)
from executorch.backends.transforms.remove_getitem_op import RemoveGetItemPass
from executorch.exir.dialects._ops import ops as exir_ops

input_t = Tuple[torch.Tensor] # Input x

Expand Down Expand Up @@ -177,6 +181,26 @@ def get_inputs(self) -> input_t:
return (torch.rand(4, 4, 4, 4),)


class DuplicateConstantOutputs(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("grid0", torch.zeros(1, 32, 32, 2))
self.register_buffer("grid1", torch.zeros(1, 32, 32, 2))

def forward(self, x: torch.Tensor):
return self.grid0, self.grid1, x


class DuplicateConstantOutputsWithAdd(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("grid0", torch.zeros(1, 32, 32, 2))
self.register_buffer("grid1", torch.zeros(1, 32, 32, 2))

def forward(self, x: torch.Tensor):
return self.grid0, self.grid1, x + x


modules: Dict[str, ModuleMetadata] = {
"no_nhwc": NoNHWC(),
"parallel_clusters": ParallelClusters(),
Expand Down Expand Up @@ -209,3 +233,38 @@ def test_to_tosa_memory_format_tosa_INT_functional(module: ModuleMetadata) -> No
module_nn = cast(torch.nn.Module, module)
pipeline = TosaPipelineINT[input_t](module_nn, module.get_inputs(), [])
pipeline.run()


def test_to_tosa_memory_format_no_target_preserves_duplicate_output_slots() -> None:
pipeline = PassPipeline[input_t](
DuplicateConstantOutputs(),
(torch.rand(1, 2, 32, 32),),
quantize=False,
pass_list=[RemoveGetItemPass, AnnotateOutputDimOrderPass],
passes_with_exported_program=[
FuseEqualPlaceholdersPass,
ToTosaMemoryFormatPass,
EnsureUniqueOutputNodesPass,
],
)
pipeline.pop_stage("run_method_and_compare_outputs")
pipeline.run()

graph_module = pipeline.tester.get_artifact().exported_program().graph_module
output_node = graph_module.graph.output_node()
outputs = list(output_node.args[0])

assert outputs[0] is not outputs[1]
assert outputs[0].target == exir_ops.backend.tosa.IDENTITY.default
assert outputs[1].target == exir_ops.backend.tosa.IDENTITY.default
assert outputs[0].args[0] is outputs[1].args[0]


def test_to_tosa_memory_format_tosa_FP_duplicate_output_identity() -> None:
pipeline = TosaPipelineFP[input_t](
DuplicateConstantOutputsWithAdd(),
(torch.rand(1, 2, 32, 32),),
[],
[],
)
pipeline.run()
1 change: 1 addition & 0 deletions backends/arm/tosa/dialect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
conv3d,
depthwise_conv2d,
gather,
identity,
matmul,
pad,
rescale,
Expand Down
Loading
Loading