From a6bded2bbb79b779d35850758b0d353633632ad1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Mon, 23 Mar 2026 15:24:39 +0100 Subject: [PATCH 1/4] Arm backend: Support uint8 IO quantization for backends MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add support for IO tensors only to be uint8. In conjuction with the QuantizeInput and QuantizeOutput pass this adds the possibility to give inputs of uint8 dtype to the model directly. Change-Id: Icc08ac242e5c980f2abd484eb0e7661418873ab7 Signed-off-by: Per Åstrand --- backends/arm/_passes/insert_rescales_pass.py | 44 +- backends/arm/operators/op_tosa_rescale.py | 17 +- backends/arm/quantizer/__init__.py | 1 + backends/arm/quantizer/arm_quantizer.py | 48 ++ backends/arm/test/misc/test_rescale_range.py | 78 ++- .../test/passes/test_ioquantization_pass.py | 512 +++++++++++++++++- .../quantizer/test_uint8_io_quantization.py | 42 ++ backends/arm/tosa/dialect/ops/rescale.py | 101 +++- backends/arm/tosa/mapping.py | 3 +- backends/arm/tosa/partitioner.py | 51 +- 10 files changed, 871 insertions(+), 26 deletions(-) create mode 100644 backends/arm/test/quantizer/test_uint8_io_quantization.py diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 8cab19dc551..5f4fec8f0c4 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import math +import operator from copy import copy from typing import cast, Dict, Optional, Set, Tuple, Type @@ -34,10 +35,44 @@ class InsertRescalePass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None: + """Ensure uint8 tensors only appear at IO boundaries. + + TOSA has no true uint8 tensor type; unsigned semantics are carried via + RESCALE input/output flags. If uint8 appears for other nodes, it means + unsigned data leaked past IO. + + """ + for node in graph_module.graph.nodes: + meta_val = node.meta.get("val") + if not isinstance(meta_val, torch.Tensor): + continue + if meta_val.dtype != torch.uint8: + continue + if node.op in ("placeholder", "output"): + continue + if node.op == "call_function" and node.target == operator.getitem: + if all(user.op == "output" for user in node.users): + continue + if ( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + ): + continue + raise ValueError( + f"Found internal uint8 tensor at node {node.name} " + f"({node.target}). Uint8 is only allowed at IO boundaries." + ) + def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule): dq_args = QuantArgs.from_operator(node.target, node.args) q_args = QuantArgs.from_operator(user.target, user.args) new_scale = dq_args.scale / q_args.scale + input_unsigned = dq_args.dtype == torch.uint8 + output_unsigned = q_args.dtype == torch.uint8 + # TOSA has no true uint8 tensors; unsigned semantics are handled via + # the RESCALE flags, so uint8 does not propagate as a tensor dtype. + output_dtype = torch.int8 if output_unsigned else q_args.dtype with graph_module.graph.inserting_before(node): rescale_node = create_node( @@ -45,11 +80,15 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule exir_ops.backend.tosa.RESCALE.default, ( node.all_input_nodes[0], - q_args.dtype, + output_dtype, [new_scale], dq_args.zp, q_args.zp, ), + kwargs={ + "input_unsigned": input_unsigned, + "output_unsigned": output_unsigned, + }, ) rescale_node.meta = copy(user.meta) user.replace_all_uses_with(rescale_node) @@ -74,6 +113,9 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module.recompile() return PassResult(graph_module, modified) + def ensures(self, graph_module: GraphModule) -> None: + self._ensure_uint8_io_only(graph_module) + class InsertRescaleInt32Pass(ArmPass): """Numerous TOSA ops require inputs and outputs to be 32-bit integers in diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index 8499fc9ccd5..dfaabecb41b 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -161,6 +161,8 @@ def _build_rescale( rounding_mode: ts.RoundingMode, per_channel: bool = False, is_scale32: bool = True, + input_unsigned: bool = False, + output_unsigned: bool = False, ): """Insert a TOSA RESCALE operator configured for the quantized path. @@ -198,8 +200,8 @@ def _build_rescale( scale32=is_scale32, rounding_mode=rounding_mode, per_channel=per_channel, - input_unsigned=False, - output_unsigned=False, + input_unsigned=input_unsigned, + output_unsigned=output_unsigned, ) tosa_fb.addOperator( @@ -228,6 +230,14 @@ def define_node( scales = cast(list[float], node.args[2]) input_zp = cast(int, node.args[3]) output_zp = cast(int, node.args[4]) + if "input_unsigned" in node.kwargs: + input_unsigned = cast(bool, node.kwargs.get("input_unsigned", False)) + else: + input_unsigned = cast(bool, node.args[5]) if len(node.args) > 5 else False + if "output_unsigned" in node.kwargs: + output_unsigned = cast(bool, node.kwargs.get("output_unsigned", False)) + else: + output_unsigned = cast(bool, node.args[6]) if len(node.args) > 6 else False if ( input_dtype @@ -244,7 +254,6 @@ def define_node( raise ValueError( f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" ) - _build_rescale( tosa_graph, scale=scales, @@ -255,4 +264,6 @@ def define_node( output_zp=[output_zp], rounding_mode=ts.RoundingMode.SINGLE_ROUND, per_channel=len(scales) > 1, + input_unsigned=input_unsigned, + output_unsigned=output_unsigned, ) diff --git a/backends/arm/quantizer/__init__.py b/backends/arm/quantizer/__init__.py index 270d56a68cd..5dd5687991d 100644 --- a/backends/arm/quantizer/__init__.py +++ b/backends/arm/quantizer/__init__.py @@ -15,6 +15,7 @@ EthosUQuantizer, get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, + get_uint8_io_quantization_config, TOSAQuantizer, VgfQuantizer, ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 3d36ec997cf..2e43b924cc8 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -105,6 +105,7 @@ "VgfQuantizer", "get_symmetric_a16w8_quantization_config", "get_symmetric_quantization_config", + "get_uint8_io_quantization_config", ] logger = logging.getLogger(__name__) @@ -234,6 +235,53 @@ def get_symmetric_quantization_config( return quantization_config +@functools.lru_cache +def get_uint8_io_quantization_config( + is_qat: bool = False, + is_dynamic: bool = False, + eps: float = 2**-16, +) -> QuantizationConfig: + """Create a uint8 IO quantization config for TOSA backends. + + This config is intended for model inputs/outputs only. Internal tensors + should remain int8 for TOSA INT lowering. + + """ + extra_args: Dict[str, Any] = {"eps": eps} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args, + ), + ) + + return TOSAQuantizationConfig( + act_quantization_spec, + act_quantization_spec, + None, + None, + ) + + def get_symmetric_a8w4_quantization_config( is_per_channel: bool = True, is_qat: bool = True, is_dynamic: bool = False ): diff --git a/backends/arm/test/misc/test_rescale_range.py b/backends/arm/test/misc/test_rescale_range.py index 1075dd4d04f..f34c58995cf 100644 --- a/backends/arm/test/misc/test_rescale_range.py +++ b/backends/arm/test/misc/test_rescale_range.py @@ -1,13 +1,16 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Tuple +import executorch.backends.arm.tosa.dialect # noqa: F401 + import pytest import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, TosaSpecification, @@ -128,3 +131,76 @@ def test_zp_outside_range_tosa_INT(): ] ) ) + + +def test_unsigned_zp_range_tosa_INT_valid(): + # Validate unsigned zero-point ranges via explicit unsigned semantics. + # First case: uint8 input (input_unsigned=True) uses in_zp in [0,255]. + # Second case: signed int8 input but unsigned output semantics (output_unsigned=True) + # allow out_zp in [0,255]. + sample_inputs = [ + # (data, out_dtype, scale, in_zp, out_zp, input_unsigned, output_unsigned) + ( + torch.randint(low=0, high=255, size=(4, 4, 4), dtype=torch.uint8), + torch.int8, + [0.5], + 255, + 0, + True, + False, + ), + ( + torch.randint(low=-128, high=127, size=(4, 4, 4), dtype=torch.int8), + torch.int8, + [0.5], + 0, + 255, + False, + True, + ), + ] + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + for sample_input in sample_inputs: + exir_ops.backend.tosa.RESCALE.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input[:5] + ] + ), + input_unsigned=sample_input[5], + output_unsigned=sample_input[6], + ) + + +def test_unsigned_zp_range_tosa_INT_invalid(): + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="(in_zp|input_zp).*range"): + exir_ops.backend.tosa.RESCALE.default( + mode.from_tensor( + torch.randint(low=0, high=255, size=(4, 4, 4), dtype=torch.uint8) + ), + torch.int8, + [0.5], + 256, + 0, + input_unsigned=True, + output_unsigned=False, + ) + with pytest.raises(TosaValueError, match="(out_zp|output_zp).*range"): + exir_ops.backend.tosa.RESCALE.default( + mode.from_tensor( + torch.randint(low=0, high=255, size=(4, 4, 4), dtype=torch.uint8) + ), + torch.int8, + [0.5], + 0, + 256, + input_unsigned=False, + output_unsigned=True, + ) diff --git a/backends/arm/test/passes/test_ioquantization_pass.py b/backends/arm/test/passes/test_ioquantization_pass.py index 606ed6e0f01..0cad2eede3c 100644 --- a/backends/arm/test/passes/test_ioquantization_pass.py +++ b/backends/arm/test/passes/test_ioquantization_pass.py @@ -6,12 +6,37 @@ from typing import Tuple +import pytest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm._passes.insert_rescales_pass import InsertRescalePass +from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + get_uint8_io_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test.tester.test_pipeline import EthosU55PipelineINT -from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses +from executorch.backends.arm.test.tester.quantize import ArmQuantize as Quantize +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + TosaPipelineINT, +) +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.backends.test.harness.stages import StageType +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes.quantize_io_pass import ( + quantize_input, + quantize_output, + QuantizeInputs, + QuantizeOutputs, +) input_t = Tuple[torch.Tensor, torch.Tensor] @@ -46,3 +71,484 @@ def test_quantize_io_u55_INT(test_data: input_t): edge.transform(passes=[QuantizeInputs(edge, [0, 1]), QuantizeOutputs(edge, [0])]) pipeline.tester.check_not(["edge__ops_quantized_decomposed_quantize_per_tensor"]) pipeline.tester.check_not(["edge__ops_quantized_decomposed_dequantize_per_tensor"]) + + +class SimpleMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 8) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.fc2(self.relu(self.fc1(x))) + + +def _build_dq_q_graph( + input_tensor: torch.Tensor, + dq_dtype: torch.dtype, + q_dtype: torch.dtype, + dq_scale: float, + dq_zp: int, + q_scale: float, + q_zp: int, +): + builder = GraphBuilder() + x = builder.placeholder("x", input_tensor) + qd_ops = exir_ops.edge.quantized_decomposed + dq = builder.call_operator( + qd_ops.dequantize_per_tensor.default, + ( + x, + dq_scale, + dq_zp, + torch.iinfo(dq_dtype).min, + torch.iinfo(dq_dtype).max, + dq_dtype, + ), + ) + q = builder.call_operator( + qd_ops.quantize_per_tensor.default, + ( + dq, + q_scale, + q_zp, + torch.iinfo(q_dtype).min, + torch.iinfo(q_dtype).max, + q_dtype, + ), + ) + builder.output([q]) + return builder.get_graph_module() + + +def test_insert_rescale_tosa_INT_folds_uint8_input(): + graph_module = _build_dq_q_graph( + torch.randint(0, 255, (1, 4), dtype=torch.uint8), + torch.uint8, + torch.int8, + dq_scale=0.5, + dq_zp=0, + q_scale=0.25, + q_zp=0, + ) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + rescale_graph = InsertRescalePass()(graph_module).graph_module.graph + rescale_nodes = [ + node + for node in rescale_graph.nodes + if node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + ] + assert rescale_nodes + assert rescale_nodes[0].kwargs.get("input_unsigned") is True + assert rescale_nodes[0].kwargs.get("output_unsigned") is False + + +def test_insert_rescale_tosa_INT_folds_uint8_output(): + graph_module = _build_dq_q_graph( + torch.randint(-128, 127, (1, 4), dtype=torch.int8), + torch.int8, + torch.uint8, + dq_scale=0.5, + dq_zp=0, + q_scale=0.25, + q_zp=0, + ) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + rescale_graph = InsertRescalePass()(graph_module).graph_module.graph + rescale_nodes = [ + node + for node in rescale_graph.nodes + if node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + ] + assert rescale_nodes + assert rescale_nodes[0].kwargs.get("input_unsigned") is False + assert rescale_nodes[0].kwargs.get("output_unsigned") is True + assert rescale_nodes[0].args[1] == torch.int8 + + +def test_quantize_io_tosa_INT_uint8_simple_mlp(): + """Float-input MLP uses uint8 IO quantization and folds to a single + delegate. + """ + model = SimpleMLP().eval() + test_data = (torch.rand(1, 4),) + compile_spec = common.get_tosa_compile_spec("TOSA-1.0+INT") + + tester = ArmTester(model, test_data, compile_spec) + quantizer = TOSAQuantizer(compile_spec) + quantizer.set_global(get_symmetric_quantization_config()) + quantizer.set_io(get_uint8_io_quantization_config()) + quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config) + + tester.quantize(quant_stage) + tester.export() + tester.to_edge_transform_and_lower() + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check( + [ + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" + ] + ) + lowered_graph = tester.get_artifact().exported_program().graph_module.graph + delegate_nodes = [ + node + for node in lowered_graph.nodes + if node.op == "call_function" + and node.target == torch.ops.higher_order.executorch_call_delegate + ] + assert len(delegate_nodes) == 1 + quant_nodes = [ + node + for node in lowered_graph.nodes + if node.op == "call_function" + and node.target + == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ] + assert len(quant_nodes) == 1 + delegate_args = delegate_nodes[0].all_input_nodes + assert ( + quant_nodes[0] in delegate_args + ), "Expected input quantize to feed the delegate call." + + +def test_quantize_io_tosa_INT_uint8(): + """Make sure quantizer doesn't allow uint8 internally.""" + torch_q_ops = ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ) + torch_dq_ops = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ) + + model = SimpleMLP().eval() + test_data = (torch.rand(1, 4),) + compile_spec = common.get_tosa_compile_spec("TOSA-1.0+INT") + + tester = ArmTester(model, test_data, compile_spec) + quantizer = TOSAQuantizer(compile_spec) + quantizer.set_global(get_symmetric_quantization_config()) + quantizer.set_io(get_uint8_io_quantization_config()) + quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config) + + tester.quantize(quant_stage) + tester.export() + + exported_program = tester.get_artifact() + graph = exported_program.graph_module.graph + placeholders = [node for node in graph.nodes if node.op == "placeholder"] + output_nodes = [node for node in graph.nodes if node.op == "output"] + + max_uint8_q_nodes = len(placeholders) + len(output_nodes) + uint8_q_nodes = [] + bad_nodes = [] + for node in graph.nodes: + meta_val = node.meta.get("val") + if not isinstance(meta_val, torch.Tensor): + continue + if meta_val.dtype != torch.uint8: + continue + if node.op in ("placeholder", "output"): + continue + if node.op == "call_function" and node.target in (*Q_OPS, *torch_q_ops): + uint8_q_nodes.append((node.name, node.target)) + continue + if node.op == "call_function" and node.target in (*DQ_OPS, *torch_dq_ops): + bad_nodes.append((node.name, node.target)) + continue + bad_nodes.append((node.name, node.target)) + assert not bad_nodes, ( + "Found internal uint8 tensors outside IO boundaries: " f"{bad_nodes}" + ) + assert len(uint8_q_nodes) <= max_uint8_q_nodes, ( + "Expected uint8 quantize nodes only at IO boundaries; " + f"found {len(uint8_q_nodes)} (max {max_uint8_q_nodes})." + ) + + +def test_quantize_io_tosa_INT_uint8_pipeline(): + """Use TOSA pipeline to build an end-to-end flow and accept uint8 IO.""" + model = SimpleMLP().eval() + test_data = (torch.rand(1, 4),) + pipeline = TosaPipelineINT( + model, + test_data, + [], + [], + run_on_tosa_ref_model=True, + use_to_edge_transform_and_lower=False, + ) + + tester = pipeline.tester + tester.quantize() + tester.export() + tester.to_edge() + edge = tester.get_artifact() + edge.transform( + passes=[ + QuantizeInputs( + edge, + { + 0: { + "scale": 1.0, + "zp": 0, + "dtype": torch.uint8, + } + }, + ), + QuantizeOutputs( + edge, + { + 0: { + "scale": 1.0, + "zp": 0, + "dtype": torch.uint8, + } + }, + ), + ] + ) + + exported_program = edge.exported_program() + graph_module = exported_program.graph_module + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + graph = InsertRescalePass()(graph_module).graph_module.graph + + assert any( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + and node.kwargs.get("input_unsigned") + for node in graph.nodes + ), "Expected RESCALE with input_unsigned=True in pipeline flow." + + +def test_quantize_io_tosa_INT_uint8_io_add(): + """Model accepts uint8 inputs/outputs while TOSA uses int8 internally.""" + model = SimpleModel().eval() + test_data = SimpleModel.test_data["rand_rand"] + compile_spec = common.get_tosa_compile_spec("TOSA-1.0+INT") + + tester = ArmTester(model, test_data, compile_spec) + quantizer = TOSAQuantizer(compile_spec) + quantizer.set_global(get_symmetric_quantization_config()) + quantizer.set_io(get_uint8_io_quantization_config()) + quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config) + + tester.quantize(quant_stage) + tester.export() + tester.to_edge() + edge = tester.get_artifact() + edge.transform( + passes=[ + QuantizeInputs(edge, [0, 1]), + QuantizeOutputs(edge, [0]), + ] + ) + + exported_program = edge.exported_program() + graph = exported_program.graph_module.graph + placeholders = [node for node in graph.nodes if node.op == "placeholder"] + assert len(placeholders) == 2 + assert placeholders[0].meta["val"].dtype == torch.uint8 + assert placeholders[1].meta["val"].dtype == torch.uint8 + output_node = graph.output_node() + output_val = output_node.args[0][0] + assert output_val.meta["val"].dtype == torch.uint8 + + graph_module = exported_program.graph_module + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + rescale_graph = InsertRescalePass()(graph_module).graph_module.graph + + rescale_nodes = [ + node + for node in rescale_graph.nodes + if node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + ] + assert rescale_nodes, "Expected RESCALE ops after lowering." + assert any( + node.kwargs.get("input_unsigned") for node in rescale_nodes + ), "Expected input_unsigned on IO rescale." + assert any( + node.kwargs.get("output_unsigned") for node in rescale_nodes + ), "Expected output_unsigned on IO rescale." + assert all( + node.args[1] == torch.int8 + for node in rescale_nodes + if node.kwargs.get("input_unsigned") or node.kwargs.get("output_unsigned") + ), "Unsigned IO rescales must output int8 internally." + + +def test_quantize_io_tosa_INT_uint8_numeric(): + """Run TOSA flow with uint8 input and verify numerical output.""" + if not TosaPipelineINT.is_tosa_ref_model_available(): + pytest.skip("TOSA reference model not available.") + model = SimpleModel().eval() + calib_input = torch.rand(1, 4) + calib_other = torch.rand(1, 4) + + pipeline = TosaPipelineINT( + model, + (calib_input, calib_other), + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.quantizer.set_io(get_uint8_io_quantization_config()) + + qparams = {} + + def _apply_uint8_io(ep): + in0_scale, in0_zp, in0_qmin, in0_qmax, in0_dtype = quantize_input(ep, 0) + in1_scale, in1_zp, in1_qmin, in1_qmax, in1_dtype = quantize_input(ep, 1) + out_scale, out_zp, out_qmin, out_qmax, out_dtype = quantize_output(ep, 0) + qparams.update( + { + "in0_scale": in0_scale, + "in0_zp": in0_zp, + "in0_qmin": in0_qmin, + "in0_qmax": in0_qmax, + "in0_dtype": in0_dtype, + "in1_scale": in1_scale, + "in1_zp": in1_zp, + "in1_qmin": in1_qmin, + "in1_qmax": in1_qmax, + "in1_dtype": in1_dtype, + "out_scale": out_scale, + "out_zp": out_zp, + "out_qmin": out_qmin, + "out_qmax": out_qmax, + "out_dtype": out_dtype, + } + ) + return ep + + class _Uint8ReferenceStage: + def __init__(self, reference_model): + self.reference_model = reference_model.eval() + + def stage_type(self): + return StageType.RUN_PASSES + + @property + def artifact(self): + return self.reference_model + + @property + def graph_module(self): + return None + + def run_artifact(self, inputs): + def _quantize(tensor, scale, zp, qmin, qmax, dtype): + return torch.ops.quantized_decomposed.quantize_per_tensor( + tensor, scale, zp, qmin, qmax, dtype + ) + + def _dequantize(tensor, scale, zp, qmin, qmax, dtype): + return torch.ops.quantized_decomposed.dequantize_per_tensor( + tensor, scale, zp, qmin, qmax, dtype + ) + + float_x = _dequantize( + inputs[0], + qparams["in0_scale"], + qparams["in0_zp"], + qparams["in0_qmin"], + qparams["in0_qmax"], + qparams["in0_dtype"], + ) + float_y = _dequantize( + inputs[1], + qparams["in1_scale"], + qparams["in1_zp"], + qparams["in1_qmin"], + qparams["in1_qmax"], + qparams["in1_dtype"], + ) + float_out = self.reference_model(float_x, float_y) + ref_u8 = _quantize( + float_out, + qparams["out_scale"], + qparams["out_zp"], + qparams["out_qmin"], + qparams["out_qmax"], + qparams["out_dtype"], + ) + # Match TOSA's signless int8 representation of unsigned outputs. + return ref_u8 + + pipeline.pop_stage("run_method_and_compare_outputs.original_model") + # Insert quantization of inputs/outputs after lowering so we can run uint8 IO. + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.run_passes, + RunPasses(pass_functions=[_apply_uint8_io]), + suffix="uint8_io", + ) + pipeline.add_stage_after( + "to_executorch", + lambda: setattr( + pipeline.tester, + "stages", + { + **pipeline.tester.stages, + StageType.RUN_PASSES: _Uint8ReferenceStage(model), + }, + ), + suffix="uint8_ref", + ) + + # Run the pipeline to get the quantization parameters without the standard comparison step + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + # Calculate the calib inputs and outputs uint8 values given the + # calibrated quantization parameters, so we can run the reference with the same quantized inputs. + input_tensor = torch.ops.quantized_decomposed.quantize_per_tensor( + calib_input, + qparams["in0_scale"], + qparams["in0_zp"], + qparams["in0_qmin"], + qparams["in0_qmax"], + qparams["in0_dtype"], + ) + other_input = torch.ops.quantized_decomposed.quantize_per_tensor( + calib_other, + qparams["in1_scale"], + qparams["in1_zp"], + qparams["in1_qmin"], + qparams["in1_qmax"], + qparams["in1_dtype"], + ) + + print( + f"input_tensor: {input_tensor}, other_input: {other_input}, qparams: {qparams}" + ) + + # Compare against a reference that dequantizes uint8 inputs, runs the float model, + # and requantizes to match TOSA's signless int8 representation. + def uint8_compare_callback(reference, output, _qparams): + # Map signless int8 to uint8 + output = output.to(torch.uint8) + diff = (output.to(torch.int16) - reference.to(torch.int16)).abs() + if diff.max().item() > 1: + raise AssertionError( + "Output mismatch beyond 1 LSB after uint8 IO flow. " + f"max abs diff={diff.max().item()}" + ) + + pipeline.tester.run_method_and_compare_outputs( + stage=StageType.TO_EXECUTORCH, + inputs=(input_tensor, other_input), + qtol=1, + reference_stage_type=StageType.RUN_PASSES, + compare_callback=lambda ref, test, qparmams: uint8_compare_callback( + ref, test, qparams + ), + ) diff --git a/backends/arm/test/quantizer/test_uint8_io_quantization.py b/backends/arm/test/quantizer/test_uint8_io_quantization.py new file mode 100644 index 00000000000..7461ca85a6f --- /dev/null +++ b/backends/arm/test/quantizer/test_uint8_io_quantization.py @@ -0,0 +1,42 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.arm.quantizer import ( + get_uint8_io_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline + + +class SimpleMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 8) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.fc2(self.relu(self.fc1(x))) + + +def test_uint8_io_quantization_config_tosa_INT_applies_to_io(): + model = SimpleMLP().eval() + test_data = (torch.rand(1, 4),) + compile_spec = common.get_tosa_compile_spec("TOSA-1.0+INT") + quantizer = TOSAQuantizer(compile_spec) + quantizer.set_io(get_uint8_io_quantization_config()) + + io_config = get_uint8_io_quantization_config() + pipeline = QuantizationPipeline( + model, + test_data, + quantizer=quantizer, + input_qspecs={io_config.input_activation: 1}, + output_qspecs={io_config.output_activation: 1}, + ) + pipeline.run() diff --git a/backends/arm/tosa/dialect/ops/rescale.py b/backends/arm/tosa/dialect/ops/rescale.py index 928ff72c9cc..c782ab4ae81 100644 --- a/backends/arm/tosa/dialect/ops/rescale.py +++ b/backends/arm/tosa/dialect/ops/rescale.py @@ -16,11 +16,18 @@ @register_fake_tosa_op( - "RESCALE(Tensor input1, ScalarType dtype, float[] scale, int in_zp, int out_zp) -> Tensor", # schema + "RESCALE(Tensor input1, ScalarType dtype, float[] scale, int in_zp, int out_zp, *, bool input_unsigned=False, bool output_unsigned=False) -> Tensor", # schema TosaSpecification.all_versions_for_profile("INT"), # target TOSA specifications ) -def RESCALE( - x: torch.Tensor, dtype: torch.dtype, scales: List[float], in_zp: int, out_zp: int +def RESCALE( # noqa: C901 + x: torch.Tensor, + dtype: torch.dtype, + scales: List[float], + in_zp: int, + out_zp: int, + *, + input_unsigned: bool = False, + output_unsigned: bool = False, ) -> torch.Tensor: tosa_spec = get_context_spec() """Casts the input tensor to dtype `dtype` to produce the correct tensor @@ -35,20 +42,82 @@ def RESCALE( ) if dtype not in (torch.int32, torch.int8, torch.int16): - raise NotImplementedError( - f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}" - ) - if dtype in (torch.int32, torch.int16) and out_zp != 0: - raise ValueError( - f"TOSA requires output_zp to be zero when the output dtype is {dtype}." + raise TosaValueError( + f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}", + op="RESCALE", ) - if x.dtype in (torch.int32, torch.int16) and in_zp != 0: - raise ValueError( - f"TOSA requires input_zp to be zero when the input dtype is {dtype}" + if input_unsigned and output_unsigned: + raise TosaValueError( + "TOSA requires input_unsigned and output_unsigned not both be true.", + op="RESCALE", ) - if x.dtype == torch.int8 and not -128 <= in_zp <= 127: - raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.") - if dtype == torch.int8 and not -128 <= out_zp <= 127: - raise ValueError(f"{out_zp=} outside valid range (-128,127) for int8.") + + if input_unsigned: + if x.dtype not in (torch.int8, torch.int16, torch.uint8): + raise TosaValueError( + f"input_unsigned requires int8/int16/uint8 input dtype, got {x.dtype}.", + op="RESCALE", + ) + if x.dtype == torch.int32: + raise TosaValueError( + "TOSA forbids input_unsigned for int32 inputs.", op="RESCALE" + ) + if x.dtype == torch.int16: + if in_zp not in (0, 32768): + raise TosaValueError( + f"{in_zp=} outside valid range (0,32768) for uint16.", + op="RESCALE", + ) + else: + if not 0 <= in_zp <= 255: + raise TosaValueError( + f"{in_zp=} outside valid range (0,255) for uint8.", + op="RESCALE", + ) + else: + if x.dtype in (torch.int32, torch.int16) and in_zp != 0: + raise TosaValueError( + f"TOSA requires input_zp to be zero when the input dtype is {x.dtype}.", + op="RESCALE", + ) + if x.dtype == torch.int8 and not -128 <= in_zp <= 127: + raise TosaValueError( + f"{in_zp=} outside valid range (-128,127) for int8.", + op="RESCALE", + ) + + if output_unsigned: + if dtype not in (torch.int8, torch.int16): + raise TosaValueError( + f"output_unsigned requires int8/int16 output dtype, got {dtype}.", + op="RESCALE", + ) + if dtype == torch.int32: + raise TosaValueError( + "TOSA forbids output_unsigned for int32 outputs.", op="RESCALE" + ) + if dtype == torch.int16: + if out_zp not in (0, 32768): + raise TosaValueError( + f"{out_zp=} outside valid range (0,32768) for uint16.", + op="RESCALE", + ) + else: + if not 0 <= out_zp <= 255: + raise TosaValueError( + f"{out_zp=} outside valid range (0,255) for uint8.", + op="RESCALE", + ) + else: + if dtype in (torch.int32, torch.int16) and out_zp != 0: + raise TosaValueError( + f"TOSA requires output_zp to be zero when the output dtype is {dtype}.", + op="RESCALE", + ) + if dtype == torch.int8 and not -128 <= out_zp <= 127: + raise TosaValueError( + f"{out_zp=} outside valid range (-128,127) for int8.", + op="RESCALE", + ) return torch.empty_like(x, dtype=dtype) diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index b0e8aee8869..da57729f6ba 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -26,7 +26,6 @@ torch.cfloat, torch.complex128, torch.cdouble, - torch.uint8, torch.int64, torch.long, ) @@ -100,6 +99,8 @@ def map_dtype(data_type: torch.dtype) -> Any: torch.half: ts.DType.FP16, torch.bfloat16: ts.DType.BF16, torch.int8: ts.DType.INT8, + # TOSA uses signless int8; unsigned semantics are expressed via RESCALE. + torch.uint8: ts.DType.INT8, torch.int16: ts.DType.INT16, torch.short: ts.DType.INT16, torch.int32: ts.DType.INT32, diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index 957e8f0e5d4..bd74f891664 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -14,8 +14,9 @@ """ import logging +import operator from itertools import count -from typing import Callable, List, Optional, Sequence, Tuple +from typing import Callable, cast, List, Optional, Sequence, Tuple import torch from executorch.backends.arm._passes.arm_pass_utils import ( @@ -200,6 +201,45 @@ def _detag_boundary_nodes( del node.meta["delegation_tag"] break + def _partition_has_invalid_uint8(self, partition: Partition, tag: str) -> bool: + """Return True if any uint8 appears outside allowed IO nodes. + + TOSA does not have a true uint8 tensor type. Unsigned semantics are only + allowed at IO boundaries and are carried via RESCALE flags. If a + partition contains uint8 in any other node, it will fail later in + lowering, so reject the partition here. + + """ + for node in partition.nodes: + if not is_partitioned(node, tag): + # Ignore nodes that were de-tagged after boundary processing. + continue + dtype: Optional[torch.dtype] = None + meta_val = node.meta.get("val") + if isinstance(meta_val, torch.Tensor): + dtype = meta_val.dtype + else: + dtype = cast(Optional[torch.dtype], node.meta.get("dtype")) + if dtype is None: + try: + dtype = get_first_fake_tensor(node).dtype + except (AttributeError, KeyError, RuntimeError, ValueError): + dtype = None + if dtype is None: + continue + if dtype != torch.uint8: + continue + + is_allowed = node.op in ("placeholder", "output") + is_allowed = is_allowed or ( + node.op == "call_function" and node.target == operator.getitem + ) + # Allow uint8 on Q/DQ nodes that mediate IO quantization. + is_allowed = is_allowed or node.target in Q_OPS or node.target in DQ_OPS + if not is_allowed: + return True + return False + def _tag_module( # noqa self, module: GraphModule, @@ -255,6 +295,15 @@ def _tag_module( # noqa reporter, ) + if self._partition_has_invalid_uint8(partition, tag): + reject_partition( + "Partition contained internal uint8 tensors. Uint8 is only supported at IO boundaries for TOSA backends.", + partition, + reporter, + ) + tags.remove(tag) + continue + # Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation." is_nocompute_partition = all( _is_noop_clone(node) From ddcc87d2e881c297111b752f233a220b93c66576 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Mon, 30 Mar 2026 17:28:07 +0200 Subject: [PATCH 2/4] Arm backend: Avoid sharing uint8 IO qspecs in shared clusters MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit SharedQspecQuantizer can propagate the IO quantization spec into internal nodes when using the composable quantizer. For uint8 IO this violates the TOSA constraint that uint8 is only allowed at IO boundaries. Skip IO-based qspec anchors for uint8 so internal nodes stay int8 while preserving shared qspec behavior elsewhere. Change-Id: Ie068de0c46426f386c86d9c295459011e906f335 Signed-off-by: Per Åstrand --- backends/arm/quantizer/arm_quantizer_utils.py | 72 +++++++++++-------- 1 file changed, 44 insertions(+), 28 deletions(-) diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 3deb9d00741..fb4f363d6b0 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -532,45 +532,61 @@ def _get_input_nodes_with_float_output(self, node: Node) -> list[Node]: def _get_user_nodes_with_float_input(self, node: Node) -> list[Node]: return [n for n in node.users.keys() if has_float_output(node)] + def _skip_shared_qspec_from_io(self, node: Node, qspec: QuantizationSpec) -> bool: + return node.op in ("placeholder", "output") and qspec.dtype == torch.uint8 + + def _maybe_enqueue_shared_node( + self, neighbor: Node, shared_nodes: set[Node], bfs_queue: list[Node] + ) -> None: + if neighbor.target in self.targets and neighbor not in shared_nodes: + if not self._is_annotated(neighbor): + bfs_queue.append(neighbor) + + def _append_output_qspec(self, node: Node, adjacent_qspecs: list[Any]) -> None: + if not self._is_annotated(node): + return + output_qspec = node.meta.get( # type: ignore[union-attr] + Q_ANNOTATION_KEY + ).output_qspec + if output_qspec is None: + return + if self._skip_shared_qspec_from_io(node, output_qspec): + return + adjacent_qspecs.append(output_qspec) + + def _append_input_qspec( + self, user_node: Node, input_node: Node, adjacent_qspecs: list[Any] + ) -> None: + if not self._is_annotated(user_node): + return + qspec_map = user_node.meta.get(Q_ANNOTATION_KEY) + if qspec_map is None: + return + if input_node not in qspec_map.input_qspec_map: + return + input_qspec = qspec_map.input_qspec_map[input_node] + if input_qspec is None: + return + if self._skip_shared_qspec_from_io(user_node, input_qspec): + return + adjacent_qspecs.append(input_qspec) + def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any]]: shared_nodes = set() bfs_queue = [root_node] - adjacent_qspecs = [] + adjacent_qspecs: list[Any] = [] while bfs_queue: node = bfs_queue.pop(0) shared_nodes.add(node) for input_node in node.all_input_nodes: - if input_node.target in self.targets and input_node not in shared_nodes: - if not self._is_annotated(input_node): - bfs_queue.append(input_node) - if self._is_annotated(input_node): - output_qspec = input_node.meta.get( # type: ignore[union-attr] - Q_ANNOTATION_KEY - ).output_qspec - if output_qspec is not None: - adjacent_qspecs.append(output_qspec) + self._maybe_enqueue_shared_node(input_node, shared_nodes, bfs_queue) + self._append_output_qspec(input_node, adjacent_qspecs) for output_node in node.users.keys(): - if ( - output_node.target in self.targets - and output_node not in shared_nodes - ): - if not self._is_annotated(output_node): - bfs_queue.append(output_node) - if ( - self._is_annotated(output_node) - and node - in output_node.meta.get( # type: ignore[union-attr] - Q_ANNOTATION_KEY - ).input_qspec_map - ): - input_qspec = output_node.meta.get( # type: ignore[union-attr] - Q_ANNOTATION_KEY - ).input_qspec_map[node] - if input_qspec is not None: - adjacent_qspecs.append(input_qspec) + self._maybe_enqueue_shared_node(output_node, shared_nodes, bfs_queue) + self._append_input_qspec(output_node, node, adjacent_qspecs) return shared_nodes, adjacent_qspecs From d20ad34a219dd6a49107a58a8f925dd51f4f13f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Tue, 31 Mar 2026 16:03:33 +0200 Subject: [PATCH 3/4] Arm backend: add preserve_io_quantization compile spec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add an option to preserve the quantization on IO. Useful for keeping input and output tensors quantized when backend supports both +INT and +FP. Change-Id: Ibf6177e70c2abd9f64151553cb94698591a77acc Signed-off-by: Per Åstrand --- backends/arm/common/arm_compile_spec.py | 36 +++++++++++++++++++++ backends/arm/test/misc/test_compile_spec.py | 25 +++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/backends/arm/common/arm_compile_spec.py b/backends/arm/common/arm_compile_spec.py index 2d3948beeb1..adc98f09254 100644 --- a/backends/arm/common/arm_compile_spec.py +++ b/backends/arm/common/arm_compile_spec.py @@ -36,6 +36,7 @@ class DebugMode(Enum): compiler_flags: list[str] = field(default_factory=list) path_for_intermediates: str | None = None tosa_debug_mode: DebugMode | None = None + preserve_io_quantization: bool = False _TOSA_SPEC_KEY = "tosa_spec" _COMPILE_FLAGS_KEY = "compile_flags" @@ -44,6 +45,7 @@ class DebugMode(Enum): _DEBUG_MODE_KEY = "dump_debug_info" _OUTPUT_REORDER_KEY = "ouput_reorder_workaround" _TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config" + _PRESERVE_IO_QUANT_KEY = "preserve_io_quantization" def _set_compile_specs( self, @@ -53,6 +55,7 @@ def _set_compile_specs( tosa_debug_mode: DebugMode | None = None, output_order_workaround: bool = False, pipeline_config: ArmPassPipelineConfig | None = None, + preserve_io_quantization: bool = False, ): """Set all values of dataclass directly.""" self.tosa_spec = tosa_spec @@ -61,6 +64,8 @@ def _set_compile_specs( self.tosa_debug_mode = tosa_debug_mode self._pipeline_config = pipeline_config self.output_order_workaround = output_order_workaround + self.preserve_io_quantization = preserve_io_quantization + self._warn_if_redundant_preserve_io_quantization() if output_order_workaround: warnings.warn( "ArmCompileSpec(output_order_workaround=True) is deprecated and will be " @@ -78,6 +83,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 tosa_debug_mode: ArmCompileSpec.DebugMode | None = None output_order_workaround: bool = False pipeline_config: ArmPassPipelineConfig | None = None + preserve_io_quantization: bool = False unknown_specs: dict[str, str] = {} for spec in compile_specs: key = spec.key @@ -128,6 +134,8 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 "More than one transform pipeline entry in compile spec." ) pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val)) + elif key == ArmCompileSpec._PRESERVE_IO_QUANT_KEY: + preserve_io_quantization = str(val).lower() in ("1", "true", "yes") else: unknown_specs[key] = val @@ -151,6 +159,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 tosa_debug_mode=tosa_debug_mode, output_order_workaround=output_order_workaround, pipeline_config=pipeline_config, + preserve_io_quantization=preserve_io_quantization, ) cls._from_list_hook(compile_spec, unknown_specs) compile_spec._validate() @@ -227,8 +236,35 @@ def _to_list(self): self._pipeline_config.serialize(), ) ) + compile_spec.append( + CompileSpec( + ArmCompileSpec._PRESERVE_IO_QUANT_KEY, + str(bool(self.preserve_io_quantization)).encode(), + ) + ) return compile_spec + def _set_preserve_io_quantization(self, enabled: bool) -> "ArmCompileSpec": + """Preserve Q/DQ nodes at IO boundaries when lowering.""" + self.preserve_io_quantization = enabled + self._warn_if_redundant_preserve_io_quantization() + return self + + def _warn_if_redundant_preserve_io_quantization(self) -> None: + """Warn when preserve_io_quantization has no effect for INT-only + specs. + """ + if ( + self.preserve_io_quantization + and self.tosa_spec.support_integer() + and not self.tosa_spec.support_float() + ): + warnings.warn( + "preserve_io_quantization=True is redundant for INT-only TOSA " + "specifications because boundary Q/DQ are already de-tagged.", + stacklevel=3, + ) + def _get_pass_pipeline_config(self) -> ArmPassPipelineConfig: """Returns configuration that controls how the Arm pass pipeline should behave. diff --git a/backends/arm/test/misc/test_compile_spec.py b/backends/arm/test/misc/test_compile_spec.py index da4bcebda35..cb2f45b5382 100644 --- a/backends/arm/test/misc/test_compile_spec.py +++ b/backends/arm/test/misc/test_compile_spec.py @@ -3,11 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings + from executorch.backends.arm.common.pipeline_config import SoftmaxDecompositionConfig from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.vgf import VgfCompileSpec -from pytest import raises +from pytest import raises, warns def test_compile_spec_u55_INT(): @@ -68,3 +70,24 @@ def test_compile_spec_tosa_INT(): assert TosaCompileSpec._from_list(spec_list) == compile_spec with raises(ValueError, match="Incorrect output format"): VgfCompileSpec._from_list(spec_list) + + +def test_preserve_io_quantization_roundtrip_vgf_FP_INT(): + compile_spec = VgfCompileSpec()._set_preserve_io_quantization(True) + roundtripped = VgfCompileSpec._from_list(compile_spec._to_list()) + assert roundtripped.preserve_io_quantization is True + + +def test_preserve_io_quantization_warns_for_u55_INT(): + with warns( + UserWarning, + match="preserve_io_quantization=True is redundant for INT-only TOSA", + ): + EthosUCompileSpec("ethos-u55-128")._set_preserve_io_quantization(True) + + +def test_preserve_io_quantization_no_warn_for_vgf_FP_INT(): + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") + VgfCompileSpec()._set_preserve_io_quantization(True) + assert len(recorded_warnings) == 0 From c4dc7e6d0e007df107d992e029d35fa074020ba8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Per=20=C3=85strand?= Date: Tue, 31 Mar 2026 13:41:40 +0200 Subject: [PATCH 4/4] Arm backend: Handle +FP+INT for vgf and quantized IO MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Change-Id: Ie943e1de816d981c0f09d9bd3683881c03e3000c Signed-off-by: Per Åstrand --- backends/arm/_passes/insert_rescales_pass.py | 7 ++ backends/arm/runtime/EthosUBackend.cpp | 15 ++-- backends/arm/test/common.py | 4 + .../test/passes/test_ioquantization_pass.py | 81 ++++++++++++++++--- backends/arm/test/tester/test_pipeline.py | 2 + backends/arm/tosa/partitioner.py | 32 +++++++- 6 files changed, 125 insertions(+), 16 deletions(-) diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 5f4fec8f0c4..06c27005440 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -54,6 +54,13 @@ def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None: if node.op == "call_function" and node.target == operator.getitem: if all(user.op == "output" for user in node.users): continue + if ( + node.op == "call_function" + and node.target + == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + ): + # dim_order is a view-like transform; allow it to preserve uint8 at IO. + continue if ( node.op == "call_function" and node.target == exir_ops.backend.tosa.RESCALE.default diff --git a/backends/arm/runtime/EthosUBackend.cpp b/backends/arm/runtime/EthosUBackend.cpp index d83b04f0e8e..2b17cf2c43d 100644 --- a/backends/arm/runtime/EthosUBackend.cpp +++ b/backends/arm/runtime/EthosUBackend.cpp @@ -194,19 +194,23 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { bool supported = 0; // 32 bit int (simple non-quantised test cases) supported |= - (tensor_in.scalar_type() == ScalarType::Int and + (tensor_in.scalar_type() == ScalarType::Int && handles.inputs->io[i].elem_size == 4); // 8 bit int (IOQDQ pass prepared networks) supported |= - (tensor_in.scalar_type() == ScalarType::Char and + (tensor_in.scalar_type() == ScalarType::Char && + handles.inputs->io[i].elem_size == 1); + // 8 bit uint8 (IOQDQ pass prepared networks) + supported |= + (tensor_in.scalar_type() == ScalarType::Byte && handles.inputs->io[i].elem_size == 1); // 16 bit int (IOQDQ pass prepared networks) supported |= - (tensor_in.scalar_type() == ScalarType::Short and + (tensor_in.scalar_type() == ScalarType::Short && handles.inputs->io[i].elem_size == 2); // bool (IOQDQ pass prepared networks) supported |= - (tensor_in.scalar_type() == ScalarType::Bool and + (tensor_in.scalar_type() == ScalarType::Bool && handles.inputs->io[i].elem_size == 1); if (!supported) { ET_LOG( @@ -222,7 +226,8 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { // which require permutation. bool both_int = tensor_in.scalar_type() == ScalarType::Int && handles.inputs->io[i].elem_size == 4; - bool both_char = tensor_in.scalar_type() == ScalarType::Char && + bool both_char = (tensor_in.scalar_type() == ScalarType::Char || + tensor_in.scalar_type() == ScalarType::Byte) && handles.inputs->io[i].elem_size == 1; bool both_short = tensor_in.scalar_type() == ScalarType::Short && handles.inputs->io[i].elem_size == 2; diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index be1ecaa03f5..736a5ffc6b5 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -160,6 +160,7 @@ def get_vgf_compile_spec( compiler_flags: Optional[str] = "", custom_path: Optional[str] = None, tosa_debug_mode: VgfCompileSpec.DebugMode | None = None, + preserve_io_quantization: bool = False, ) -> VgfCompileSpec: """Get the ArmCompileSpec for the default VGF tests, to modify the compile spec before calling .build() to finalize it. @@ -188,6 +189,9 @@ def get_vgf_compile_spec( .dump_debug_info(tosa_debug_mode) ) + if preserve_io_quantization: + compile_spec._set_preserve_io_quantization(True) + return compile_spec diff --git a/backends/arm/test/passes/test_ioquantization_pass.py b/backends/arm/test/passes/test_ioquantization_pass.py index 0cad2eede3c..e42f7a093cc 100644 --- a/backends/arm/test/passes/test_ioquantization_pass.py +++ b/backends/arm/test/passes/test_ioquantization_pass.py @@ -23,6 +23,7 @@ from executorch.backends.arm.test.tester.test_pipeline import ( EthosU55PipelineINT, TosaPipelineINT, + VgfPipeline, ) from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, @@ -402,6 +403,62 @@ def test_quantize_io_tosa_INT_uint8_numeric(): ) pipeline.quantizer.set_io(get_uint8_io_quantization_config()) + _run_uint8_io_numeric_pipeline(pipeline, model, calib_input, calib_other) + + +def test_quantize_io_vgf_INT_uint8_numeric(): + """Run VGF flow with uint8 input and verify numerical output.""" + + model = SimpleModel().eval() + calib_input = torch.rand(1, 4) + calib_other = torch.rand(1, 4) + + pipeline = VgfPipeline( + model, + (calib_input, calib_other), + aten_op=[], + exir_op=[], + run_on_vulkan_runtime=True, + quantize=True, + use_to_edge_transform_and_lower=True, + preserve_io_quantization=True, + ) + + pipeline.quantizer.set_io(get_uint8_io_quantization_config()) + + if pipeline.has_stage("check_not.exir_quant_nodes"): + pipeline.pop_stage("check_not.exir_quant_nodes") + _run_uint8_io_numeric_pipeline(pipeline, model, calib_input, calib_other) + + +def test_quantize_io_u55_INT_uint8_numeric(): + """Run Ethos-U55 flow with uint8 input and verify numerical output.""" + model = SimpleModel().eval() + calib_input = torch.rand(1, 4) + calib_other = torch.rand(1, 4) + + if not ( + common.corstone300_installed() + and common.arm_executor_runner_exists("corstone-300") + ): + pytest.xfail("Did not find Corstone-300 FVP or executor_runner on path") + + pipeline = EthosU55PipelineINT( + model, + (calib_input, calib_other), + aten_ops=[], + exir_ops=[], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.quantizer.set_io(get_uint8_io_quantization_config()) + + _run_uint8_io_numeric_pipeline(pipeline, model, calib_input, calib_other) + + +def _run_uint8_io_numeric_pipeline( # noqa: C901 + pipeline, model, calib_input, calib_other +) -> None: qparams = {} def _apply_uint8_io(ep): @@ -483,7 +540,6 @@ def _dequantize(tensor, scale, zp, qmin, qmax, dtype): # Match TOSA's signless int8 representation of unsigned outputs. return ref_u8 - pipeline.pop_stage("run_method_and_compare_outputs.original_model") # Insert quantization of inputs/outputs after lowering so we can run uint8 IO. pipeline.add_stage_after( "to_edge_transform_and_lower", @@ -505,9 +561,14 @@ def _dequantize(tensor, scale, zp, qmin, qmax, dtype): ) # Run the pipeline to get the quantization parameters without the standard comparison step - pipeline.pop_stage("run_method_and_compare_outputs") + if pipeline.has_stage("run_method_and_compare_outputs"): + pipeline.pop_stage("run_method_and_compare_outputs") pipeline.run() + assert qparams["in0_dtype"] == torch.uint8 + assert qparams["in1_dtype"] == torch.uint8 + assert qparams["out_dtype"] == torch.uint8 + # Calculate the calib inputs and outputs uint8 values given the # calibrated quantization parameters, so we can run the reference with the same quantized inputs. input_tensor = torch.ops.quantized_decomposed.quantize_per_tensor( @@ -527,24 +588,26 @@ def _dequantize(tensor, scale, zp, qmin, qmax, dtype): qparams["in1_dtype"], ) - print( - f"input_tensor: {input_tensor}, other_input: {other_input}, qparams: {qparams}" - ) - # Compare against a reference that dequantizes uint8 inputs, runs the float model, # and requantizes to match TOSA's signless int8 representation. def uint8_compare_callback(reference, output, _qparams): # Map signless int8 to uint8 - output = output.to(torch.uint8) - diff = (output.to(torch.int16) - reference.to(torch.int16)).abs() + output_u8 = output.to(torch.uint8) + reference_u8 = reference.to(torch.uint8) + diff = (output_u8.to(torch.int16) - reference_u8.to(torch.int16)).abs() if diff.max().item() > 1: raise AssertionError( "Output mismatch beyond 1 LSB after uint8 IO flow. " f"max abs diff={diff.max().item()}" ) + compare_stage = ( + StageType.SERIALIZE + if pipeline.has_stage("serialize") + else StageType.TO_EXECUTORCH + ) pipeline.tester.run_method_and_compare_outputs( - stage=StageType.TO_EXECUTORCH, + stage=compare_stage, inputs=(input_tensor, other_input), qtol=1, reference_stage_type=StageType.RUN_PASSES, diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index fe38b7c9690..7e7f576e35c 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -1184,6 +1184,7 @@ def __init__( tosa_extensions: Optional[List[str]] = None, tosa_spec: TosaSpecification | str | None = None, fold_quantize: bool = True, + preserve_io_quantization: bool = False, ): if tosa_spec is None: if tosa_version is None: @@ -1201,6 +1202,7 @@ def __init__( compiler_flags=vgf_compiler_flags, custom_path=custom_path, tosa_debug_mode=tosa_debug_mode, + preserve_io_quantization=preserve_io_quantization, ) super().__init__( diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index bd74f891664..c4ce5fa5eff 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -152,7 +152,11 @@ def __init__( self.additional_checks = additional_checks def _detag_boundary_nodes( - self, module: GraphModule, tag: str, reporter: WhyNoPartitionReporter + self, + module: GraphModule, + tag: str, + reporter: WhyNoPartitionReporter, + detag_first_fp_node: bool = True, ) -> None: """De-tag nodes at the partition boundary. @@ -188,7 +192,7 @@ def _detag_boundary_nodes( # Remove tag from quantize node with input outside partition, # or dequantize node with any output outside partition del node.meta["delegation_tag"] - elif not is_q_node and not is_dq_node: + elif detag_first_fp_node and not is_q_node and not is_dq_node: # For non Q/DQ nodes, remove tag from first node in partition if any input has fp dtype for input in node.all_input_nodes: if is_partitioned(input, tag): @@ -201,6 +205,21 @@ def _detag_boundary_nodes( del node.meta["delegation_tag"] break + def _preserve_io_quantization_enabled(self) -> bool: + """Return True if IO quantization should be preserved from compile + specs. + """ + for spec in self.delegation_spec.compile_specs: + if spec.key != "preserve_io_quantization": + continue + raw = ( + spec.value.decode() + if isinstance(spec.value, (bytes, bytearray)) + else str(spec.value) + ) + return raw.lower() in ("1", "true", "yes") + return False + def _partition_has_invalid_uint8(self, partition: Partition, tag: str) -> bool: """Return True if any uint8 appears outside allowed IO nodes. @@ -295,6 +314,15 @@ def _tag_module( # noqa reporter, ) + if self._preserve_io_quantization_enabled(): + # Detag boundary Q/DQ to keep IO quantization outside delegate. + self._detag_boundary_nodes( + module, + tag, + reporter, + detag_first_fp_node=False, + ) + if self._partition_has_invalid_uint8(partition, tag): reject_partition( "Partition contained internal uint8 tensors. Uint8 is only supported at IO boundaries for TOSA backends.",