feat: add MX6 / MX9 fake-quant kernels and W+A PTQ flow#28
Open
Jiarwang77 wants to merge 5 commits into
Open
Conversation
There was a problem hiding this comment.
Pull request overview
This PR adds emulated MX9/MX6 fake-quantization (QDQ) kernels and integrates them into ALTO’s PTQ flow via a new QuantizationArgs.format field plus dispatch in the patched compressed_tensors fake_quantize. It also adds W+A fully-dynamic Llama-3.2-1B recipes/config entries, an E2E example script, and a dedicated unit-test suite covering kernel correctness, dispatch wiring, and a toy integration lifecycle.
Changes:
- Add MX9 kernel implementation and MX6 wrapper (reusing MX9 math with
quant_bit=5) plus runtime pydantic schema patching forQuantizationArgs.format. - Wire dispatch for
format == "mx9" / "mx6"intoModelPatcher.patch_fake_quantize, and adjust calibration to no-op for fully-dynamic weights without observers. - Add dynamic W+A recipes/configs and comprehensive unit tests + an E2E example script.
Reviewed changes
Copilot reviewed 22 out of 22 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
alto/__init__.py |
Ensures format field injection runs early; imports mx6/mx9 packages for wiring. |
alto/models/patcher.py |
Adds format-based dispatch to MX9/MX6 fake-quant kernels inside patched fake_quantize. |
alto/modifiers/quantization/calibration.py |
Skips static weight scale baking when no weight_observer exists (fully dynamic). |
alto/modifiers/quantization/base.py |
Minor import formatting/whitespace change. |
alto/kernels/mx9/registry_patch.py |
Injects format: Optional[str] into QuantizationArgs and rebuilds nested schemas. |
alto/kernels/mx9/quantize.py |
Implements MX9 block-wise shared-prime-bit fake-quantization kernel. |
alto/kernels/mx9/format.py |
Defines MX9 constants (BLOCK_SIZE, QUANT_BIT, etc.). |
alto/kernels/mx9/__init__.py |
Imports registry_patch for side-effect wiring. |
alto/kernels/mx6/quantize.py |
MX6 wrapper around MX9 kernel with quant_bit=5. |
alto/kernels/mx6/format.py |
Defines MX6 constants. |
alto/kernels/mx6/__init__.py |
Ensures format-field injection is available when importing mx6 alone. |
alto/models/llama3/configs/mx9_wa_recipe.yaml |
Adds fully-dynamic MX9 W+A recipe. |
alto/models/llama3/configs/mx6_wa_recipe.yaml |
Adds fully-dynamic MX6 W+A recipe. |
alto/models/llama3/config_registry.py |
Registers new llama3_1b_mx9_wa / llama3_1b_mx6_wa configs. |
examples/llama3.2_1b_mx9.sh |
Adds E2E validation script for MX9 W+A flow. |
tests/unittest/mx9_mx6/__init__.py |
Adds test package marker. |
tests/unittest/mx9_mx6/test_mx9_quantize.py |
MX9 kernel correctness tests (golden vector + optional Quark parity + invariants). |
tests/unittest/mx9_mx6/test_mx6_quantize.py |
MX6 wrapper tests (optional Quark parity + invariants). |
tests/unittest/mx9_mx6/test_mx9_dispatch.py |
Validates format survives validation and dispatch routes to MX9 kernel. |
tests/unittest/mx9_mx6/test_mx6_dispatch.py |
Validates format survives validation and dispatch routes to MX6 kernel. |
tests/unittest/mx9_mx6/test_mx9_wa_integration.py |
Toy integration test for MX9 dynamic W+A lifecycle wiring. |
tests/unittest/mx9_mx6/test_mx6_wa_integration.py |
Toy integration test for MX6 dynamic W+A lifecycle wiring. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Comment on lines
+125
to
+127
| n = SHARED_PRIME_BIT_GROUP | ||
| assert demote.shape[-1] % n == 0 | ||
| flat_shape = demote.shape |
Comment on lines
+60
to
+64
| if getattr(args, "format", None) == "mx9": | ||
| from alto.kernels.mx9.format import BLOCK_SIZE | ||
| from alto.kernels.mx9.quantize import mx9_fake_quantize | ||
|
|
||
| return mx9_fake_quantize(x, block_size=(args.group_size or BLOCK_SIZE)) |
Comment on lines
+65
to
+69
| if getattr(args, "format", None) == "mx6": | ||
| from alto.kernels.mx6.format import BLOCK_SIZE | ||
| from alto.kernels.mx6.quantize import mx6_fake_quantize | ||
|
|
||
| return mx6_fake_quantize(x, block_size=(args.group_size or BLOCK_SIZE)) |
| quant_bit: int = QUANT_BIT, | ||
| axis: int = -1, | ||
| ) -> torch.Tensor: | ||
| """对输入执行 MX6 block-wise fake quantization(QDQ)。""" |
Comment on lines
+25
to
+30
| def _make_pkg(name: str, path: str) -> types.ModuleType: | ||
| mod = types.ModuleType(name) | ||
| mod.__path__ = [path] | ||
| mod.__package__ = name | ||
| sys.modules[name] = mod | ||
| return mod |
Comment on lines
+25
to
+30
| def _make_pkg(name: str, path: str) -> types.ModuleType: | ||
| mod = types.ModuleType(name) | ||
| mod.__path__ = [path] | ||
| mod.__package__ = name | ||
| sys.modules[name] = mod | ||
| return mod |
Comment on lines
+26
to
+28
| MODEL_PATH=${MODEL_PATH:-"/wekafs/jiarwang/Llama-3.2-1B"} | ||
| VALIDATOR_STEPS=${VALIDATOR_STEPS:-"10"} | ||
| CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-"/wekafs/jiarwang/mx9_e2e_logs/ckpt_${CONFIG}_$(date +%Y%m%d_%H%M%S)"} |
…uantization
Relocate the emulated MX6/MX9 fake-quant reference and the format-field
registry patch from alto/kernels/{mx6,mx9}/ into
alto/modifiers/quantization/ (mx.py + format_registry.py), and move the
format-field injection into the quantization package __init__. Update the
mx9/mx6 dispatch and quantize test imports to the new module paths.
- mx.py: replace assert on block_size divisibility with an explicit ValueError (asserts can be stripped under python -O and give an unhelpful message). - patcher.py: thread args.num_bits through the mx9/mx6 dispatch (falling back to the format default) instead of silently ignoring it, so an inconsistent num_bits no longer quantizes at the wrong width. - mx9/mx6 W+A integration tests: route all sys.modules mutations through monkeypatch.setitem so stubs are reverted at teardown and cannot leak into other tests in the same process. - examples/llama3.2_1b_mx9.sh: require MODEL_PATH instead of a hardcoded /wekafs path and default the checkpoint folder under ./outputs.
821f972 to
0891d9b
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds emulated MX9 and MX6 fake-quantization kernels and wires them into
ALTO's PTQ quantization flow, with a fully-dynamic Weight + Activation (W+A)
recipe for Llama-3.2-1B.
MX6 and MX9 share Quark's
fake_quantize_mx6_mx9math and differ only in theelement bit width (
quant_bit=8for MX9,5for MX6).Changes
alto/kernels/mx9/: MX9 fake-quant kernel (block-wise, shared-prime-bit) +format constants, bit-exact with Quark.
alto/kernels/mx6/: MX6 wrapper reusing the MX9 math atquant_bit=5.ModelPatcher.patch_fake_quantizeroutesformat == "mx9" / "mx6"to thematching kernel; plain int8 falls through untouched.
registry_patch.inject_format_field()injects theQuantizationArgs.formatfield so recipes carrying
format: mx6/mx9survive pydantic validation.update_weight_zp_scalebecomes a no-op when a module has noweight_observer, so fully-dynamic weights skip static scale baking.mx9_wa_recipe.yaml,mx6_wa_recipe.yaml(W+A fully dynamic).llama3_1b_mx9_wa,llama3_1b_mx6_waconfig entries.examples/llama3.2_1b_mx9.sh(Llama-3.2-1B MX9 W+A validation).tests/unittest/mx9_mx6/: kernel / dispatch / W+A integrationfor both MX9 and MX6.
Test plan
pytest tests/unittest/mx9_mx6/→ 76 passed, 1 skipped(Quark-absent bit-exact case). Needs full ALTO env (incl.
torchtitan); CPU only.bash examples/llama3.2_1b_mx9.sh→ validate loss 2.1140, Training completed.
CONFIG=llama3_1b_mx6_wa bash examples/llama3.2_1b_mx9.sh→ validate loss 2.1440, Training completed.