Skip to content

Commit c4dc7e6

Browse files
committed
Arm backend: Handle +FP+INT for vgf and quantized IO
Change-Id: Ie943e1de816d981c0f09d9bd3683881c03e3000c Signed-off-by: Per Åstrand <per.astrand@arm.com>
1 parent d20ad34 commit c4dc7e6

6 files changed

Lines changed: 125 additions & 16 deletions

File tree

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,13 @@ def _ensure_uint8_io_only(self, graph_module: GraphModule) -> None:
5454
if node.op == "call_function" and node.target == operator.getitem:
5555
if all(user.op == "output" for user in node.users):
5656
continue
57+
if (
58+
node.op == "call_function"
59+
and node.target
60+
== exir_ops.edge.dim_order_ops._to_dim_order_copy.default
61+
):
62+
# dim_order is a view-like transform; allow it to preserve uint8 at IO.
63+
continue
5764
if (
5865
node.op == "call_function"
5966
and node.target == exir_ops.backend.tosa.RESCALE.default

backends/arm/runtime/EthosUBackend.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,19 +194,23 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
194194
bool supported = 0;
195195
// 32 bit int (simple non-quantised test cases)
196196
supported |=
197-
(tensor_in.scalar_type() == ScalarType::Int and
197+
(tensor_in.scalar_type() == ScalarType::Int &&
198198
handles.inputs->io[i].elem_size == 4);
199199
// 8 bit int (IOQDQ pass prepared networks)
200200
supported |=
201-
(tensor_in.scalar_type() == ScalarType::Char and
201+
(tensor_in.scalar_type() == ScalarType::Char &&
202+
handles.inputs->io[i].elem_size == 1);
203+
// 8 bit uint8 (IOQDQ pass prepared networks)
204+
supported |=
205+
(tensor_in.scalar_type() == ScalarType::Byte &&
202206
handles.inputs->io[i].elem_size == 1);
203207
// 16 bit int (IOQDQ pass prepared networks)
204208
supported |=
205-
(tensor_in.scalar_type() == ScalarType::Short and
209+
(tensor_in.scalar_type() == ScalarType::Short &&
206210
handles.inputs->io[i].elem_size == 2);
207211
// bool (IOQDQ pass prepared networks)
208212
supported |=
209-
(tensor_in.scalar_type() == ScalarType::Bool and
213+
(tensor_in.scalar_type() == ScalarType::Bool &&
210214
handles.inputs->io[i].elem_size == 1);
211215
if (!supported) {
212216
ET_LOG(
@@ -222,7 +226,8 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
222226
// which require permutation.
223227
bool both_int = tensor_in.scalar_type() == ScalarType::Int &&
224228
handles.inputs->io[i].elem_size == 4;
225-
bool both_char = tensor_in.scalar_type() == ScalarType::Char &&
229+
bool both_char = (tensor_in.scalar_type() == ScalarType::Char ||
230+
tensor_in.scalar_type() == ScalarType::Byte) &&
226231
handles.inputs->io[i].elem_size == 1;
227232
bool both_short = tensor_in.scalar_type() == ScalarType::Short &&
228233
handles.inputs->io[i].elem_size == 2;

backends/arm/test/common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def get_vgf_compile_spec(
160160
compiler_flags: Optional[str] = "",
161161
custom_path: Optional[str] = None,
162162
tosa_debug_mode: VgfCompileSpec.DebugMode | None = None,
163+
preserve_io_quantization: bool = False,
163164
) -> VgfCompileSpec:
164165
"""Get the ArmCompileSpec for the default VGF tests, to modify the compile
165166
spec before calling .build() to finalize it.
@@ -188,6 +189,9 @@ def get_vgf_compile_spec(
188189
.dump_debug_info(tosa_debug_mode)
189190
)
190191

192+
if preserve_io_quantization:
193+
compile_spec._set_preserve_io_quantization(True)
194+
191195
return compile_spec
192196

193197

backends/arm/test/passes/test_ioquantization_pass.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from executorch.backends.arm.test.tester.test_pipeline import (
2424
EthosU55PipelineINT,
2525
TosaPipelineINT,
26+
VgfPipeline,
2627
)
2728
from executorch.backends.arm.tosa.specification import (
2829
TosaLoweringContext,
@@ -402,6 +403,62 @@ def test_quantize_io_tosa_INT_uint8_numeric():
402403
)
403404
pipeline.quantizer.set_io(get_uint8_io_quantization_config())
404405

406+
_run_uint8_io_numeric_pipeline(pipeline, model, calib_input, calib_other)
407+
408+
409+
def test_quantize_io_vgf_INT_uint8_numeric():
410+
"""Run VGF flow with uint8 input and verify numerical output."""
411+
412+
model = SimpleModel().eval()
413+
calib_input = torch.rand(1, 4)
414+
calib_other = torch.rand(1, 4)
415+
416+
pipeline = VgfPipeline(
417+
model,
418+
(calib_input, calib_other),
419+
aten_op=[],
420+
exir_op=[],
421+
run_on_vulkan_runtime=True,
422+
quantize=True,
423+
use_to_edge_transform_and_lower=True,
424+
preserve_io_quantization=True,
425+
)
426+
427+
pipeline.quantizer.set_io(get_uint8_io_quantization_config())
428+
429+
if pipeline.has_stage("check_not.exir_quant_nodes"):
430+
pipeline.pop_stage("check_not.exir_quant_nodes")
431+
_run_uint8_io_numeric_pipeline(pipeline, model, calib_input, calib_other)
432+
433+
434+
def test_quantize_io_u55_INT_uint8_numeric():
435+
"""Run Ethos-U55 flow with uint8 input and verify numerical output."""
436+
model = SimpleModel().eval()
437+
calib_input = torch.rand(1, 4)
438+
calib_other = torch.rand(1, 4)
439+
440+
if not (
441+
common.corstone300_installed()
442+
and common.arm_executor_runner_exists("corstone-300")
443+
):
444+
pytest.xfail("Did not find Corstone-300 FVP or executor_runner on path")
445+
446+
pipeline = EthosU55PipelineINT(
447+
model,
448+
(calib_input, calib_other),
449+
aten_ops=[],
450+
exir_ops=[],
451+
run_on_fvp=True,
452+
use_to_edge_transform_and_lower=True,
453+
)
454+
pipeline.quantizer.set_io(get_uint8_io_quantization_config())
455+
456+
_run_uint8_io_numeric_pipeline(pipeline, model, calib_input, calib_other)
457+
458+
459+
def _run_uint8_io_numeric_pipeline( # noqa: C901
460+
pipeline, model, calib_input, calib_other
461+
) -> None:
405462
qparams = {}
406463

407464
def _apply_uint8_io(ep):
@@ -483,7 +540,6 @@ def _dequantize(tensor, scale, zp, qmin, qmax, dtype):
483540
# Match TOSA's signless int8 representation of unsigned outputs.
484541
return ref_u8
485542

486-
pipeline.pop_stage("run_method_and_compare_outputs.original_model")
487543
# Insert quantization of inputs/outputs after lowering so we can run uint8 IO.
488544
pipeline.add_stage_after(
489545
"to_edge_transform_and_lower",
@@ -505,9 +561,14 @@ def _dequantize(tensor, scale, zp, qmin, qmax, dtype):
505561
)
506562

507563
# Run the pipeline to get the quantization parameters without the standard comparison step
508-
pipeline.pop_stage("run_method_and_compare_outputs")
564+
if pipeline.has_stage("run_method_and_compare_outputs"):
565+
pipeline.pop_stage("run_method_and_compare_outputs")
509566
pipeline.run()
510567

568+
assert qparams["in0_dtype"] == torch.uint8
569+
assert qparams["in1_dtype"] == torch.uint8
570+
assert qparams["out_dtype"] == torch.uint8
571+
511572
# Calculate the calib inputs and outputs uint8 values given the
512573
# calibrated quantization parameters, so we can run the reference with the same quantized inputs.
513574
input_tensor = torch.ops.quantized_decomposed.quantize_per_tensor(
@@ -527,24 +588,26 @@ def _dequantize(tensor, scale, zp, qmin, qmax, dtype):
527588
qparams["in1_dtype"],
528589
)
529590

530-
print(
531-
f"input_tensor: {input_tensor}, other_input: {other_input}, qparams: {qparams}"
532-
)
533-
534591
# Compare against a reference that dequantizes uint8 inputs, runs the float model,
535592
# and requantizes to match TOSA's signless int8 representation.
536593
def uint8_compare_callback(reference, output, _qparams):
537594
# Map signless int8 to uint8
538-
output = output.to(torch.uint8)
539-
diff = (output.to(torch.int16) - reference.to(torch.int16)).abs()
595+
output_u8 = output.to(torch.uint8)
596+
reference_u8 = reference.to(torch.uint8)
597+
diff = (output_u8.to(torch.int16) - reference_u8.to(torch.int16)).abs()
540598
if diff.max().item() > 1:
541599
raise AssertionError(
542600
"Output mismatch beyond 1 LSB after uint8 IO flow. "
543601
f"max abs diff={diff.max().item()}"
544602
)
545603

604+
compare_stage = (
605+
StageType.SERIALIZE
606+
if pipeline.has_stage("serialize")
607+
else StageType.TO_EXECUTORCH
608+
)
546609
pipeline.tester.run_method_and_compare_outputs(
547-
stage=StageType.TO_EXECUTORCH,
610+
stage=compare_stage,
548611
inputs=(input_tensor, other_input),
549612
qtol=1,
550613
reference_stage_type=StageType.RUN_PASSES,

backends/arm/test/tester/test_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,7 @@ def __init__(
11841184
tosa_extensions: Optional[List[str]] = None,
11851185
tosa_spec: TosaSpecification | str | None = None,
11861186
fold_quantize: bool = True,
1187+
preserve_io_quantization: bool = False,
11871188
):
11881189
if tosa_spec is None:
11891190
if tosa_version is None:
@@ -1201,6 +1202,7 @@ def __init__(
12011202
compiler_flags=vgf_compiler_flags,
12021203
custom_path=custom_path,
12031204
tosa_debug_mode=tosa_debug_mode,
1205+
preserve_io_quantization=preserve_io_quantization,
12041206
)
12051207

12061208
super().__init__(

backends/arm/tosa/partitioner.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ def __init__(
152152
self.additional_checks = additional_checks
153153

154154
def _detag_boundary_nodes(
155-
self, module: GraphModule, tag: str, reporter: WhyNoPartitionReporter
155+
self,
156+
module: GraphModule,
157+
tag: str,
158+
reporter: WhyNoPartitionReporter,
159+
detag_first_fp_node: bool = True,
156160
) -> None:
157161
"""De-tag nodes at the partition boundary.
158162
@@ -188,7 +192,7 @@ def _detag_boundary_nodes(
188192
# Remove tag from quantize node with input outside partition,
189193
# or dequantize node with any output outside partition
190194
del node.meta["delegation_tag"]
191-
elif not is_q_node and not is_dq_node:
195+
elif detag_first_fp_node and not is_q_node and not is_dq_node:
192196
# For non Q/DQ nodes, remove tag from first node in partition if any input has fp dtype
193197
for input in node.all_input_nodes:
194198
if is_partitioned(input, tag):
@@ -201,6 +205,21 @@ def _detag_boundary_nodes(
201205
del node.meta["delegation_tag"]
202206
break
203207

208+
def _preserve_io_quantization_enabled(self) -> bool:
209+
"""Return True if IO quantization should be preserved from compile
210+
specs.
211+
"""
212+
for spec in self.delegation_spec.compile_specs:
213+
if spec.key != "preserve_io_quantization":
214+
continue
215+
raw = (
216+
spec.value.decode()
217+
if isinstance(spec.value, (bytes, bytearray))
218+
else str(spec.value)
219+
)
220+
return raw.lower() in ("1", "true", "yes")
221+
return False
222+
204223
def _partition_has_invalid_uint8(self, partition: Partition, tag: str) -> bool:
205224
"""Return True if any uint8 appears outside allowed IO nodes.
206225
@@ -295,6 +314,15 @@ def _tag_module( # noqa
295314
reporter,
296315
)
297316

317+
if self._preserve_io_quantization_enabled():
318+
# Detag boundary Q/DQ to keep IO quantization outside delegate.
319+
self._detag_boundary_nodes(
320+
module,
321+
tag,
322+
reporter,
323+
detag_first_fp_node=False,
324+
)
325+
298326
if self._partition_has_invalid_uint8(partition, tag):
299327
reject_partition(
300328
"Partition contained internal uint8 tensors. Uint8 is only supported at IO boundaries for TOSA backends.",

0 commit comments

Comments
 (0)