2323from executorch .backends .arm .test .tester .test_pipeline import (
2424 EthosU55PipelineINT ,
2525 TosaPipelineINT ,
26+ VgfPipeline ,
2627)
2728from executorch .backends .arm .tosa .specification import (
2829 TosaLoweringContext ,
@@ -402,6 +403,62 @@ def test_quantize_io_tosa_INT_uint8_numeric():
402403 )
403404 pipeline .quantizer .set_io (get_uint8_io_quantization_config ())
404405
406+ _run_uint8_io_numeric_pipeline (pipeline , model , calib_input , calib_other )
407+
408+
409+ def test_quantize_io_vgf_INT_uint8_numeric ():
410+ """Run VGF flow with uint8 input and verify numerical output."""
411+
412+ model = SimpleModel ().eval ()
413+ calib_input = torch .rand (1 , 4 )
414+ calib_other = torch .rand (1 , 4 )
415+
416+ pipeline = VgfPipeline (
417+ model ,
418+ (calib_input , calib_other ),
419+ aten_op = [],
420+ exir_op = [],
421+ run_on_vulkan_runtime = True ,
422+ quantize = True ,
423+ use_to_edge_transform_and_lower = True ,
424+ preserve_io_quantization = True ,
425+ )
426+
427+ pipeline .quantizer .set_io (get_uint8_io_quantization_config ())
428+
429+ if pipeline .has_stage ("check_not.exir_quant_nodes" ):
430+ pipeline .pop_stage ("check_not.exir_quant_nodes" )
431+ _run_uint8_io_numeric_pipeline (pipeline , model , calib_input , calib_other )
432+
433+
434+ def test_quantize_io_u55_INT_uint8_numeric ():
435+ """Run Ethos-U55 flow with uint8 input and verify numerical output."""
436+ model = SimpleModel ().eval ()
437+ calib_input = torch .rand (1 , 4 )
438+ calib_other = torch .rand (1 , 4 )
439+
440+ if not (
441+ common .corstone300_installed ()
442+ and common .arm_executor_runner_exists ("corstone-300" )
443+ ):
444+ pytest .xfail ("Did not find Corstone-300 FVP or executor_runner on path" )
445+
446+ pipeline = EthosU55PipelineINT (
447+ model ,
448+ (calib_input , calib_other ),
449+ aten_ops = [],
450+ exir_ops = [],
451+ run_on_fvp = True ,
452+ use_to_edge_transform_and_lower = True ,
453+ )
454+ pipeline .quantizer .set_io (get_uint8_io_quantization_config ())
455+
456+ _run_uint8_io_numeric_pipeline (pipeline , model , calib_input , calib_other )
457+
458+
459+ def _run_uint8_io_numeric_pipeline ( # noqa: C901
460+ pipeline , model , calib_input , calib_other
461+ ) -> None :
405462 qparams = {}
406463
407464 def _apply_uint8_io (ep ):
@@ -483,7 +540,6 @@ def _dequantize(tensor, scale, zp, qmin, qmax, dtype):
483540 # Match TOSA's signless int8 representation of unsigned outputs.
484541 return ref_u8
485542
486- pipeline .pop_stage ("run_method_and_compare_outputs.original_model" )
487543 # Insert quantization of inputs/outputs after lowering so we can run uint8 IO.
488544 pipeline .add_stage_after (
489545 "to_edge_transform_and_lower" ,
@@ -505,9 +561,14 @@ def _dequantize(tensor, scale, zp, qmin, qmax, dtype):
505561 )
506562
507563 # Run the pipeline to get the quantization parameters without the standard comparison step
508- pipeline .pop_stage ("run_method_and_compare_outputs" )
564+ if pipeline .has_stage ("run_method_and_compare_outputs" ):
565+ pipeline .pop_stage ("run_method_and_compare_outputs" )
509566 pipeline .run ()
510567
568+ assert qparams ["in0_dtype" ] == torch .uint8
569+ assert qparams ["in1_dtype" ] == torch .uint8
570+ assert qparams ["out_dtype" ] == torch .uint8
571+
511572 # Calculate the calib inputs and outputs uint8 values given the
512573 # calibrated quantization parameters, so we can run the reference with the same quantized inputs.
513574 input_tensor = torch .ops .quantized_decomposed .quantize_per_tensor (
@@ -527,24 +588,26 @@ def _dequantize(tensor, scale, zp, qmin, qmax, dtype):
527588 qparams ["in1_dtype" ],
528589 )
529590
530- print (
531- f"input_tensor: { input_tensor } , other_input: { other_input } , qparams: { qparams } "
532- )
533-
534591 # Compare against a reference that dequantizes uint8 inputs, runs the float model,
535592 # and requantizes to match TOSA's signless int8 representation.
536593 def uint8_compare_callback (reference , output , _qparams ):
537594 # Map signless int8 to uint8
538- output = output .to (torch .uint8 )
539- diff = (output .to (torch .int16 ) - reference .to (torch .int16 )).abs ()
595+ output_u8 = output .to (torch .uint8 )
596+ reference_u8 = reference .to (torch .uint8 )
597+ diff = (output_u8 .to (torch .int16 ) - reference_u8 .to (torch .int16 )).abs ()
540598 if diff .max ().item () > 1 :
541599 raise AssertionError (
542600 "Output mismatch beyond 1 LSB after uint8 IO flow. "
543601 f"max abs diff={ diff .max ().item ()} "
544602 )
545603
604+ compare_stage = (
605+ StageType .SERIALIZE
606+ if pipeline .has_stage ("serialize" )
607+ else StageType .TO_EXECUTORCH
608+ )
546609 pipeline .tester .run_method_and_compare_outputs (
547- stage = StageType . TO_EXECUTORCH ,
610+ stage = compare_stage ,
548611 inputs = (input_tensor , other_input ),
549612 qtol = 1 ,
550613 reference_stage_type = StageType .RUN_PASSES ,
0 commit comments