Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions code/evaluation/logical_error_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _ort_quantize_int8(fp32_onnx_path: str, output_path: str, calib_dets: "np.nd
class _DetCalibReader(CalibrationDataReader):

def __init__(self, data):
self._rows = [{"dets": data[i:i + 1].astype("float32")} for i in range(len(data))]
self._rows = [{"dets": data[i:i + 1]} for i in range(len(data))]
self._iter = iter(self._rows)

def get_next(self):
Expand Down Expand Up @@ -1202,7 +1202,7 @@ def run_inference_and_decode_pre_decoder_memory(model, device, dist, cfg) -> dic
mq.quantize(
onnx_path=fp32_onnx_path,
quantize_mode=quant_format,
calibration_data={"dets": calib_dets.astype("float32")},
calibration_data={"dets": calib_dets},
output_path=onnx_path,
**quant_kwargs,
)
Expand Down
46 changes: 46 additions & 0 deletions code/tests/test_onnx_quant_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,52 @@ def test_mq_quantize_called_with_correct_args_int8(self):
self.assertEqual(calib.shape, (calib_num_samples, num_dets))
self.assertEqual(calib.dtype, np.uint8)

def test_mq_quantize_called_with_correct_args_fp8(self):
"""With QUANT_FORMAT=fp8, calibration data must preserve uint8 dtype — not be cast to float32.

Regression test for #52: the original code applied .astype('float32') before passing
calib_dets to mq.quantize, but the ONNX model's 'dets' input is typed uint8, causing:
[ONNXRuntimeError] INVALID_ARGUMENT: Unexpected input data type.
Actual: (tensor(float)), expected: (tensor(uint8))
The fix passes calib_dets directly, preserving the uint8 dtype.
"""
mock_mq = MagicMock()
num_dets = 20
num_obs = 1
loader = _make_fake_dataloader(
num_batches=2, batch_size=32, num_dets=num_dets, num_obs=num_obs
)

with patch.dict(os.environ, {"QUANT_FORMAT": "fp8", "QUANT_CALIB_SAMPLES": "16"}):
quant_format = "fp8"
fp32_path = "model.onnx"
quant_path = "model_fp8.onnx"
calib_num_samples = int(os.environ.get("QUANT_CALIB_SAMPLES", "256"))
calib_dets = _collect_calibration_dets(loader, num_obs, calib_num_samples, num_dets)
quant_kwargs = {"op_types_to_quantize": ["Conv"], "high_precision_dtype": "fp16"}
mock_mq.quantize(
onnx_path=fp32_path,
quantize_mode=quant_format,
calibration_data={"dets": calib_dets},
output_path=quant_path,
**quant_kwargs,
)

mock_mq.quantize.assert_called_once()
call_kwargs = mock_mq.quantize.call_args
self.assertEqual(call_kwargs.kwargs["quantize_mode"], "fp8")
self.assertIn("dets", call_kwargs.kwargs["calibration_data"])
calib = call_kwargs.kwargs["calibration_data"]["dets"]
self.assertEqual(calib.shape, (calib_num_samples, num_dets))
self.assertEqual(
calib.dtype,
np.uint8,
"FP8 calibration data must preserve uint8 dtype; "
"casting to float32 triggers [ONNXRuntimeError] INVALID_ARGUMENT (#52)",
)
self.assertEqual(call_kwargs.kwargs.get("op_types_to_quantize"), ["Conv"])
self.assertEqual(call_kwargs.kwargs.get("high_precision_dtype"), "fp16")

def test_fp8_fail_fast_raises(self):
"""With QUANT_FORMAT=fp8, if mq.quantize raises, a RuntimeError is propagated."""
num_dets = 20
Expand Down
Loading