Skip to content

Fix export of fp8 ONNX files#52

Merged
ivanbasov merged 2 commits into
mainfrom
pr-fix-fp8-export
Apr 8, 2026
Merged

Fix export of fp8 ONNX files#52
ivanbasov merged 2 commits into
mainfrom
pr-fix-fp8-export

Conversation

@bmhowe23
Copy link
Copy Markdown
Collaborator

@bmhowe23 bmhowe23 commented Apr 8, 2026

Without this change:

$ ONNX_WORKFLOW=2 QUANT_FORMAT=fp8 DISTANCE=13 N_ROUNDS=104 PREDECODER_INFERENCE_NUM_SAMPLES=2048 WORKFLOW=inference EXPERIMENT_NAME=predecoder_model_1 bash code/scripts/local_run.sh
...
[LER] ONNX export failed: [LER] FP8 ONNX quantization failed (fail-fast): [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Unexpected input data type. Actual: (tensor(float)) , expected: (tensor(uint8)); falling back to PyTorch.
...

With this change, no such error is produced. Now we can properly export fp8 ONNX files.

bmhowe23 and others added 2 commits April 7, 2026 21:07
Signed-off-by: Ben Howe <bhowe@nvidia.com>
`_collect_calibration_dets` returns uint8; casting to float32 before
passing to mq.quantize triggered an INVALID_ARGUMENT error from the
ONNX runtime ("expected: tensor(uint8), got: tensor(float)").
The new test mirrors the existing int8 variant and asserts that the
fp8 path preserves the original uint8 dtype and forwards the
FP8-specific kwargs (op_types_to_quantize, high_precision_dtype).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@IgorBaratta IgorBaratta left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ivanbasov ivanbasov merged commit 2cbb707 into main Apr 8, 2026
17 checks passed
@ivanbasov ivanbasov deleted the pr-fix-fp8-export branch April 8, 2026 15:56
ivanbasov added a commit that referenced this pull request Apr 10, 2026
* Fix export of fp8 ONNX files

Signed-off-by: Ben Howe <bhowe@nvidia.com>

* test: add fp8 calibration dtype regression test for #52

`_collect_calibration_dets` returns uint8; casting to float32 before
passing to mq.quantize triggered an INVALID_ARGUMENT error from the
ONNX runtime ("expected: tensor(uint8), got: tensor(float)").
The new test mirrors the existing int8 variant and asserts that the
fp8 path preserves the original uint8 dtype and forwards the
FP8-specific kwargs (op_types_to_quantize, high_precision_dtype).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Signed-off-by: Ben Howe <bhowe@nvidia.com>
Co-authored-by: Ivan Basov <ibasov@nvidia.com>
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants