Skip to content

feat: add MX6 / MX9 fake-quant kernels and W+A PTQ flow#28

Open
Jiarwang77 wants to merge 5 commits into
mainfrom
jiarwang/mx9-fake-quant
Open

feat: add MX6 / MX9 fake-quant kernels and W+A PTQ flow#28
Jiarwang77 wants to merge 5 commits into
mainfrom
jiarwang/mx9-fake-quant

Conversation

@Jiarwang77

Copy link
Copy Markdown
Collaborator

Summary

Adds emulated MX9 and MX6 fake-quantization kernels and wires them into
ALTO's PTQ quantization flow, with a fully-dynamic Weight + Activation (W+A)
recipe for Llama-3.2-1B.

MX6 and MX9 share Quark's fake_quantize_mx6_mx9 math and differ only in the
element bit width (quant_bit=8 for MX9, 5 for MX6).

Changes

  • Kernels
    • alto/kernels/mx9/: MX9 fake-quant kernel (block-wise, shared-prime-bit) +
      format constants, bit-exact with Quark.
    • alto/kernels/mx6/: MX6 wrapper reusing the MX9 math at quant_bit=5.
  • PTQ dispatch wiring
    • ModelPatcher.patch_fake_quantize routes format == "mx9" / "mx6" to the
      matching kernel; plain int8 falls through untouched.
    • registry_patch.inject_format_field() injects the QuantizationArgs.format
      field so recipes carrying format: mx6/mx9 survive pydantic validation.
    • update_weight_zp_scale becomes a no-op when a module has no
      weight_observer, so fully-dynamic weights skip static scale baking.
  • Recipes / configs
    • mx9_wa_recipe.yaml, mx6_wa_recipe.yaml (W+A fully dynamic).
    • llama3_1b_mx9_wa, llama3_1b_mx6_wa config entries.
  • E2E example: examples/llama3.2_1b_mx9.sh (Llama-3.2-1B MX9 W+A validation).
  • Tests under tests/unittest/mx9_mx6/: kernel / dispatch / W+A integration
    for both MX9 and MX6.

Test plan

  • Unit tests: pytest tests/unittest/mx9_mx6/ → 76 passed, 1 skipped
    (Quark-absent bit-exact case). Needs full ALTO env (incl. torchtitan); CPU only.
  • Bit-exact parity vs Quark 0.11.x for both MX9 and MX6.
  • E2E MX9 on Llama-3.2-1B: bash examples/llama3.2_1b_mx9.sh
    → validate loss 2.1140, Training completed.
  • E2E MX6 on Llama-3.2-1B: CONFIG=llama3_1b_mx6_wa bash examples/llama3.2_1b_mx9.sh
    → validate loss 2.1440, Training completed.

Copilot AI left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds emulated MX9/MX6 fake-quantization (QDQ) kernels and integrates them into ALTO’s PTQ flow via a new QuantizationArgs.format field plus dispatch in the patched compressed_tensors fake_quantize. It also adds W+A fully-dynamic Llama-3.2-1B recipes/config entries, an E2E example script, and a dedicated unit-test suite covering kernel correctness, dispatch wiring, and a toy integration lifecycle.

Changes:

  • Add MX9 kernel implementation and MX6 wrapper (reusing MX9 math with quant_bit=5) plus runtime pydantic schema patching for QuantizationArgs.format.
  • Wire dispatch for format == "mx9" / "mx6" into ModelPatcher.patch_fake_quantize, and adjust calibration to no-op for fully-dynamic weights without observers.
  • Add dynamic W+A recipes/configs and comprehensive unit tests + an E2E example script.

Reviewed changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
alto/__init__.py Ensures format field injection runs early; imports mx6/mx9 packages for wiring.
alto/models/patcher.py Adds format-based dispatch to MX9/MX6 fake-quant kernels inside patched fake_quantize.
alto/modifiers/quantization/calibration.py Skips static weight scale baking when no weight_observer exists (fully dynamic).
alto/modifiers/quantization/base.py Minor import formatting/whitespace change.
alto/kernels/mx9/registry_patch.py Injects format: Optional[str] into QuantizationArgs and rebuilds nested schemas.
alto/kernels/mx9/quantize.py Implements MX9 block-wise shared-prime-bit fake-quantization kernel.
alto/kernels/mx9/format.py Defines MX9 constants (BLOCK_SIZE, QUANT_BIT, etc.).
alto/kernels/mx9/__init__.py Imports registry_patch for side-effect wiring.
alto/kernels/mx6/quantize.py MX6 wrapper around MX9 kernel with quant_bit=5.
alto/kernels/mx6/format.py Defines MX6 constants.
alto/kernels/mx6/__init__.py Ensures format-field injection is available when importing mx6 alone.
alto/models/llama3/configs/mx9_wa_recipe.yaml Adds fully-dynamic MX9 W+A recipe.
alto/models/llama3/configs/mx6_wa_recipe.yaml Adds fully-dynamic MX6 W+A recipe.
alto/models/llama3/config_registry.py Registers new llama3_1b_mx9_wa / llama3_1b_mx6_wa configs.
examples/llama3.2_1b_mx9.sh Adds E2E validation script for MX9 W+A flow.
tests/unittest/mx9_mx6/__init__.py Adds test package marker.
tests/unittest/mx9_mx6/test_mx9_quantize.py MX9 kernel correctness tests (golden vector + optional Quark parity + invariants).
tests/unittest/mx9_mx6/test_mx6_quantize.py MX6 wrapper tests (optional Quark parity + invariants).
tests/unittest/mx9_mx6/test_mx9_dispatch.py Validates format survives validation and dispatch routes to MX9 kernel.
tests/unittest/mx9_mx6/test_mx6_dispatch.py Validates format survives validation and dispatch routes to MX6 kernel.
tests/unittest/mx9_mx6/test_mx9_wa_integration.py Toy integration test for MX9 dynamic W+A lifecycle wiring.
tests/unittest/mx9_mx6/test_mx6_wa_integration.py Toy integration test for MX6 dynamic W+A lifecycle wiring.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +125 to +127
n = SHARED_PRIME_BIT_GROUP
assert demote.shape[-1] % n == 0
flat_shape = demote.shape
Comment thread alto/models/patcher.py Outdated
Comment on lines +60 to +64
if getattr(args, "format", None) == "mx9":
from alto.kernels.mx9.format import BLOCK_SIZE
from alto.kernels.mx9.quantize import mx9_fake_quantize

return mx9_fake_quantize(x, block_size=(args.group_size or BLOCK_SIZE))
Comment thread alto/models/patcher.py Outdated
Comment on lines +65 to +69
if getattr(args, "format", None) == "mx6":
from alto.kernels.mx6.format import BLOCK_SIZE
from alto.kernels.mx6.quantize import mx6_fake_quantize

return mx6_fake_quantize(x, block_size=(args.group_size or BLOCK_SIZE))
Comment thread alto/kernels/mx6/quantize.py Outdated
quant_bit: int = QUANT_BIT,
axis: int = -1,
) -> torch.Tensor:
"""对输入执行 MX6 block-wise fake quantization(QDQ)。"""
Comment on lines +25 to +30
def _make_pkg(name: str, path: str) -> types.ModuleType:
mod = types.ModuleType(name)
mod.__path__ = [path]
mod.__package__ = name
sys.modules[name] = mod
return mod
Comment on lines +25 to +30
def _make_pkg(name: str, path: str) -> types.ModuleType:
mod = types.ModuleType(name)
mod.__path__ = [path]
mod.__package__ = name
sys.modules[name] = mod
return mod
Comment thread examples/llama3.2_1b_mx9.sh Outdated
Comment on lines +26 to +28
MODEL_PATH=${MODEL_PATH:-"/wekafs/jiarwang/Llama-3.2-1B"}
VALIDATOR_STEPS=${VALIDATOR_STEPS:-"10"}
CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-"/wekafs/jiarwang/mx9_e2e_logs/ckpt_${CONFIG}_$(date +%Y%m%d_%H%M%S)"}
…uantization

Relocate the emulated MX6/MX9 fake-quant reference and the format-field
registry patch from alto/kernels/{mx6,mx9}/ into
alto/modifiers/quantization/ (mx.py + format_registry.py), and move the
format-field injection into the quantization package __init__. Update the
mx9/mx6 dispatch and quantize test imports to the new module paths.
- mx.py: replace assert on block_size divisibility with an explicit
  ValueError (asserts can be stripped under python -O and give an
  unhelpful message).
- patcher.py: thread args.num_bits through the mx9/mx6 dispatch
  (falling back to the format default) instead of silently ignoring it,
  so an inconsistent num_bits no longer quantizes at the wrong width.
- mx9/mx6 W+A integration tests: route all sys.modules mutations through
  monkeypatch.setitem so stubs are reverted at teardown and cannot leak
  into other tests in the same process.
- examples/llama3.2_1b_mx9.sh: require MODEL_PATH instead of a hardcoded
  /wekafs path and default the checkpoint folder under ./outputs.
@Jiarwang77 Jiarwang77 force-pushed the jiarwang/mx9-fake-quant branch from 821f972 to 0891d9b Compare June 15, 2026 09:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants