diff --git a/backends/arm/_passes/insert_rescales_pass.py b/backends/arm/_passes/insert_rescales_pass.py index 8cab19dc551..06c27005440 100644 --- a/backends/arm/_passes/insert_rescales_pass.py +++ b/backends/arm/_passes/insert_rescales_pass.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. import math +import operator from copy import copy from typing import cast, Dict, Optional, Set, Tuple, Type @@ -34,10 +35,51 @@ class InsertRescalePass(ArmPass): _passes_required_after: Set[Type[ExportPass]] = set() + def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None: + """Ensure uint8 tensors only appear at IO boundaries. + + TOSA has no true uint8 tensor type; unsigned semantics are carried via + RESCALE input/output flags. If uint8 appears for other nodes, it means + unsigned data leaked past IO. + + """ + for node in graph_module.graph.nodes: + meta_val = node.meta.get("val") + if not isinstance(meta_val, torch.Tensor): + continue + if meta_val.dtype != torch.uint8: + continue + if node.op in ("placeholder", "output"): + continue + if node.op == "call_function" and node.target == operator.getitem: + if all(user.op == "output" for user in node.users): + continue + if ( + node.op == "call_function" + and node.target + == exir_ops.edge.dim_order_ops._to_dim_order_copy.default + ): + # dim_order is a view-like transform; allow it to preserve uint8 at IO. + continue + if ( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + ): + continue + raise ValueError( + f"Found internal uint8 tensor at node {node.name} " + f"({node.target}). Uint8 is only allowed at IO boundaries." + ) + def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule): dq_args = QuantArgs.from_operator(node.target, node.args) q_args = QuantArgs.from_operator(user.target, user.args) new_scale = dq_args.scale / q_args.scale + input_unsigned = dq_args.dtype == torch.uint8 + output_unsigned = q_args.dtype == torch.uint8 + # TOSA has no true uint8 tensors; unsigned semantics are handled via + # the RESCALE flags, so uint8 does not propagate as a tensor dtype. + output_dtype = torch.int8 if output_unsigned else q_args.dtype with graph_module.graph.inserting_before(node): rescale_node = create_node( @@ -45,11 +87,15 @@ def fold_dq_q_to_rescale(self, node: Node, user: Node, graph_module: GraphModule exir_ops.backend.tosa.RESCALE.default, ( node.all_input_nodes[0], - q_args.dtype, + output_dtype, [new_scale], dq_args.zp, q_args.zp, ), + kwargs={ + "input_unsigned": input_unsigned, + "output_unsigned": output_unsigned, + }, ) rescale_node.meta = copy(user.meta) user.replace_all_uses_with(rescale_node) @@ -74,6 +120,9 @@ def call(self, graph_module: GraphModule) -> PassResult: graph_module.recompile() return PassResult(graph_module, modified) + def ensures(self, graph_module: GraphModule) -> None: + self._ensure_uint8_io_only(graph_module) + class InsertRescaleInt32Pass(ArmPass): """Numerous TOSA ops require inputs and outputs to be 32-bit integers in diff --git a/backends/arm/common/arm_compile_spec.py b/backends/arm/common/arm_compile_spec.py index 2d3948beeb1..adc98f09254 100644 --- a/backends/arm/common/arm_compile_spec.py +++ b/backends/arm/common/arm_compile_spec.py @@ -36,6 +36,7 @@ class DebugMode(Enum): compiler_flags: list[str] = field(default_factory=list) path_for_intermediates: str | None = None tosa_debug_mode: DebugMode | None = None + preserve_io_quantization: bool = False _TOSA_SPEC_KEY = "tosa_spec" _COMPILE_FLAGS_KEY = "compile_flags" @@ -44,6 +45,7 @@ class DebugMode(Enum): _DEBUG_MODE_KEY = "dump_debug_info" _OUTPUT_REORDER_KEY = "ouput_reorder_workaround" _TRANSFORM_PIPELINE_CONFIG_KEY = "transform_pipeline_config" + _PRESERVE_IO_QUANT_KEY = "preserve_io_quantization" def _set_compile_specs( self, @@ -53,6 +55,7 @@ def _set_compile_specs( tosa_debug_mode: DebugMode | None = None, output_order_workaround: bool = False, pipeline_config: ArmPassPipelineConfig | None = None, + preserve_io_quantization: bool = False, ): """Set all values of dataclass directly.""" self.tosa_spec = tosa_spec @@ -61,6 +64,8 @@ def _set_compile_specs( self.tosa_debug_mode = tosa_debug_mode self._pipeline_config = pipeline_config self.output_order_workaround = output_order_workaround + self.preserve_io_quantization = preserve_io_quantization + self._warn_if_redundant_preserve_io_quantization() if output_order_workaround: warnings.warn( "ArmCompileSpec(output_order_workaround=True) is deprecated and will be " @@ -78,6 +83,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 tosa_debug_mode: ArmCompileSpec.DebugMode | None = None output_order_workaround: bool = False pipeline_config: ArmPassPipelineConfig | None = None + preserve_io_quantization: bool = False unknown_specs: dict[str, str] = {} for spec in compile_specs: key = spec.key @@ -128,6 +134,8 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 "More than one transform pipeline entry in compile spec." ) pipeline_config = ArmPassPipelineConfig.from_dict(json.loads(val)) + elif key == ArmCompileSpec._PRESERVE_IO_QUANT_KEY: + preserve_io_quantization = str(val).lower() in ("1", "true", "yes") else: unknown_specs[key] = val @@ -151,6 +159,7 @@ def _from_list(cls, compile_specs: list[CompileSpec]): # noqa: C901 tosa_debug_mode=tosa_debug_mode, output_order_workaround=output_order_workaround, pipeline_config=pipeline_config, + preserve_io_quantization=preserve_io_quantization, ) cls._from_list_hook(compile_spec, unknown_specs) compile_spec._validate() @@ -227,8 +236,35 @@ def _to_list(self): self._pipeline_config.serialize(), ) ) + compile_spec.append( + CompileSpec( + ArmCompileSpec._PRESERVE_IO_QUANT_KEY, + str(bool(self.preserve_io_quantization)).encode(), + ) + ) return compile_spec + def _set_preserve_io_quantization(self, enabled: bool) -> "ArmCompileSpec": + """Preserve Q/DQ nodes at IO boundaries when lowering.""" + self.preserve_io_quantization = enabled + self._warn_if_redundant_preserve_io_quantization() + return self + + def _warn_if_redundant_preserve_io_quantization(self) -> None: + """Warn when preserve_io_quantization has no effect for INT-only + specs. + """ + if ( + self.preserve_io_quantization + and self.tosa_spec.support_integer() + and not self.tosa_spec.support_float() + ): + warnings.warn( + "preserve_io_quantization=True is redundant for INT-only TOSA " + "specifications because boundary Q/DQ are already de-tagged.", + stacklevel=3, + ) + def _get_pass_pipeline_config(self) -> ArmPassPipelineConfig: """Returns configuration that controls how the Arm pass pipeline should behave. diff --git a/backends/arm/operators/op_tosa_rescale.py b/backends/arm/operators/op_tosa_rescale.py index 8499fc9ccd5..dfaabecb41b 100644 --- a/backends/arm/operators/op_tosa_rescale.py +++ b/backends/arm/operators/op_tosa_rescale.py @@ -161,6 +161,8 @@ def _build_rescale( rounding_mode: ts.RoundingMode, per_channel: bool = False, is_scale32: bool = True, + input_unsigned: bool = False, + output_unsigned: bool = False, ): """Insert a TOSA RESCALE operator configured for the quantized path. @@ -198,8 +200,8 @@ def _build_rescale( scale32=is_scale32, rounding_mode=rounding_mode, per_channel=per_channel, - input_unsigned=False, - output_unsigned=False, + input_unsigned=input_unsigned, + output_unsigned=output_unsigned, ) tosa_fb.addOperator( @@ -228,6 +230,14 @@ def define_node( scales = cast(list[float], node.args[2]) input_zp = cast(int, node.args[3]) output_zp = cast(int, node.args[4]) + if "input_unsigned" in node.kwargs: + input_unsigned = cast(bool, node.kwargs.get("input_unsigned", False)) + else: + input_unsigned = cast(bool, node.args[5]) if len(node.args) > 5 else False + if "output_unsigned" in node.kwargs: + output_unsigned = cast(bool, node.kwargs.get("output_unsigned", False)) + else: + output_unsigned = cast(bool, node.args[6]) if len(node.args) > 6 else False if ( input_dtype @@ -244,7 +254,6 @@ def define_node( raise ValueError( f"If output dtype is not int8 or int16, output_zp must be 0. Got {ts.DTypeNames[output_dtype]}, {output_zp=}" ) - _build_rescale( tosa_graph, scale=scales, @@ -255,4 +264,6 @@ def define_node( output_zp=[output_zp], rounding_mode=ts.RoundingMode.SINGLE_ROUND, per_channel=len(scales) > 1, + input_unsigned=input_unsigned, + output_unsigned=output_unsigned, ) diff --git a/backends/arm/quantizer/__init__.py b/backends/arm/quantizer/__init__.py index 270d56a68cd..5dd5687991d 100644 --- a/backends/arm/quantizer/__init__.py +++ b/backends/arm/quantizer/__init__.py @@ -15,6 +15,7 @@ EthosUQuantizer, get_symmetric_a16w8_quantization_config, get_symmetric_quantization_config, + get_uint8_io_quantization_config, TOSAQuantizer, VgfQuantizer, ) diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 3d36ec997cf..2e43b924cc8 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -105,6 +105,7 @@ "VgfQuantizer", "get_symmetric_a16w8_quantization_config", "get_symmetric_quantization_config", + "get_uint8_io_quantization_config", ] logger = logging.getLogger(__name__) @@ -234,6 +235,53 @@ def get_symmetric_quantization_config( return quantization_config +@functools.lru_cache +def get_uint8_io_quantization_config( + is_qat: bool = False, + is_dynamic: bool = False, + eps: float = 2**-16, +) -> QuantizationConfig: + """Create a uint8 IO quantization config for TOSA backends. + + This config is intended for model inputs/outputs only. Internal tensors + should remain int8 for TOSA INT lowering. + + """ + extra_args: Dict[str, Any] = {"eps": eps} + if is_qat: + if is_dynamic: + act_observer_or_fake_quant_ctr = FakeQuantize + dynamic_quant_observer = MovingAverageMinMaxObserver.with_args( + averaging_constant=1 + ) + extra_args["observer"] = dynamic_quant_observer + else: + act_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize # type: ignore[assignment] + else: + if is_dynamic: + act_observer_or_fake_quant_ctr = PlaceholderObserver # type: ignore[assignment] + else: + act_observer_or_fake_quant_ctr = HistogramObserver # type: ignore[assignment] + + act_quantization_spec = QuantizationSpec( + dtype=torch.uint8, + quant_min=torch.iinfo(torch.uint8).min, + quant_max=torch.iinfo(torch.uint8).max, + qscheme=torch.per_tensor_affine, + is_dynamic=is_dynamic, + observer_or_fake_quant_ctr=act_observer_or_fake_quant_ctr.with_args( + **extra_args, + ), + ) + + return TOSAQuantizationConfig( + act_quantization_spec, + act_quantization_spec, + None, + None, + ) + + def get_symmetric_a8w4_quantization_config( is_per_channel: bool = True, is_qat: bool = True, is_dynamic: bool = False ): diff --git a/backends/arm/quantizer/arm_quantizer_utils.py b/backends/arm/quantizer/arm_quantizer_utils.py index 3deb9d00741..fb4f363d6b0 100644 --- a/backends/arm/quantizer/arm_quantizer_utils.py +++ b/backends/arm/quantizer/arm_quantizer_utils.py @@ -532,45 +532,61 @@ def _get_input_nodes_with_float_output(self, node: Node) -> list[Node]: def _get_user_nodes_with_float_input(self, node: Node) -> list[Node]: return [n for n in node.users.keys() if has_float_output(node)] + def _skip_shared_qspec_from_io(self, node: Node, qspec: QuantizationSpec) -> bool: + return node.op in ("placeholder", "output") and qspec.dtype == torch.uint8 + + def _maybe_enqueue_shared_node( + self, neighbor: Node, shared_nodes: set[Node], bfs_queue: list[Node] + ) -> None: + if neighbor.target in self.targets and neighbor not in shared_nodes: + if not self._is_annotated(neighbor): + bfs_queue.append(neighbor) + + def _append_output_qspec(self, node: Node, adjacent_qspecs: list[Any]) -> None: + if not self._is_annotated(node): + return + output_qspec = node.meta.get( # type: ignore[union-attr] + Q_ANNOTATION_KEY + ).output_qspec + if output_qspec is None: + return + if self._skip_shared_qspec_from_io(node, output_qspec): + return + adjacent_qspecs.append(output_qspec) + + def _append_input_qspec( + self, user_node: Node, input_node: Node, adjacent_qspecs: list[Any] + ) -> None: + if not self._is_annotated(user_node): + return + qspec_map = user_node.meta.get(Q_ANNOTATION_KEY) + if qspec_map is None: + return + if input_node not in qspec_map.input_qspec_map: + return + input_qspec = qspec_map.input_qspec_map[input_node] + if input_qspec is None: + return + if self._skip_shared_qspec_from_io(user_node, input_qspec): + return + adjacent_qspecs.append(input_qspec) + def _get_shared_clique(self, root_node: Node) -> tuple[set[Node], list[Any]]: shared_nodes = set() bfs_queue = [root_node] - adjacent_qspecs = [] + adjacent_qspecs: list[Any] = [] while bfs_queue: node = bfs_queue.pop(0) shared_nodes.add(node) for input_node in node.all_input_nodes: - if input_node.target in self.targets and input_node not in shared_nodes: - if not self._is_annotated(input_node): - bfs_queue.append(input_node) - if self._is_annotated(input_node): - output_qspec = input_node.meta.get( # type: ignore[union-attr] - Q_ANNOTATION_KEY - ).output_qspec - if output_qspec is not None: - adjacent_qspecs.append(output_qspec) + self._maybe_enqueue_shared_node(input_node, shared_nodes, bfs_queue) + self._append_output_qspec(input_node, adjacent_qspecs) for output_node in node.users.keys(): - if ( - output_node.target in self.targets - and output_node not in shared_nodes - ): - if not self._is_annotated(output_node): - bfs_queue.append(output_node) - if ( - self._is_annotated(output_node) - and node - in output_node.meta.get( # type: ignore[union-attr] - Q_ANNOTATION_KEY - ).input_qspec_map - ): - input_qspec = output_node.meta.get( # type: ignore[union-attr] - Q_ANNOTATION_KEY - ).input_qspec_map[node] - if input_qspec is not None: - adjacent_qspecs.append(input_qspec) + self._maybe_enqueue_shared_node(output_node, shared_nodes, bfs_queue) + self._append_input_qspec(output_node, node, adjacent_qspecs) return shared_nodes, adjacent_qspecs diff --git a/backends/arm/runtime/EthosUBackend.cpp b/backends/arm/runtime/EthosUBackend.cpp index d83b04f0e8e..2b17cf2c43d 100644 --- a/backends/arm/runtime/EthosUBackend.cpp +++ b/backends/arm/runtime/EthosUBackend.cpp @@ -194,19 +194,23 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { bool supported = 0; // 32 bit int (simple non-quantised test cases) supported |= - (tensor_in.scalar_type() == ScalarType::Int and + (tensor_in.scalar_type() == ScalarType::Int && handles.inputs->io[i].elem_size == 4); // 8 bit int (IOQDQ pass prepared networks) supported |= - (tensor_in.scalar_type() == ScalarType::Char and + (tensor_in.scalar_type() == ScalarType::Char && + handles.inputs->io[i].elem_size == 1); + // 8 bit uint8 (IOQDQ pass prepared networks) + supported |= + (tensor_in.scalar_type() == ScalarType::Byte && handles.inputs->io[i].elem_size == 1); // 16 bit int (IOQDQ pass prepared networks) supported |= - (tensor_in.scalar_type() == ScalarType::Short and + (tensor_in.scalar_type() == ScalarType::Short && handles.inputs->io[i].elem_size == 2); // bool (IOQDQ pass prepared networks) supported |= - (tensor_in.scalar_type() == ScalarType::Bool and + (tensor_in.scalar_type() == ScalarType::Bool && handles.inputs->io[i].elem_size == 1); if (!supported) { ET_LOG( @@ -222,7 +226,8 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { // which require permutation. bool both_int = tensor_in.scalar_type() == ScalarType::Int && handles.inputs->io[i].elem_size == 4; - bool both_char = tensor_in.scalar_type() == ScalarType::Char && + bool both_char = (tensor_in.scalar_type() == ScalarType::Char || + tensor_in.scalar_type() == ScalarType::Byte) && handles.inputs->io[i].elem_size == 1; bool both_short = tensor_in.scalar_type() == ScalarType::Short && handles.inputs->io[i].elem_size == 2; diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index be1ecaa03f5..736a5ffc6b5 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -160,6 +160,7 @@ def get_vgf_compile_spec( compiler_flags: Optional[str] = "", custom_path: Optional[str] = None, tosa_debug_mode: VgfCompileSpec.DebugMode | None = None, + preserve_io_quantization: bool = False, ) -> VgfCompileSpec: """Get the ArmCompileSpec for the default VGF tests, to modify the compile spec before calling .build() to finalize it. @@ -188,6 +189,9 @@ def get_vgf_compile_spec( .dump_debug_info(tosa_debug_mode) ) + if preserve_io_quantization: + compile_spec._set_preserve_io_quantization(True) + return compile_spec diff --git a/backends/arm/test/misc/test_compile_spec.py b/backends/arm/test/misc/test_compile_spec.py index da4bcebda35..cb2f45b5382 100644 --- a/backends/arm/test/misc/test_compile_spec.py +++ b/backends/arm/test/misc/test_compile_spec.py @@ -3,11 +3,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import warnings + from executorch.backends.arm.common.pipeline_config import SoftmaxDecompositionConfig from executorch.backends.arm.ethosu import EthosUCompileSpec from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec from executorch.backends.arm.vgf import VgfCompileSpec -from pytest import raises +from pytest import raises, warns def test_compile_spec_u55_INT(): @@ -68,3 +70,24 @@ def test_compile_spec_tosa_INT(): assert TosaCompileSpec._from_list(spec_list) == compile_spec with raises(ValueError, match="Incorrect output format"): VgfCompileSpec._from_list(spec_list) + + +def test_preserve_io_quantization_roundtrip_vgf_FP_INT(): + compile_spec = VgfCompileSpec()._set_preserve_io_quantization(True) + roundtripped = VgfCompileSpec._from_list(compile_spec._to_list()) + assert roundtripped.preserve_io_quantization is True + + +def test_preserve_io_quantization_warns_for_u55_INT(): + with warns( + UserWarning, + match="preserve_io_quantization=True is redundant for INT-only TOSA", + ): + EthosUCompileSpec("ethos-u55-128")._set_preserve_io_quantization(True) + + +def test_preserve_io_quantization_no_warn_for_vgf_FP_INT(): + with warnings.catch_warnings(record=True) as recorded_warnings: + warnings.simplefilter("always") + VgfCompileSpec()._set_preserve_io_quantization(True) + assert len(recorded_warnings) == 0 diff --git a/backends/arm/test/misc/test_rescale_range.py b/backends/arm/test/misc/test_rescale_range.py index 1075dd4d04f..f34c58995cf 100644 --- a/backends/arm/test/misc/test_rescale_range.py +++ b/backends/arm/test/misc/test_rescale_range.py @@ -1,13 +1,16 @@ -# Copyright 2024-2025 Arm Limited and/or its affiliates. +# Copyright 2024-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Tuple +import executorch.backends.arm.tosa.dialect # noqa: F401 + import pytest import torch +from executorch.backends.arm.tosa.dialect.lib import TosaValueError from executorch.backends.arm.tosa.specification import ( TosaLoweringContext, TosaSpecification, @@ -128,3 +131,76 @@ def test_zp_outside_range_tosa_INT(): ] ) ) + + +def test_unsigned_zp_range_tosa_INT_valid(): + # Validate unsigned zero-point ranges via explicit unsigned semantics. + # First case: uint8 input (input_unsigned=True) uses in_zp in [0,255]. + # Second case: signed int8 input but unsigned output semantics (output_unsigned=True) + # allow out_zp in [0,255]. + sample_inputs = [ + # (data, out_dtype, scale, in_zp, out_zp, input_unsigned, output_unsigned) + ( + torch.randint(low=0, high=255, size=(4, 4, 4), dtype=torch.uint8), + torch.int8, + [0.5], + 255, + 0, + True, + False, + ), + ( + torch.randint(low=-128, high=127, size=(4, 4, 4), dtype=torch.int8), + torch.int8, + [0.5], + 0, + 255, + False, + True, + ), + ] + + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + for sample_input in sample_inputs: + exir_ops.backend.tosa.RESCALE.default( + *tuple( + [ + mode.from_tensor(i) if isinstance(i, torch.Tensor) else i + for i in sample_input[:5] + ] + ), + input_unsigned=sample_input[5], + output_unsigned=sample_input[6], + ) + + +def test_unsigned_zp_range_tosa_INT_invalid(): + with TosaLoweringContext( + TosaSpecification.create_from_string("TOSA-1.0+INT") + ), FakeTensorMode() as mode: + with pytest.raises(TosaValueError, match="(in_zp|input_zp).*range"): + exir_ops.backend.tosa.RESCALE.default( + mode.from_tensor( + torch.randint(low=0, high=255, size=(4, 4, 4), dtype=torch.uint8) + ), + torch.int8, + [0.5], + 256, + 0, + input_unsigned=True, + output_unsigned=False, + ) + with pytest.raises(TosaValueError, match="(out_zp|output_zp).*range"): + exir_ops.backend.tosa.RESCALE.default( + mode.from_tensor( + torch.randint(low=0, high=255, size=(4, 4, 4), dtype=torch.uint8) + ), + torch.int8, + [0.5], + 0, + 256, + input_unsigned=False, + output_unsigned=True, + ) diff --git a/backends/arm/test/passes/test_ioquantization_pass.py b/backends/arm/test/passes/test_ioquantization_pass.py index 606ed6e0f01..e42f7a093cc 100644 --- a/backends/arm/test/passes/test_ioquantization_pass.py +++ b/backends/arm/test/passes/test_ioquantization_pass.py @@ -6,12 +6,38 @@ from typing import Tuple +import pytest import torch -from executorch.backends.arm.test import common +from executorch.backends.arm._passes.insert_rescales_pass import InsertRescalePass +from executorch.backends.arm.constants import DQ_OPS, Q_OPS +from executorch.backends.arm.quantizer import ( + get_symmetric_quantization_config, + get_uint8_io_quantization_config, + TOSAQuantizer, +) -from executorch.backends.arm.test.tester.test_pipeline import EthosU55PipelineINT -from executorch.exir.passes.quantize_io_pass import QuantizeInputs, QuantizeOutputs +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester, RunPasses +from executorch.backends.arm.test.tester.quantize import ArmQuantize as Quantize +from executorch.backends.arm.test.tester.test_pipeline import ( + EthosU55PipelineINT, + TosaPipelineINT, + VgfPipeline, +) +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.backends.test.harness.stages import StageType +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.passes.quantize_io_pass import ( + quantize_input, + quantize_output, + QuantizeInputs, + QuantizeOutputs, +) input_t = Tuple[torch.Tensor, torch.Tensor] @@ -46,3 +72,546 @@ def test_quantize_io_u55_INT(test_data: input_t): edge.transform(passes=[QuantizeInputs(edge, [0, 1]), QuantizeOutputs(edge, [0])]) pipeline.tester.check_not(["edge__ops_quantized_decomposed_quantize_per_tensor"]) pipeline.tester.check_not(["edge__ops_quantized_decomposed_dequantize_per_tensor"]) + + +class SimpleMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 8) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.fc2(self.relu(self.fc1(x))) + + +def _build_dq_q_graph( + input_tensor: torch.Tensor, + dq_dtype: torch.dtype, + q_dtype: torch.dtype, + dq_scale: float, + dq_zp: int, + q_scale: float, + q_zp: int, +): + builder = GraphBuilder() + x = builder.placeholder("x", input_tensor) + qd_ops = exir_ops.edge.quantized_decomposed + dq = builder.call_operator( + qd_ops.dequantize_per_tensor.default, + ( + x, + dq_scale, + dq_zp, + torch.iinfo(dq_dtype).min, + torch.iinfo(dq_dtype).max, + dq_dtype, + ), + ) + q = builder.call_operator( + qd_ops.quantize_per_tensor.default, + ( + dq, + q_scale, + q_zp, + torch.iinfo(q_dtype).min, + torch.iinfo(q_dtype).max, + q_dtype, + ), + ) + builder.output([q]) + return builder.get_graph_module() + + +def test_insert_rescale_tosa_INT_folds_uint8_input(): + graph_module = _build_dq_q_graph( + torch.randint(0, 255, (1, 4), dtype=torch.uint8), + torch.uint8, + torch.int8, + dq_scale=0.5, + dq_zp=0, + q_scale=0.25, + q_zp=0, + ) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + rescale_graph = InsertRescalePass()(graph_module).graph_module.graph + rescale_nodes = [ + node + for node in rescale_graph.nodes + if node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + ] + assert rescale_nodes + assert rescale_nodes[0].kwargs.get("input_unsigned") is True + assert rescale_nodes[0].kwargs.get("output_unsigned") is False + + +def test_insert_rescale_tosa_INT_folds_uint8_output(): + graph_module = _build_dq_q_graph( + torch.randint(-128, 127, (1, 4), dtype=torch.int8), + torch.int8, + torch.uint8, + dq_scale=0.5, + dq_zp=0, + q_scale=0.25, + q_zp=0, + ) + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + rescale_graph = InsertRescalePass()(graph_module).graph_module.graph + rescale_nodes = [ + node + for node in rescale_graph.nodes + if node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + ] + assert rescale_nodes + assert rescale_nodes[0].kwargs.get("input_unsigned") is False + assert rescale_nodes[0].kwargs.get("output_unsigned") is True + assert rescale_nodes[0].args[1] == torch.int8 + + +def test_quantize_io_tosa_INT_uint8_simple_mlp(): + """Float-input MLP uses uint8 IO quantization and folds to a single + delegate. + """ + model = SimpleMLP().eval() + test_data = (torch.rand(1, 4),) + compile_spec = common.get_tosa_compile_spec("TOSA-1.0+INT") + + tester = ArmTester(model, test_data, compile_spec) + quantizer = TOSAQuantizer(compile_spec) + quantizer.set_global(get_symmetric_quantization_config()) + quantizer.set_io(get_uint8_io_quantization_config()) + quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config) + + tester.quantize(quant_stage) + tester.export() + tester.to_edge_transform_and_lower() + tester.check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) + tester.check( + [ + "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_default" + ] + ) + lowered_graph = tester.get_artifact().exported_program().graph_module.graph + delegate_nodes = [ + node + for node in lowered_graph.nodes + if node.op == "call_function" + and node.target == torch.ops.higher_order.executorch_call_delegate + ] + assert len(delegate_nodes) == 1 + quant_nodes = [ + node + for node in lowered_graph.nodes + if node.op == "call_function" + and node.target + == exir_ops.edge.quantized_decomposed.quantize_per_tensor.default + ] + assert len(quant_nodes) == 1 + delegate_args = delegate_nodes[0].all_input_nodes + assert ( + quant_nodes[0] in delegate_args + ), "Expected input quantize to feed the delegate call." + + +def test_quantize_io_tosa_INT_uint8(): + """Make sure quantizer doesn't allow uint8 internally.""" + torch_q_ops = ( + torch.ops.quantized_decomposed.quantize_per_tensor.default, + torch.ops.quantized_decomposed.quantize_per_tensor.tensor, + torch.ops.quantized_decomposed.quantize_per_channel.default, + ) + torch_dq_ops = ( + torch.ops.quantized_decomposed.dequantize_per_tensor.default, + torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, + torch.ops.quantized_decomposed.dequantize_per_channel.default, + ) + + model = SimpleMLP().eval() + test_data = (torch.rand(1, 4),) + compile_spec = common.get_tosa_compile_spec("TOSA-1.0+INT") + + tester = ArmTester(model, test_data, compile_spec) + quantizer = TOSAQuantizer(compile_spec) + quantizer.set_global(get_symmetric_quantization_config()) + quantizer.set_io(get_uint8_io_quantization_config()) + quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config) + + tester.quantize(quant_stage) + tester.export() + + exported_program = tester.get_artifact() + graph = exported_program.graph_module.graph + placeholders = [node for node in graph.nodes if node.op == "placeholder"] + output_nodes = [node for node in graph.nodes if node.op == "output"] + + max_uint8_q_nodes = len(placeholders) + len(output_nodes) + uint8_q_nodes = [] + bad_nodes = [] + for node in graph.nodes: + meta_val = node.meta.get("val") + if not isinstance(meta_val, torch.Tensor): + continue + if meta_val.dtype != torch.uint8: + continue + if node.op in ("placeholder", "output"): + continue + if node.op == "call_function" and node.target in (*Q_OPS, *torch_q_ops): + uint8_q_nodes.append((node.name, node.target)) + continue + if node.op == "call_function" and node.target in (*DQ_OPS, *torch_dq_ops): + bad_nodes.append((node.name, node.target)) + continue + bad_nodes.append((node.name, node.target)) + assert not bad_nodes, ( + "Found internal uint8 tensors outside IO boundaries: " f"{bad_nodes}" + ) + assert len(uint8_q_nodes) <= max_uint8_q_nodes, ( + "Expected uint8 quantize nodes only at IO boundaries; " + f"found {len(uint8_q_nodes)} (max {max_uint8_q_nodes})." + ) + + +def test_quantize_io_tosa_INT_uint8_pipeline(): + """Use TOSA pipeline to build an end-to-end flow and accept uint8 IO.""" + model = SimpleMLP().eval() + test_data = (torch.rand(1, 4),) + pipeline = TosaPipelineINT( + model, + test_data, + [], + [], + run_on_tosa_ref_model=True, + use_to_edge_transform_and_lower=False, + ) + + tester = pipeline.tester + tester.quantize() + tester.export() + tester.to_edge() + edge = tester.get_artifact() + edge.transform( + passes=[ + QuantizeInputs( + edge, + { + 0: { + "scale": 1.0, + "zp": 0, + "dtype": torch.uint8, + } + }, + ), + QuantizeOutputs( + edge, + { + 0: { + "scale": 1.0, + "zp": 0, + "dtype": torch.uint8, + } + }, + ), + ] + ) + + exported_program = edge.exported_program() + graph_module = exported_program.graph_module + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + graph = InsertRescalePass()(graph_module).graph_module.graph + + assert any( + node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + and node.kwargs.get("input_unsigned") + for node in graph.nodes + ), "Expected RESCALE with input_unsigned=True in pipeline flow." + + +def test_quantize_io_tosa_INT_uint8_io_add(): + """Model accepts uint8 inputs/outputs while TOSA uses int8 internally.""" + model = SimpleModel().eval() + test_data = SimpleModel.test_data["rand_rand"] + compile_spec = common.get_tosa_compile_spec("TOSA-1.0+INT") + + tester = ArmTester(model, test_data, compile_spec) + quantizer = TOSAQuantizer(compile_spec) + quantizer.set_global(get_symmetric_quantization_config()) + quantizer.set_io(get_uint8_io_quantization_config()) + quant_stage = Quantize(quantizer, quantization_config=quantizer.global_config) + + tester.quantize(quant_stage) + tester.export() + tester.to_edge() + edge = tester.get_artifact() + edge.transform( + passes=[ + QuantizeInputs(edge, [0, 1]), + QuantizeOutputs(edge, [0]), + ] + ) + + exported_program = edge.exported_program() + graph = exported_program.graph_module.graph + placeholders = [node for node in graph.nodes if node.op == "placeholder"] + assert len(placeholders) == 2 + assert placeholders[0].meta["val"].dtype == torch.uint8 + assert placeholders[1].meta["val"].dtype == torch.uint8 + output_node = graph.output_node() + output_val = output_node.args[0][0] + assert output_val.meta["val"].dtype == torch.uint8 + + graph_module = exported_program.graph_module + with TosaLoweringContext(TosaSpecification.create_from_string("TOSA-1.0+INT")): + rescale_graph = InsertRescalePass()(graph_module).graph_module.graph + + rescale_nodes = [ + node + for node in rescale_graph.nodes + if node.op == "call_function" + and node.target == exir_ops.backend.tosa.RESCALE.default + ] + assert rescale_nodes, "Expected RESCALE ops after lowering." + assert any( + node.kwargs.get("input_unsigned") for node in rescale_nodes + ), "Expected input_unsigned on IO rescale." + assert any( + node.kwargs.get("output_unsigned") for node in rescale_nodes + ), "Expected output_unsigned on IO rescale." + assert all( + node.args[1] == torch.int8 + for node in rescale_nodes + if node.kwargs.get("input_unsigned") or node.kwargs.get("output_unsigned") + ), "Unsigned IO rescales must output int8 internally." + + +def test_quantize_io_tosa_INT_uint8_numeric(): + """Run TOSA flow with uint8 input and verify numerical output.""" + if not TosaPipelineINT.is_tosa_ref_model_available(): + pytest.skip("TOSA reference model not available.") + model = SimpleModel().eval() + calib_input = torch.rand(1, 4) + calib_other = torch.rand(1, 4) + + pipeline = TosaPipelineINT( + model, + (calib_input, calib_other), + aten_op=[], + exir_op=[], + use_to_edge_transform_and_lower=True, + ) + pipeline.quantizer.set_io(get_uint8_io_quantization_config()) + + _run_uint8_io_numeric_pipeline(pipeline, model, calib_input, calib_other) + + +def test_quantize_io_vgf_INT_uint8_numeric(): + """Run VGF flow with uint8 input and verify numerical output.""" + + model = SimpleModel().eval() + calib_input = torch.rand(1, 4) + calib_other = torch.rand(1, 4) + + pipeline = VgfPipeline( + model, + (calib_input, calib_other), + aten_op=[], + exir_op=[], + run_on_vulkan_runtime=True, + quantize=True, + use_to_edge_transform_and_lower=True, + preserve_io_quantization=True, + ) + + pipeline.quantizer.set_io(get_uint8_io_quantization_config()) + + if pipeline.has_stage("check_not.exir_quant_nodes"): + pipeline.pop_stage("check_not.exir_quant_nodes") + _run_uint8_io_numeric_pipeline(pipeline, model, calib_input, calib_other) + + +def test_quantize_io_u55_INT_uint8_numeric(): + """Run Ethos-U55 flow with uint8 input and verify numerical output.""" + model = SimpleModel().eval() + calib_input = torch.rand(1, 4) + calib_other = torch.rand(1, 4) + + if not ( + common.corstone300_installed() + and common.arm_executor_runner_exists("corstone-300") + ): + pytest.xfail("Did not find Corstone-300 FVP or executor_runner on path") + + pipeline = EthosU55PipelineINT( + model, + (calib_input, calib_other), + aten_ops=[], + exir_ops=[], + run_on_fvp=True, + use_to_edge_transform_and_lower=True, + ) + pipeline.quantizer.set_io(get_uint8_io_quantization_config()) + + _run_uint8_io_numeric_pipeline(pipeline, model, calib_input, calib_other) + + +def _run_uint8_io_numeric_pipeline( # noqa: C901 + pipeline, model, calib_input, calib_other +) -> None: + qparams = {} + + def _apply_uint8_io(ep): + in0_scale, in0_zp, in0_qmin, in0_qmax, in0_dtype = quantize_input(ep, 0) + in1_scale, in1_zp, in1_qmin, in1_qmax, in1_dtype = quantize_input(ep, 1) + out_scale, out_zp, out_qmin, out_qmax, out_dtype = quantize_output(ep, 0) + qparams.update( + { + "in0_scale": in0_scale, + "in0_zp": in0_zp, + "in0_qmin": in0_qmin, + "in0_qmax": in0_qmax, + "in0_dtype": in0_dtype, + "in1_scale": in1_scale, + "in1_zp": in1_zp, + "in1_qmin": in1_qmin, + "in1_qmax": in1_qmax, + "in1_dtype": in1_dtype, + "out_scale": out_scale, + "out_zp": out_zp, + "out_qmin": out_qmin, + "out_qmax": out_qmax, + "out_dtype": out_dtype, + } + ) + return ep + + class _Uint8ReferenceStage: + def __init__(self, reference_model): + self.reference_model = reference_model.eval() + + def stage_type(self): + return StageType.RUN_PASSES + + @property + def artifact(self): + return self.reference_model + + @property + def graph_module(self): + return None + + def run_artifact(self, inputs): + def _quantize(tensor, scale, zp, qmin, qmax, dtype): + return torch.ops.quantized_decomposed.quantize_per_tensor( + tensor, scale, zp, qmin, qmax, dtype + ) + + def _dequantize(tensor, scale, zp, qmin, qmax, dtype): + return torch.ops.quantized_decomposed.dequantize_per_tensor( + tensor, scale, zp, qmin, qmax, dtype + ) + + float_x = _dequantize( + inputs[0], + qparams["in0_scale"], + qparams["in0_zp"], + qparams["in0_qmin"], + qparams["in0_qmax"], + qparams["in0_dtype"], + ) + float_y = _dequantize( + inputs[1], + qparams["in1_scale"], + qparams["in1_zp"], + qparams["in1_qmin"], + qparams["in1_qmax"], + qparams["in1_dtype"], + ) + float_out = self.reference_model(float_x, float_y) + ref_u8 = _quantize( + float_out, + qparams["out_scale"], + qparams["out_zp"], + qparams["out_qmin"], + qparams["out_qmax"], + qparams["out_dtype"], + ) + # Match TOSA's signless int8 representation of unsigned outputs. + return ref_u8 + + # Insert quantization of inputs/outputs after lowering so we can run uint8 IO. + pipeline.add_stage_after( + "to_edge_transform_and_lower", + pipeline.tester.run_passes, + RunPasses(pass_functions=[_apply_uint8_io]), + suffix="uint8_io", + ) + pipeline.add_stage_after( + "to_executorch", + lambda: setattr( + pipeline.tester, + "stages", + { + **pipeline.tester.stages, + StageType.RUN_PASSES: _Uint8ReferenceStage(model), + }, + ), + suffix="uint8_ref", + ) + + # Run the pipeline to get the quantization parameters without the standard comparison step + if pipeline.has_stage("run_method_and_compare_outputs"): + pipeline.pop_stage("run_method_and_compare_outputs") + pipeline.run() + + assert qparams["in0_dtype"] == torch.uint8 + assert qparams["in1_dtype"] == torch.uint8 + assert qparams["out_dtype"] == torch.uint8 + + # Calculate the calib inputs and outputs uint8 values given the + # calibrated quantization parameters, so we can run the reference with the same quantized inputs. + input_tensor = torch.ops.quantized_decomposed.quantize_per_tensor( + calib_input, + qparams["in0_scale"], + qparams["in0_zp"], + qparams["in0_qmin"], + qparams["in0_qmax"], + qparams["in0_dtype"], + ) + other_input = torch.ops.quantized_decomposed.quantize_per_tensor( + calib_other, + qparams["in1_scale"], + qparams["in1_zp"], + qparams["in1_qmin"], + qparams["in1_qmax"], + qparams["in1_dtype"], + ) + + # Compare against a reference that dequantizes uint8 inputs, runs the float model, + # and requantizes to match TOSA's signless int8 representation. + def uint8_compare_callback(reference, output, _qparams): + # Map signless int8 to uint8 + output_u8 = output.to(torch.uint8) + reference_u8 = reference.to(torch.uint8) + diff = (output_u8.to(torch.int16) - reference_u8.to(torch.int16)).abs() + if diff.max().item() > 1: + raise AssertionError( + "Output mismatch beyond 1 LSB after uint8 IO flow. " + f"max abs diff={diff.max().item()}" + ) + + compare_stage = ( + StageType.SERIALIZE + if pipeline.has_stage("serialize") + else StageType.TO_EXECUTORCH + ) + pipeline.tester.run_method_and_compare_outputs( + stage=compare_stage, + inputs=(input_tensor, other_input), + qtol=1, + reference_stage_type=StageType.RUN_PASSES, + compare_callback=lambda ref, test, qparmams: uint8_compare_callback( + ref, test, qparams + ), + ) diff --git a/backends/arm/test/quantizer/test_uint8_io_quantization.py b/backends/arm/test/quantizer/test_uint8_io_quantization.py new file mode 100644 index 00000000000..7461ca85a6f --- /dev/null +++ b/backends/arm/test/quantizer/test_uint8_io_quantization.py @@ -0,0 +1,42 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from executorch.backends.arm.quantizer import ( + get_uint8_io_quantization_config, + TOSAQuantizer, +) +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.test_pipeline import QuantizationPipeline + + +class SimpleMLP(torch.nn.Module): + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(4, 8) + self.relu = torch.nn.ReLU() + self.fc2 = torch.nn.Linear(8, 4) + + def forward(self, x): + return self.fc2(self.relu(self.fc1(x))) + + +def test_uint8_io_quantization_config_tosa_INT_applies_to_io(): + model = SimpleMLP().eval() + test_data = (torch.rand(1, 4),) + compile_spec = common.get_tosa_compile_spec("TOSA-1.0+INT") + quantizer = TOSAQuantizer(compile_spec) + quantizer.set_io(get_uint8_io_quantization_config()) + + io_config = get_uint8_io_quantization_config() + pipeline = QuantizationPipeline( + model, + test_data, + quantizer=quantizer, + input_qspecs={io_config.input_activation: 1}, + output_qspecs={io_config.output_activation: 1}, + ) + pipeline.run() diff --git a/backends/arm/test/tester/test_pipeline.py b/backends/arm/test/tester/test_pipeline.py index fe38b7c9690..7e7f576e35c 100644 --- a/backends/arm/test/tester/test_pipeline.py +++ b/backends/arm/test/tester/test_pipeline.py @@ -1184,6 +1184,7 @@ def __init__( tosa_extensions: Optional[List[str]] = None, tosa_spec: TosaSpecification | str | None = None, fold_quantize: bool = True, + preserve_io_quantization: bool = False, ): if tosa_spec is None: if tosa_version is None: @@ -1201,6 +1202,7 @@ def __init__( compiler_flags=vgf_compiler_flags, custom_path=custom_path, tosa_debug_mode=tosa_debug_mode, + preserve_io_quantization=preserve_io_quantization, ) super().__init__( diff --git a/backends/arm/tosa/dialect/ops/rescale.py b/backends/arm/tosa/dialect/ops/rescale.py index 928ff72c9cc..c782ab4ae81 100644 --- a/backends/arm/tosa/dialect/ops/rescale.py +++ b/backends/arm/tosa/dialect/ops/rescale.py @@ -16,11 +16,18 @@ @register_fake_tosa_op( - "RESCALE(Tensor input1, ScalarType dtype, float[] scale, int in_zp, int out_zp) -> Tensor", # schema + "RESCALE(Tensor input1, ScalarType dtype, float[] scale, int in_zp, int out_zp, *, bool input_unsigned=False, bool output_unsigned=False) -> Tensor", # schema TosaSpecification.all_versions_for_profile("INT"), # target TOSA specifications ) -def RESCALE( - x: torch.Tensor, dtype: torch.dtype, scales: List[float], in_zp: int, out_zp: int +def RESCALE( # noqa: C901 + x: torch.Tensor, + dtype: torch.dtype, + scales: List[float], + in_zp: int, + out_zp: int, + *, + input_unsigned: bool = False, + output_unsigned: bool = False, ) -> torch.Tensor: tosa_spec = get_context_spec() """Casts the input tensor to dtype `dtype` to produce the correct tensor @@ -35,20 +42,82 @@ def RESCALE( ) if dtype not in (torch.int32, torch.int8, torch.int16): - raise NotImplementedError( - f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}" - ) - if dtype in (torch.int32, torch.int16) and out_zp != 0: - raise ValueError( - f"TOSA requires output_zp to be zero when the output dtype is {dtype}." + raise TosaValueError( + f"tosa::rescale currently only supports int32, int16 and int8, not {dtype}", + op="RESCALE", ) - if x.dtype in (torch.int32, torch.int16) and in_zp != 0: - raise ValueError( - f"TOSA requires input_zp to be zero when the input dtype is {dtype}" + if input_unsigned and output_unsigned: + raise TosaValueError( + "TOSA requires input_unsigned and output_unsigned not both be true.", + op="RESCALE", ) - if x.dtype == torch.int8 and not -128 <= in_zp <= 127: - raise ValueError(f"{in_zp=} outside valid range (-128,127) for int8.") - if dtype == torch.int8 and not -128 <= out_zp <= 127: - raise ValueError(f"{out_zp=} outside valid range (-128,127) for int8.") + + if input_unsigned: + if x.dtype not in (torch.int8, torch.int16, torch.uint8): + raise TosaValueError( + f"input_unsigned requires int8/int16/uint8 input dtype, got {x.dtype}.", + op="RESCALE", + ) + if x.dtype == torch.int32: + raise TosaValueError( + "TOSA forbids input_unsigned for int32 inputs.", op="RESCALE" + ) + if x.dtype == torch.int16: + if in_zp not in (0, 32768): + raise TosaValueError( + f"{in_zp=} outside valid range (0,32768) for uint16.", + op="RESCALE", + ) + else: + if not 0 <= in_zp <= 255: + raise TosaValueError( + f"{in_zp=} outside valid range (0,255) for uint8.", + op="RESCALE", + ) + else: + if x.dtype in (torch.int32, torch.int16) and in_zp != 0: + raise TosaValueError( + f"TOSA requires input_zp to be zero when the input dtype is {x.dtype}.", + op="RESCALE", + ) + if x.dtype == torch.int8 and not -128 <= in_zp <= 127: + raise TosaValueError( + f"{in_zp=} outside valid range (-128,127) for int8.", + op="RESCALE", + ) + + if output_unsigned: + if dtype not in (torch.int8, torch.int16): + raise TosaValueError( + f"output_unsigned requires int8/int16 output dtype, got {dtype}.", + op="RESCALE", + ) + if dtype == torch.int32: + raise TosaValueError( + "TOSA forbids output_unsigned for int32 outputs.", op="RESCALE" + ) + if dtype == torch.int16: + if out_zp not in (0, 32768): + raise TosaValueError( + f"{out_zp=} outside valid range (0,32768) for uint16.", + op="RESCALE", + ) + else: + if not 0 <= out_zp <= 255: + raise TosaValueError( + f"{out_zp=} outside valid range (0,255) for uint8.", + op="RESCALE", + ) + else: + if dtype in (torch.int32, torch.int16) and out_zp != 0: + raise TosaValueError( + f"TOSA requires output_zp to be zero when the output dtype is {dtype}.", + op="RESCALE", + ) + if dtype == torch.int8 and not -128 <= out_zp <= 127: + raise TosaValueError( + f"{out_zp=} outside valid range (-128,127) for int8.", + op="RESCALE", + ) return torch.empty_like(x, dtype=dtype) diff --git a/backends/arm/tosa/mapping.py b/backends/arm/tosa/mapping.py index 2833da4c9ad..14749d0ec44 100644 --- a/backends/arm/tosa/mapping.py +++ b/backends/arm/tosa/mapping.py @@ -26,7 +26,6 @@ torch.cfloat, torch.complex128, torch.cdouble, - torch.uint8, torch.int64, torch.long, ) @@ -100,6 +99,8 @@ def map_dtype(data_type: torch.dtype) -> Any: torch.half: ts.DType.FP16, torch.bfloat16: ts.DType.BF16, torch.int8: ts.DType.INT8, + # TOSA uses signless int8; unsigned semantics are expressed via RESCALE. + torch.uint8: ts.DType.INT8, torch.int16: ts.DType.INT16, torch.short: ts.DType.INT16, torch.int32: ts.DType.INT32, diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index a7ef79abbef..0a78f15cc3d 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -14,8 +14,9 @@ """ import logging +import operator from itertools import count -from typing import Callable, List, Optional, Sequence, Tuple +from typing import Callable, cast, List, Optional, Sequence, Tuple import torch from executorch.backends.arm._passes.arm_pass_utils import ( @@ -171,7 +172,11 @@ def register_custom_partition_op(self, op: torch._ops.OpOverload) -> None: self._custom_partition_ops.add(op) def _detag_boundary_nodes( - self, module: GraphModule, tag: str, reporter: WhyNoPartitionReporter + self, + module: GraphModule, + tag: str, + reporter: WhyNoPartitionReporter, + detag_first_fp_node: bool = True, ) -> None: """De-tag nodes at the partition boundary. @@ -207,7 +212,7 @@ def _detag_boundary_nodes( # Remove tag from quantize node with input outside partition, # or dequantize node with any output outside partition del node.meta["delegation_tag"] - elif not is_q_node and not is_dq_node: + elif detag_first_fp_node and not is_q_node and not is_dq_node: # For non Q/DQ nodes, remove tag from first node in partition if any input has fp dtype for input in node.all_input_nodes: if is_partitioned(input, tag): @@ -220,6 +225,60 @@ def _detag_boundary_nodes( del node.meta["delegation_tag"] break + def _preserve_io_quantization_enabled(self) -> bool: + """Return True if IO quantization should be preserved from compile + specs. + """ + for spec in self.delegation_spec.compile_specs: + if spec.key != "preserve_io_quantization": + continue + raw = ( + spec.value.decode() + if isinstance(spec.value, (bytes, bytearray)) + else str(spec.value) + ) + return raw.lower() in ("1", "true", "yes") + return False + + def _partition_has_invalid_uint8(self, partition: Partition, tag: str) -> bool: + """Return True if any uint8 appears outside allowed IO nodes. + + TOSA does not have a true uint8 tensor type. Unsigned semantics are only + allowed at IO boundaries and are carried via RESCALE flags. If a + partition contains uint8 in any other node, it will fail later in + lowering, so reject the partition here. + + """ + for node in partition.nodes: + if not is_partitioned(node, tag): + # Ignore nodes that were de-tagged after boundary processing. + continue + dtype: Optional[torch.dtype] = None + meta_val = node.meta.get("val") + if isinstance(meta_val, torch.Tensor): + dtype = meta_val.dtype + else: + dtype = cast(Optional[torch.dtype], node.meta.get("dtype")) + if dtype is None: + try: + dtype = get_first_fake_tensor(node).dtype + except (AttributeError, KeyError, RuntimeError, ValueError): + dtype = None + if dtype is None: + continue + if dtype != torch.uint8: + continue + + is_allowed = node.op in ("placeholder", "output") + is_allowed = is_allowed or ( + node.op == "call_function" and node.target == operator.getitem + ) + # Allow uint8 on Q/DQ nodes that mediate IO quantization. + is_allowed = is_allowed or node.target in Q_OPS or node.target in DQ_OPS + if not is_allowed: + return True + return False + def _tag_module( # noqa self, module: GraphModule, @@ -285,6 +344,24 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: reporter, ) + if self._preserve_io_quantization_enabled(): + # Detag boundary Q/DQ to keep IO quantization outside delegate. + self._detag_boundary_nodes( + module, + tag, + reporter, + detag_first_fp_node=False, + ) + + if self._partition_has_invalid_uint8(partition, tag): + reject_partition( + "Partition contained internal uint8 tensors. Uint8 is only supported at IO boundaries for TOSA backends.", + partition, + reporter, + ) + tags.remove(tag) + continue + # Check whether the partition contains only no-op or non-computational ops. Such partitions don't make sense to delegate, and in the worst case may be optimized away during lowering, which can break compilation." is_nocompute_partition = all( _is_noop_clone(node)