From 2513b6b28078baa968fd9c138dc7554c7170d09a Mon Sep 17 00:00:00 2001 From: jiarwang Date: Mon, 8 Jun 2026 11:45:59 +0000 Subject: [PATCH 1/5] feat: mx9: add MX9 fake-quant kernel with bit-exact tests --- alto/kernels/mx9/__init__.py | 12 ++ alto/kernels/mx9/format.py | 11 ++ alto/kernels/mx9/quantize.py | 155 ++++++++++++++++ tests/unittest/mx9_mx6/__init__.py | 3 + tests/unittest/mx9_mx6/test_mx9_quantize.py | 186 ++++++++++++++++++++ 5 files changed, 367 insertions(+) create mode 100644 alto/kernels/mx9/__init__.py create mode 100644 alto/kernels/mx9/format.py create mode 100644 alto/kernels/mx9/quantize.py create mode 100644 tests/unittest/mx9_mx6/__init__.py create mode 100644 tests/unittest/mx9_mx6/test_mx9_quantize.py diff --git a/alto/kernels/mx9/__init__.py b/alto/kernels/mx9/__init__.py new file mode 100644 index 0000000..504c23d --- /dev/null +++ b/alto/kernels/mx9/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""MX9 emulated (fake-quant) format. + +NOTE: this package currently provides fake-quantization (QDQ emulation) only -- +there is no real packed kernel / GEMM. It plugs into the standard +``QuantizationModifier`` path via runtime patches, without touching the generic +calibration / observer machinery. The patches that wire it into the quant path +(``QuantizationArgs.format`` field injection + ``fake_quantize`` mx9 dispatch) +are added separately via ``registry_patch`` and ``alto.models.patcher``. +""" diff --git a/alto/kernels/mx9/format.py b/alto/kernels/mx9/format.py new file mode 100644 index 0000000..c06d354 --- /dev/null +++ b/alto/kernels/mx9/format.py @@ -0,0 +1,11 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""MX9 format constants (fake-quant emulation only; no real packed kernel).""" + +# Matches Quark mx9 reference (test_mx.py uses block_size=16; Quark leaves it +# configurable via MX9Spec.block_size). 16 is the L2 block; the shared prime bit +# is shared across every SHARED_PRIME_BIT_GROUP (=2) elements. +BLOCK_SIZE = 16 +QUANT_BIT = 8 +SHARED_PRIME_BIT_GROUP = 2 diff --git a/alto/kernels/mx9/quantize.py b/alto/kernels/mx9/quantize.py new file mode 100644 index 0000000..696890b --- /dev/null +++ b/alto/kernels/mx9/quantize.py @@ -0,0 +1,155 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""MX9 fake-quantize kernel (block-wise emulation; no real packing/kernel inference). + +This is a port of Quark's ``fake_quantize_mx6_mx9`` for the ``quant_bit=8`` (MX9) +case, with its helpers (``_t_exponent`` / ``_reshape_to_blocks`` / ``_pad_to_blocks``) +inlined here. + +MX9 in a nutshell: + 1. Split the input along ``axis`` into blocks of ``block_size`` elements. + 2. Each block derives a shared exponent ``max_exp = floor(log2(amax))`` from its + max absolute value. + 3. Every ``SHARED_PRIME_BIT_GROUP`` (=2) adjacent elements share one prime bit: + if both elements in a pair sit at least one exponent below the block max, the + pair uses ``max_exp - 1`` (finer resolution); otherwise ``max_exp``. + 4. The shared exponent yields a power-of-two scale; the block is then + round -> clamp -> dequant. + +The "9" is not ``quant_bit=9``: elements are still quantized to 8-bit integers, and +the extra "1" is the per-pair prime bit that decides the one-exponent demotion. +Intuitively: 8-bit signed integer value + a shared prime bit for exponent refinement. + - ``QUANT_BIT = 8``: elements round to an 8-bit integer grid. + - ``SHARED_PRIME_BIT_GROUP = 2``: every two adjacent elements share a prime bit. + +This is fake quantization: the output keeps the input dtype (bf16/fp16/fp32) with +values projected onto the MX9-representable grid. The scale is computed from the +data at runtime; any externally supplied compressed_tensors scale is ignored here. +""" + +import torch + +from .format import BLOCK_SIZE, QUANT_BIT, SHARED_PRIME_BIT_GROUP + + +def _pad_to_blocks(x: torch.Tensor, block_size: int) -> torch.Tensor: + """Right-pad the last dim with zeros to a multiple of ``block_size``.""" + pad = block_size - x.size(-1) % block_size + if pad == block_size: + return x + return torch.nn.functional.pad(x, (0, pad)) + + +def _reshape_to_blocks(x: torch.Tensor, block_size: int, axis: int) -> torch.Tensor: + """Reshape into ``[num_rows, num_blocks, block_size]``. + + Blocking always happens along the last dim, so the requested ``axis`` is moved + there first (e.g. weight ``[out, in]`` with axis=-1 blocks each output channel's + input dim; activation ``[batch, seq, hidden]`` blocks each token's hidden vector). + """ + if axis > x.dim() - 1: + raise IndexError("Axis is larger than number of tensor dimensions") + + # Move the quantized axis to the last dim, flatten the rest, pad, then split + # the last dim into blocks. Padding is trimmed after quantization. + x = x.transpose(axis, -1) + x = x.reshape(-1, x.size(-1)) + x = _pad_to_blocks(x, block_size) + return x.reshape(x.size(0), x.size(1) // block_size, block_size) + + +def _exponent_frexp_no_exception(t: torch.Tensor) -> torch.Tensor: + """Read the exponent straight from the float bit pattern (~ floor(log2(|x|))).""" + with torch.no_grad(): + if t.dtype == torch.float32: + return ((t.view(torch.int32) >> 23) & 0xFF) - 127 + if t.dtype == torch.bfloat16: + return ((t.view(torch.int16) >> 7) & 0xFF) - 127 + if t.dtype == torch.float16: + # Matches Quark: fp16 exponent bias treated as -15. + return ((t.view(torch.int16) >> 10) & 0x1F) - 15 + raise ValueError(f"Unsupported data type: {t.dtype}") + + +def _t_exponent(t: torch.Tensor) -> torch.Tensor: + """Per-element exponent; NaN/Inf are zeroed before extraction.""" + with torch.no_grad(): + t = torch.nan_to_num(t, nan=0, posinf=0, neginf=0) + return _exponent_frexp_no_exception(t) + + +def mx9_fake_quantize( + input_tensor: torch.Tensor, + block_size: int = BLOCK_SIZE, + quant_bit: int = QUANT_BIT, + axis: int = -1, +) -> torch.Tensor: + """Block-wise MX9 fake quantization (QDQ), bit-exact with Quark's + ``fake_quantize_mx6_mx9(quant_bit=8)``. + + Args: + input_tensor: tensor to quantize (weight or activation). + block_size: elements per MX9 block. + quant_bit: element integer width (8 for MX9, 5 for MX6). + axis: axis to block along. Only ``axis=-1`` is supported on the ALTO path. + """ + if axis != -1: + raise NotImplementedError("mx9_fake_quantize supports axis=-1 only") + + input_dtype = input_tensor.dtype + + # Shape to restore after blocking: _reshape_to_blocks moves axis to the last dim, + # so record the post-transpose shape to rebuild before transposing back. + input_shape = list(input_tensor.shape) + input_shape[-1], input_shape[axis] = input_shape[axis], input_shape[-1] + + # Detached block view used only for the per-block max exponent (no gradient). + block_x = _reshape_to_blocks(input_tensor.detach(), block_size, axis) + # Out-of-place nan_to_num: in-place breaks torch.compile/inductor functional graphs. + block_x = torch.nan_to_num(block_x, nan=0.0, posinf=0.0, neginf=0.0) + + # Per-block max abs -> shared block exponent max_exp = floor(log2(amax)). + amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) + max_exp = _t_exponent(amax) # [num_rows, num_blocks, 1] + + # Block the original input for the per-element QDQ. + input_tensor = _reshape_to_blocks(input_tensor, block_size, axis) + t_exp = _t_exponent(input_tensor) # [num_rows, num_blocks, block_size] + + # Per element: is it at least one exponent below its block max? (max_exp broadcasts) + demote = max_exp - t_exp >= 1 + + # Prime bit: a pair is demoted only if BOTH its elements are below the block max. + # Reshape last dim into pairs, AND within each pair, then broadcast back. + n = SHARED_PRIME_BIT_GROUP + assert demote.shape[-1] % n == 0 + flat_shape = demote.shape + demote = demote.reshape(*flat_shape[:-1], flat_shape[-1] // n, n) + demote = torch.sum(demote, -1, keepdim=True) == n + demote = demote.repeat(*([1] * (demote.dim() - 1)), n).reshape(flat_shape) + + # Demoted pairs drop one exponent (the prime-bit refinement). + shared_exp = max_exp - demote.long() + + # Final power-of-two scale: 2 ** (shared_exp - quant_bit + 2) (=2**(shared_exp-6) for MX9). + scale = torch.pow(2.0, shared_exp - quant_bit + 2) + + # Max representable integer magnitude. clamp_max guards against float overflow + # when max_exp is large: quant_max = (2**(max_exp+1) - scale) / scale. + quant_max = ( + torch.clamp_max( + torch.pow(2.0, max_exp.to(torch.float64) + 1) - scale, + torch.finfo(torch.float32).max, + ).to(torch.float32) + / scale + ) + + # QDQ: round to the integer grid, clamp to range, dequant back to float. + output_tensor = torch.round(input_tensor / scale) + output_tensor = torch.clamp(output_tensor, -quant_max, quant_max) * scale + + # Flatten blocks, drop padding, restore shape/dtype, and transpose axis back. + output_tensor = output_tensor.reshape(output_tensor.size(0), -1) + output_tensor = output_tensor[:, : input_shape[-1]].reshape(input_shape).to(input_dtype) + return output_tensor.transpose(axis, -1) diff --git a/tests/unittest/mx9_mx6/__init__.py b/tests/unittest/mx9_mx6/__init__.py new file mode 100644 index 0000000..85eac52 --- /dev/null +++ b/tests/unittest/mx9_mx6/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT diff --git a/tests/unittest/mx9_mx6/test_mx9_quantize.py b/tests/unittest/mx9_mx6/test_mx9_quantize.py new file mode 100644 index 0000000..3dabedc --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx9_quantize.py @@ -0,0 +1,186 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Tests for the MX9 fake-quantize kernel. + +Three layers: + 1. Golden-vector parity vs Quark's published mx9 output (hardcoded, so it runs + even when Quark is not installed). This is the primary correctness anchor. + 2. Optional bit-exact parity vs a live Quark import (skipped if unavailable). + 3. Format-property checks (block independence, prime-bit demotion, padding, + NaN/Inf, dtype, idempotence, axis contract). +""" + +import pytest +import torch + +from alto.kernels.mx9.format import BLOCK_SIZE, QUANT_BIT +from alto.kernels.mx9.quantize import mx9_fake_quantize + + +# --------------------------------------------------------------------------- # +# Helpers +# --------------------------------------------------------------------------- # +def _quark_interesting_pattern() -> torch.Tensor: + """Quark test_mx.py:228 create_4d_tensor_with_interesting_pattern().""" + result = torch.zeros(1, 2, 3, 4) + for x1 in range(2): + for x2 in range(3): + for x3 in range(4): + result[0, x1, x2, x3] = (x3 + 1.0) * (10 * (x2 + 1.0)) + 100.0 * x1 + return result + + +# Quark's published mx9 output for the pattern above @ block_size=16, axis=-1 +# (test_mx.py:341-348). MX9 is high precision, so it round-trips these integers +# exactly. +_QUARK_MX9_GOLDEN = torch.tensor( + [ + [ + [[10.0, 20.0, 30.0, 40.0], [20.0, 40.0, 60.0, 80.0], [30.0, 60.0, 90.0, 120.0]], + [[110.0, 120.0, 130.0, 140.0], [120.0, 140.0, 160.0, 180.0], [130.0, 160.0, 190.0, 220.0]], + ] + ] +) + + +def _quark_mx9_reference(x: torch.Tensor, block_size: int): + """Live Quark reference; returns None (caller skips) if Quark is unavailable.""" + try: + from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( + fake_quantize_mx6_mx9, + ) + except Exception: + return None + return fake_quantize_mx6_mx9(x.clone(), axis=-1, block_size=block_size, quant_bit=QUANT_BIT) + + +def _rand(shape, dtype): + torch.manual_seed(0) + return torch.randn(*shape, dtype=dtype) + + +# --------------------------------------------------------------------------- # +# Layer 1: golden vector (always runs, no Quark needed) +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_mx9_matches_quark_golden_vector(dtype): + x = _quark_interesting_pattern().to(dtype) + out = mx9_fake_quantize(x, block_size=16) + expected = _QUARK_MX9_GOLDEN.to(dtype) + assert out.dtype == dtype + assert torch.equal(out, expected), (out - expected).abs().max().item() + + +# --------------------------------------------------------------------------- # +# Layer 2: live Quark bit-exact (bonus, skips if Quark absent) +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("shape", [(4, 64), (8, 16), (2, 3, 32), (3, 40)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("block_size", [16, 32]) +def test_mx9_matches_quark_bit_exact(shape, dtype, block_size): + x = _rand(shape, dtype) + ref = _quark_mx9_reference(x, block_size) + if ref is None: + pytest.skip("Quark reference (fake_quantize_mx6_mx9) not importable") + + out = mx9_fake_quantize(x, block_size=block_size, quant_bit=QUANT_BIT) + assert out.dtype == ref.dtype + assert torch.equal(out, ref), (out - ref).abs().max().item() + + +# --------------------------------------------------------------------------- # +# Layer 3: format-property checks +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +def test_shape_and_dtype_preserved(dtype): + x = _rand((4, 64), dtype) + out = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + assert out.shape == x.shape + assert out.dtype == dtype + + +def test_block_independence(): + # A block of huge values must not degrade the scale of a separate small block. + # Two rows = two independent blocks (last dim == block_size). + big = torch.full((1, BLOCK_SIZE), 100.0) + small = torch.full((1, BLOCK_SIZE), 0.01) + joint = mx9_fake_quantize(torch.cat([big, small], dim=0), block_size=BLOCK_SIZE) + alone = mx9_fake_quantize(small, block_size=BLOCK_SIZE) + # small block's result is identical whether or not the big block shares the tensor. + assert torch.equal(joint[1:2], alone) + + +def test_prime_bit_demotes_small_pair(): + # MX9's distinguishing feature: when BOTH elements of an adjacent pair are + # >= 1 octave below the block max, that pair's scale drops one exponent, + # giving the pair finer resolution. With block max 128 (max_exp=7) the + # non-demoted scale is 2.0 and the demoted scale is 1.0, so v=1.4 rounds to + # 2.0 (err 0.6) without demotion but to 1.0 (err 0.4) with it. + v = 1.4 + block = torch.zeros(1, BLOCK_SIZE) + block[0, 0] = 128.0 # block max -> sets shared exponent, never demoted + block[0, 1] = v # pair (idx 0,1): idx0 is the max -> NOT demoted + block[0, 2] = v # pair (idx 2,3): BOTH small -> demoted + block[0, 3] = v + out = mx9_fake_quantize(block, block_size=BLOCK_SIZE) + # Hard values: with max_exp=7 the non-demoted scale is 2.0 and the demoted + # scale is 1.0, so v=1.4 lands on exactly those grid points. + assert out[0, 1].item() == 2.0 # non-demoted pair (idx0 is the max) + assert out[0, 2].item() == 1.0 # demoted pair (both small) + assert out[0, 3].item() == 1.0 + err_non_demoted = (out[0, 1] - v).abs().item() + err_demoted = (out[0, 2] - v).abs().item() + assert err_demoted < err_non_demoted + + +def test_padding_path_non_divisible_last_dim(): + x = _rand((3, 40), torch.float32) # 40 not divisible by 16 + out = mx9_fake_quantize(x, block_size=16) + assert out.shape == x.shape + # Padding zeros must not corrupt the real elements: quantizing the same + # data laid out as exact blocks gives the matching prefix. + assert not torch.isnan(out).any() + + +def test_inf_does_not_poison_block(): + # nan_to_num is applied to the amax copy only (matching Quark), so ±Inf in + # the data does NOT blow up the shared scale: the real elements survive and + # Inf is clamped to the finite quant range. NaN in the *data* path is not + # scrubbed and would propagate -- documented limitation, tested separately. + x = torch.full((1, BLOCK_SIZE), 4.0) + x[0, 0] = float("inf") + x[0, 1] = float("-inf") + out = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + assert not torch.isinf(out).any() + # The real 4.0 values are unaffected by the poisoned amax candidate. + assert torch.equal(out[0, 2:], torch.full((BLOCK_SIZE - 2,), 4.0)) + + +def test_nan_in_data_propagates(): + # Known limitation faithful to Quark: a NaN element survives QDQ as NaN. + # Locking this in so any future change to the data-path sanitization is loud. + x = torch.full((1, BLOCK_SIZE), 4.0) + x[0, 0] = float("nan") + out = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + assert torch.isnan(out[0, 0]) + assert torch.equal(out[0, 1:], torch.full((BLOCK_SIZE - 1,), 4.0)) + + +def test_qdq_idempotent(): + x = _rand((8, 64), torch.float32) + once = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + twice = mx9_fake_quantize(once, block_size=BLOCK_SIZE) + assert torch.equal(once, twice) + + +def test_zeros_stay_zero(): + x = torch.zeros((2, 32), dtype=torch.float32) + out = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + assert torch.equal(out, x) + + +def test_axis_contract(): + x = _rand((4, 32), torch.float32) + with pytest.raises(NotImplementedError): + mx9_fake_quantize(x, block_size=BLOCK_SIZE, axis=0) From 2693afdc5935d5e26ee8b294b40b8ec975e4837a Mon Sep 17 00:00:00 2001 From: jiarwang Date: Tue, 9 Jun 2026 11:29:42 +0000 Subject: [PATCH 2/5] feat: mx9: add MX9 into the PTQ quantization flow --- alto/__init__.py | 5 + alto/kernels/mx9/__init__.py | 4 +- alto/kernels/mx9/registry_patch.py | 44 +++++ alto/models/llama3/config_registry.py | 11 ++ alto/models/llama3/configs/mx9_wa_recipe.yaml | 25 +++ alto/models/patcher.py | 5 + alto/modifiers/quantization/base.py | 6 +- alto/modifiers/quantization/calibration.py | 5 + examples/llama3.2_1b_mx9.sh | 57 ++++++ tests/unittest/mx9_mx6/test_mx9_dispatch.py | 121 +++++++++++++ .../mx9_mx6/test_mx9_wa_integration.py | 170 ++++++++++++++++++ 11 files changed, 451 insertions(+), 2 deletions(-) create mode 100644 alto/kernels/mx9/registry_patch.py create mode 100644 alto/models/llama3/configs/mx9_wa_recipe.yaml create mode 100755 examples/llama3.2_1b_mx9.sh create mode 100644 tests/unittest/mx9_mx6/test_mx9_dispatch.py create mode 100644 tests/unittest/mx9_mx6/test_mx9_wa_integration.py diff --git a/alto/__init__.py b/alto/__init__.py index 2bf2a5d..0d74e9e 100644 --- a/alto/__init__.py +++ b/alto/__init__.py @@ -7,3 +7,8 @@ from .observers import * from .modifiers import * from .models import * + +# Runtime patches with import-time side effects (kept after the API re-exports): +# importing the mx9 kernel package injects the QuantizationArgs.format field so +# recipes carrying ``format: mx9`` parse correctly. Do not remove. +from .kernels import mx9 # noqa: F401,E402 diff --git a/alto/kernels/mx9/__init__.py b/alto/kernels/mx9/__init__.py index 504c23d..bfc57fb 100644 --- a/alto/kernels/mx9/__init__.py +++ b/alto/kernels/mx9/__init__.py @@ -8,5 +8,7 @@ ``QuantizationModifier`` path via runtime patches, without touching the generic calibration / observer machinery. The patches that wire it into the quant path (``QuantizationArgs.format`` field injection + ``fake_quantize`` mx9 dispatch) -are added separately via ``registry_patch`` and ``alto.models.patcher``. +live in ``registry_patch`` and ``alto.models.patcher``. """ + +from . import registry_patch # noqa: F401 injects the QuantizationArgs.format field; do not remove diff --git a/alto/kernels/mx9/registry_patch.py b/alto/kernels/mx9/registry_patch.py new file mode 100644 index 0000000..e537764 --- /dev/null +++ b/alto/kernels/mx9/registry_patch.py @@ -0,0 +1,44 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Runtime patches that wire emulated formats into the standard quant path. + +Importing this module injects a real ``format`` field into +``compressed_tensors.QuantizationArgs`` so recipe values like ``format: mx9`` +survive pydantic parsing and become readable via ``getattr(args, "format", None)`` +(by default unknown fields are silently dropped). + +The actual ``fake_quantize`` dispatch (``args.format == "mx9"`` -> mx9 kernel) +lives in ``alto.models.patcher.ModelPatcher.patch_fake_quantize`` where the single +wrap of ``compressed_tensors...forward.fake_quantize`` already happens. +""" + +from typing import Optional + +_FORMAT_FIELD_INJECTED = False + + +def inject_format_field() -> None: + """Add ``format: Optional[str] = None`` to ``QuantizationArgs`` (idempotent).""" + global _FORMAT_FIELD_INJECTED + if _FORMAT_FIELD_INJECTED: + return + + from pydantic.fields import FieldInfo + from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig, QuantizationScheme + + if "format" not in QuantizationArgs.model_fields: + QuantizationArgs.model_fields["format"] = FieldInfo( + annotation=Optional[str], default=None + ) + QuantizationArgs.model_rebuild(force=True) + # QuantizationArgs is nested inside these models. Rebuild them as well so + # recipe dictionaries with weights/input_activations.format are accepted + # instead of being rejected by the old cached schema. + QuantizationScheme.model_rebuild(force=True) + QuantizationConfig.model_rebuild(force=True) + + _FORMAT_FIELD_INJECTED = True + + +inject_format_field() diff --git a/alto/models/llama3/config_registry.py b/alto/models/llama3/config_registry.py index e499b0c..84aad84 100644 --- a/alto/models/llama3/config_registry.py +++ b/alto/models/llama3/config_registry.py @@ -28,6 +28,7 @@ "llama3_8b_lpt", "llama3_1b_gptq", "llama3_1b_awq", + "llama3_1b_mx9_wa", "llama3_8b", "llama3_8b_gptq", "llama3_8b_rtn", @@ -182,6 +183,16 @@ def llama3_1b_awq() -> Trainer.Config: return config +def llama3_1b_mx9_wa() -> Trainer.Config: + config = llama3_1b() + config.training.steps = 1 + config.optimizer = OptimizersContainer.Config(lr=0.0) + config.model_converters = ModelConvertersContainer.Config(converters=[ + ModelOptConverter.Config(recipe="./alto/models/llama3/configs/mx9_wa_recipe.yaml",), + ],) + return config + + LLAMA3_8B_PATH = "/workspace/Model-Optimizer/models/meta-llama/Llama-3.1-8B" diff --git a/alto/models/llama3/configs/mx9_wa_recipe.yaml b/alto/models/llama3/configs/mx9_wa_recipe.yaml new file mode 100644 index 0000000..3df4841 --- /dev/null +++ b/alto/models/llama3/configs/mx9_wa_recipe.yaml @@ -0,0 +1,25 @@ +# MX9 W+A fully-dynamic recipe (single GPU, dense). +# Weight and input-activation are both quantized, both dynamic -> QDQ happens in +# the forward pass, no calibration / observer / weight baking required. +quantization_stage: + quantization_modifiers: + QuantizationModifier: + ignore: ["output"] # lm_head not quantized + sequential: false + config_groups: + group_0: + targets: ["Linear"] + weights: # fully dynamic: QDQ in forward, no baking + num_bits: 8 + type: "int" + symmetric: true + strategy: "tensor" + dynamic: true + format: "mx9" # <- hits mx9 dispatch + input_activations: # fully dynamic + num_bits: 8 + type: "int" + symmetric: true + strategy: "tensor" + dynamic: true + format: "mx9" diff --git a/alto/models/patcher.py b/alto/models/patcher.py index 6eee519..9c5f1bf 100644 --- a/alto/models/patcher.py +++ b/alto/models/patcher.py @@ -57,6 +57,11 @@ class FakeQuantizeFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, scale, zero_point, args, g_idx, global_scale): + if getattr(args, "format", None) == "mx9": + from alto.kernels.mx9.format import BLOCK_SIZE + from alto.kernels.mx9.quantize import mx9_fake_quantize + + return mx9_fake_quantize(x, block_size=(args.group_size or BLOCK_SIZE)) return original_fake_quantize(x, scale, zero_point, args, g_idx, global_scale) @staticmethod diff --git a/alto/modifiers/quantization/base.py b/alto/modifiers/quantization/base.py index 4aba465..3511fdd 100644 --- a/alto/modifiers/quantization/base.py +++ b/alto/modifiers/quantization/base.py @@ -15,7 +15,10 @@ import torch import tqdm -from compressed_tensors.quantization import disable_quantization, enable_quantization +from compressed_tensors.quantization import ( + disable_quantization, + enable_quantization, +) from compressed_tensors.utils import getattr_chain, match_named_modules from pydantic import Field, PrivateAttr from torch.nn import Module @@ -28,6 +31,7 @@ __all__ = ["QuantizationModifier"] + DEFAULT_SUBLAYER_GROUPS = [ ["wq", "wk", "wv", "q_proj", "k_proj", "v_proj"], ["wo", "o_proj"], diff --git a/alto/modifiers/quantization/calibration.py b/alto/modifiers/quantization/calibration.py index f340f71..7eb54b2 100644 --- a/alto/modifiers/quantization/calibration.py +++ b/alto/modifiers/quantization/calibration.py @@ -135,6 +135,11 @@ def update_weight_zp_scale(module: Module): if getattr_chain(module, "quantization_scheme.weights", None) is None: return + # Fully dynamic weights have no observer: their QDQ runs live in the patched + # forward on every pass, so there is no static scale to compute or bake. + if not hasattr(module, "weight_observer"): + return + if getattr(module, "quantization_status", None) != QuantizationStatus.CALIBRATION: logger.warning("Attempting to calibrate weights of a module not in calibration mode") diff --git a/examples/llama3.2_1b_mx9.sh b/examples/llama3.2_1b_mx9.sh new file mode 100755 index 0000000..6b92ba0 --- /dev/null +++ b/examples/llama3.2_1b_mx9.sh @@ -0,0 +1,57 @@ +#!/usr/bin/bash +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +# MX9 W8A8 dynamic fake-quant validation on Llama-3.2-1B +# Weight and input activations are quantized dynamically through +# alto/models/llama3/configs/mx9_wa_recipe.yaml +# +# Usage: +# bash examples/llama3.2_1b_mx9.sh +# VALIDATOR_STEPS=100 bash examples/llama3.2_1b_mx9.sh +# CONFIG=llama3_1b bash examples/llama3.2_1b_mx9.sh # BF16 baseline +rm -rf outputs/ +set -ex + +NGPU=${NGPU:-"1"} +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0} +export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-${CUDA_VISIBLE_DEVICES}} +export LOG_RANK=${LOG_RANK:-0} +TRAIN_FILE=${TRAIN_FILE:-"alto.train"} +MODULE=${MODULE:-"llama3"} +CONFIG=${CONFIG:-"llama3_1b_mx9_wa"} +COMM_MODE=${COMM_MODE:-""} + +MODEL_PATH=${MODEL_PATH:-"/wekafs/jiarwang/Llama-3.2-1B"} +VALIDATOR_STEPS=${VALIDATOR_STEPS:-"10"} +CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-"/wekafs/jiarwang/mx9_e2e_logs/ckpt_${CONFIG}_$(date +%Y%m%d_%H%M%S)"} + +TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} + +if [ -n "$COMM_MODE" ]; then + echo "Running with comm_mode=${COMM_MODE}" + NGPU="${NGPU}" LOCAL_RANK=0 python3 -m ${TRAIN_FILE} \ + --module ${MODULE} \ + --config ${CONFIG} \ + --hf_assets_path "${MODEL_PATH}" \ + --checkpoint.initial_load_path "${MODEL_PATH}" \ + --checkpoint.folder "${CHECKPOINT_FOLDER}" \ + --validator.steps "${VALIDATOR_STEPS}" \ + "$@" \ + --comm.mode=${COMM_MODE} \ + --training.steps 1 +else + PYTORCH_ALLOC_CONF="expandable_segments:True" \ + TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ + torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ + --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ + -m ${TRAIN_FILE} \ + --module ${MODULE} \ + --config ${CONFIG} \ + --hf_assets_path "${MODEL_PATH}" \ + --checkpoint.initial_load_path "${MODEL_PATH}" \ + --checkpoint.folder "${CHECKPOINT_FOLDER}" \ + --validator.steps "${VALIDATOR_STEPS}" \ + "$@" +fi diff --git a/tests/unittest/mx9_mx6/test_mx9_dispatch.py b/tests/unittest/mx9_mx6/test_mx9_dispatch.py new file mode 100644 index 0000000..a118d79 --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx9_dispatch.py @@ -0,0 +1,121 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Tests for the MX9 dispatch wiring in ``ModelPatcher.patch_fake_quantize``. + +The MX9 kernel itself is covered by ``test_mx9_quantize.py``. This file tests the +*wiring* one layer up: after ``patch_fake_quantize()`` replaces +``compressed_tensors...forward.fake_quantize``, a ``QuantizationArgs`` carrying +``format == "mx9"`` must route to ``mx9_fake_quantize``, while plain int8 args +(no ``format``) must fall through to the original implementation untouched. + +Also checks that ``registry_patch.inject_format_field()`` makes the ``format`` +field survive pydantic validation (otherwise the recipe value is silently +dropped and dispatch never fires). + +Run with: + pytest tests/unittest/mx9_mx6/test_mx9_dispatch.py +""" + +import pytest +import torch + +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization.lifecycle import forward as forward_module + +from alto.kernels.mx9.format import BLOCK_SIZE +from alto.kernels.mx9.quantize import mx9_fake_quantize + + +def _mx9_args() -> QuantizationArgs: + return QuantizationArgs.model_validate( + { + "num_bits": 8, + "type": "int", + "symmetric": True, + "strategy": "tensor", + "dynamic": True, + "format": "mx9", + } + ) + + +def _int8_args() -> QuantizationArgs: + return QuantizationArgs.model_validate( + { + "num_bits": 8, + "type": "int", + "symmetric": True, + "strategy": "tensor", + } + ) + + +# --------------------------------------------------------------------------- # +# format field injection (registry_patch) +# --------------------------------------------------------------------------- # +def test_format_field_survives_validation(): + """Without inject_format_field(), pydantic drops the unknown ``format`` key + and dispatch can never see it.""" + args = _mx9_args() + assert getattr(args, "format", None) == "mx9" + + +def test_format_field_defaults_none_for_plain_int8(): + args = _int8_args() + assert getattr(args, "format", None) is None + + +# --------------------------------------------------------------------------- # +# dispatch routing +# --------------------------------------------------------------------------- # +def test_mx9_args_dispatch_to_mx9_kernel(): + """format == "mx9" must route fake_quantize to mx9_fake_quantize bit-exact.""" + torch.manual_seed(0) + x = torch.randn(3, 40, dtype=torch.float32) + args = _mx9_args() + + patched_out = forward_module.fake_quantize( + x=x, + scale=torch.ones(1), + zero_point=None, + args=args, + g_idx=None, + global_scale=None, + ) + expected = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + + assert torch.equal(patched_out, expected) + + +def test_int8_args_fall_through_to_original(): + """Plain int8 (no format) must NOT touch the mx9 kernel; output must match a + direct INT8 per-tensor QDQ, never the mx9 result.""" + torch.manual_seed(1) + x = torch.randn(3, 40, dtype=torch.float32) + args = _int8_args() + + # per-tensor symmetric int8 scale (what the original kernel would use) + qmax = 127.0 + scale = (x.abs().max() / qmax).clamp_min(1e-12).reshape(1) + zp = torch.zeros(1, dtype=torch.int32) + + patched_out = forward_module.fake_quantize( + x=x, + scale=scale, + zero_point=zp, + args=args, + g_idx=None, + global_scale=None, + ) + + expected_int8 = (torch.round(x / scale).clamp(-128, 127)) * scale + assert torch.allclose(patched_out, expected_int8, atol=1e-5) + + # and it must differ from the mx9 path (sanity: int8 did not silently route to mx9) + mx9_out = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + assert not torch.equal(patched_out, mx9_out) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v", "-p", "no:cacheprovider"])) diff --git a/tests/unittest/mx9_mx6/test_mx9_wa_integration.py b/tests/unittest/mx9_mx6/test_mx9_wa_integration.py new file mode 100644 index 0000000..28fdabd --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx9_wa_integration.py @@ -0,0 +1,170 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Toy integration test for MX9 W+A dynamic quantization. + +This intentionally stays below a real llama3 validation run. It verifies the +first full ALTO wiring layer: + +recipe yaml -> QuantizationModifier -> Linear quantization_scheme -> +post_step dynamic-weight skip -> wrapped Linear forward -> MX9 W+A QDQ. +""" + +import importlib.util +import os +import sys +import types + +import torch +import yaml + + +ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + + +def _make_pkg(name: str, path: str) -> types.ModuleType: + mod = types.ModuleType(name) + mod.__path__ = [path] + mod.__package__ = name + sys.modules[name] = mod + return mod + + +def _load(modname: str, relpath: str) -> types.ModuleType: + spec = importlib.util.spec_from_file_location(modname, os.path.join(ROOT, relpath)) + mod = importlib.util.module_from_spec(spec) + sys.modules[modname] = mod + spec.loader.exec_module(mod) + return mod + + +def _install_torchtitan_stubs() -> None: + """Provide the tiny torchtitan surface needed by quantization modules.""" + _make_pkg("torchtitan", "") + _make_pkg("torchtitan.tools", "") + + logging_mod = types.ModuleType("torchtitan.tools.logging") + + class _Logger: + def debug(self, *args, **kwargs): + pass + + def info(self, *args, **kwargs): + pass + + def warning(self, *args, **kwargs): + pass + + logging_mod.logger = _Logger() + sys.modules["torchtitan.tools.logging"] = logging_mod + + utils_mod = types.ModuleType("torchtitan.tools.utils") + utils_mod.device_type = torch.device("cpu") + sys.modules["torchtitan.tools.utils"] = utils_mod + + +def _load_quantization_modifier(): + """Load only the ALTO modules needed for QuantizationModifier. + + Importing top-level ``alto`` currently pulls unrelated torchtitan entry + points, so the test loads the required modules directly. + """ + _install_torchtitan_stubs() + + _make_pkg("alto", os.path.join(ROOT, "alto")) + _make_pkg("alto.kernels", os.path.join(ROOT, "alto/kernels")) + _make_pkg("alto.kernels.mx9", os.path.join(ROOT, "alto/kernels/mx9")) + _make_pkg("alto.models", os.path.join(ROOT, "alto/models")) + + _load("alto.kernels.mx9.format", "alto/kernels/mx9/format.py") + quantize_mod = _load( + "alto.kernels.mx9.quantize", + "alto/kernels/mx9/quantize.py", + ) + _load("alto.kernels.mx9.registry_patch", "alto/kernels/mx9/registry_patch.py") + + # Dynamic W/A uses no observers, but calibration.py imports Observer. + observers_mod = types.ModuleType("alto.observers") + + class _Observer: + @staticmethod + def create_instance(*args, **kwargs): + raise AssertionError("MX9 dynamic W+A should not create observers") + + observers_mod.Observer = _Observer + sys.modules["alto.observers"] = observers_mod + + _make_pkg("alto.modifiers", os.path.join(ROOT, "alto/modifiers")) + _make_pkg("alto.modifiers.utils", os.path.join(ROOT, "alto/modifiers/utils")) + hooks_mod = _load("alto.modifiers.utils.hooks", "alto/modifiers/utils/hooks.py") + sys.modules["alto.modifiers.utils"].HooksMixin = hooks_mod.HooksMixin + + modifier_base_mod = _load("alto.modifiers.base", "alto/modifiers/base.py") + sys.modules["alto.modifiers"].Modifier = modifier_base_mod.Modifier + + _make_pkg("alto.modifiers.quantization", os.path.join(ROOT, "alto/modifiers/quantization")) + _load("alto.modifiers.quantization.calibration", "alto/modifiers/quantization/calibration.py") + _load("alto.modifiers.quantization.mixin", "alto/modifiers/quantization/mixin.py") + + # base.py imports get_layers for sequential mode; this test uses sequential=False. + _make_pkg("alto.utils", os.path.join(ROOT, "alto/utils")) + _make_pkg("alto.utils.pytorch", os.path.join(ROOT, "alto/utils/pytorch")) + module_utils = types.ModuleType("alto.utils.pytorch.module") + module_utils.get_layers = lambda *args, **kwargs: {} + sys.modules["alto.utils.pytorch.module"] = module_utils + + quant_base_mod = _load("alto.modifiers.quantization.base", "alto/modifiers/quantization/base.py") + + patcher_mod = _load("alto.models.patcher", "alto/models/patcher.py") + patcher_mod.ModelPatcher.patch_fake_quantize() + + return quant_base_mod.QuantizationModifier, quantize_mod + + +def _load_mx9_modifier_from_recipe(): + recipe_path = os.path.join(ROOT, "alto/models/llama3/configs/mx9_wa_recipe.yaml") + with open(recipe_path, "r") as f: + recipe = yaml.safe_load(f) + + mod_args = recipe["quantization_stage"]["quantization_modifiers"]["QuantizationModifier"] + QuantizationModifier, quantize_mod = _load_quantization_modifier() + return QuantizationModifier(**mod_args), quantize_mod + + +def test_mx9_wa_recipe_toy_linear_lifecycle(monkeypatch): + modifier, quantize_mod = _load_mx9_modifier_from_recipe() + model = torch.nn.Sequential(torch.nn.Linear(16, 16, bias=False)) + linear = model[0] + + modifier.initialize([model]) + + scheme = linear.quantization_scheme + assert getattr(scheme.weights, "format", None) == "mx9" + assert getattr(scheme.input_activations, "format", None) == "mx9" + assert scheme.weights.dynamic is True + assert scheme.input_activations.dynamic is True + assert not hasattr(linear, "weight_observer") + assert not hasattr(linear, "input_observer") + + calls = [] + original_mx9 = quantize_mod.mx9_fake_quantize + + def counted_mx9(input_tensor, *args, **kwargs): + calls.append(tuple(input_tensor.shape)) + return original_mx9(input_tensor, *args, **kwargs) + + monkeypatch.setattr(quantize_mod, "mx9_fake_quantize", counted_mx9) + + x = torch.randn(2, 16) + modifier.pre_step([model]) + modifier.post_step([model]) # must skip static weight baking for MX9 dynamic weight + assert not hasattr(linear, "weight_observer") + assert not hasattr(linear, "weight_scale") + + out = model(x) + assert out.shape == (2, 16) + assert len(calls) == 2 + assert (2, 16) in calls # input activation QDQ + assert (16, 16) in calls # weight QDQ + + modifier.finalize([model]) From b530d287bdfb21d895c42271c637f02d8b44f315 Mon Sep 17 00:00:00 2001 From: jiarwang Date: Thu, 11 Jun 2026 07:55:06 +0000 Subject: [PATCH 3/5] feat: mx6: add MX6 fake-quant kernel and the PTQ quantization flow --- alto/__init__.py | 17 +- alto/kernels/mx6/__init__.py | 14 ++ alto/kernels/mx6/format.py | 8 + alto/kernels/mx6/quantize.py | 34 ++++ alto/models/llama3/config_registry.py | 11 ++ alto/models/llama3/configs/mx6_wa_recipe.yaml | 25 +++ alto/models/patcher.py | 5 + tests/unittest/mx9_mx6/test_mx6_dispatch.py | 121 ++++++++++++ tests/unittest/mx9_mx6/test_mx6_quantize.py | 66 +++++++ .../mx9_mx6/test_mx6_wa_integration.py | 173 ++++++++++++++++++ 10 files changed, 471 insertions(+), 3 deletions(-) create mode 100644 alto/kernels/mx6/__init__.py create mode 100644 alto/kernels/mx6/format.py create mode 100644 alto/kernels/mx6/quantize.py create mode 100644 alto/models/llama3/configs/mx6_wa_recipe.yaml create mode 100644 tests/unittest/mx9_mx6/test_mx6_dispatch.py create mode 100644 tests/unittest/mx9_mx6/test_mx6_quantize.py create mode 100644 tests/unittest/mx9_mx6/test_mx6_wa_integration.py diff --git a/alto/__init__.py b/alto/__init__.py index 0d74e9e..d342a80 100644 --- a/alto/__init__.py +++ b/alto/__init__.py @@ -2,13 +2,24 @@ # # SPDX-License-Identifier: MIT +# Inject the QuantizationArgs.format field BEFORE importing modifiers: the alto +# QuantizationModifier compiles its nested QuantizationScheme schema at class +# definition time, so the field must exist first or recipes carrying +# ``format: mx6/mx9`` are rejected by the cached (format-less) schema. Calling +# inject_format_field() explicitly (it is idempotent) makes this ordering +# dependency self-documenting instead of relying on import order, which an +# import-sorter could silently reshuffle. +from .kernels.mx9.registry_patch import inject_format_field + +inject_format_field() + from .components import * from .config import * from .observers import * from .modifiers import * from .models import * -# Runtime patches with import-time side effects (kept after the API re-exports): -# importing the mx9 kernel package injects the QuantizationArgs.format field so -# recipes carrying ``format: mx9`` parse correctly. Do not remove. +# MX fake-quant kernels: import for their dispatch wiring. The format-field +# injection they depend on already ran above. +from .kernels import mx6 # noqa: F401,E402 from .kernels import mx9 # noqa: F401,E402 diff --git a/alto/kernels/mx6/__init__.py b/alto/kernels/mx6/__init__.py new file mode 100644 index 0000000..1760aa6 --- /dev/null +++ b/alto/kernels/mx6/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""MX6 emulated (fake-quant) format. + +MX6 shares the same Quark ``fake_quantize_mx6_mx9`` math as MX9 and only differs +in the element integer bit width (``quant_bit=5`` for MX6, ``8`` for MX9). + +The ``QuantizationArgs.format`` field injection (needed so ``format: mx6`` +recipes parse) lives in ``alto.kernels.mx9.registry_patch``; import it here so +this package is self-sufficient instead of relying on mx9 being imported first. +""" + +from alto.kernels.mx9 import registry_patch # noqa: F401 injects the QuantizationArgs.format field; do not remove diff --git a/alto/kernels/mx6/format.py b/alto/kernels/mx6/format.py new file mode 100644 index 0000000..6787fa6 --- /dev/null +++ b/alto/kernels/mx6/format.py @@ -0,0 +1,8 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""MX6 format constants (fake-quant emulation only; no real packed kernel).""" + +BLOCK_SIZE = 16 +QUANT_BIT = 5 +SHARED_PRIME_BIT_GROUP = 2 diff --git a/alto/kernels/mx6/quantize.py b/alto/kernels/mx6/quantize.py new file mode 100644 index 0000000..988456b --- /dev/null +++ b/alto/kernels/mx6/quantize.py @@ -0,0 +1,34 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""MX6 fake-quantize wrapper. + +Quark implements MX6 and MX9 with the same ``fake_quantize_mx6_mx9`` function. +The only format difference in this emulation path is ``quant_bit``: + +- MX6: ``quant_bit=5`` +- MX9: ``quant_bit=8`` + +The shared-prime-bit, block reshape, exponent, scale, clamp, and dequant math is +therefore reused from the MX9 implementation. +""" + +import torch + +from alto.kernels.mx9.quantize import mx9_fake_quantize +from .format import BLOCK_SIZE, QUANT_BIT + + +def mx6_fake_quantize( + input_tensor: torch.Tensor, + block_size: int = BLOCK_SIZE, + quant_bit: int = QUANT_BIT, + axis: int = -1, +) -> torch.Tensor: + """对输入执行 MX6 block-wise fake quantization(QDQ)。""" + return mx9_fake_quantize( + input_tensor=input_tensor, + block_size=block_size, + quant_bit=quant_bit, + axis=axis, + ) diff --git a/alto/models/llama3/config_registry.py b/alto/models/llama3/config_registry.py index 84aad84..0dd990a 100644 --- a/alto/models/llama3/config_registry.py +++ b/alto/models/llama3/config_registry.py @@ -29,6 +29,7 @@ "llama3_1b_gptq", "llama3_1b_awq", "llama3_1b_mx9_wa", + "llama3_1b_mx6_wa", "llama3_8b", "llama3_8b_gptq", "llama3_8b_rtn", @@ -193,6 +194,16 @@ def llama3_1b_mx9_wa() -> Trainer.Config: return config +def llama3_1b_mx6_wa() -> Trainer.Config: + config = llama3_1b() + config.training.steps = 1 + config.optimizer = OptimizersContainer.Config(lr=0.0) + config.model_converters = ModelConvertersContainer.Config(converters=[ + ModelOptConverter.Config(recipe="./alto/models/llama3/configs/mx6_wa_recipe.yaml",), + ],) + return config + + LLAMA3_8B_PATH = "/workspace/Model-Optimizer/models/meta-llama/Llama-3.1-8B" diff --git a/alto/models/llama3/configs/mx6_wa_recipe.yaml b/alto/models/llama3/configs/mx6_wa_recipe.yaml new file mode 100644 index 0000000..5ab4d3b --- /dev/null +++ b/alto/models/llama3/configs/mx6_wa_recipe.yaml @@ -0,0 +1,25 @@ +# MX6 W+A fully-dynamic recipe (single GPU, dense). +# Weight and input-activation are both quantized, both dynamic -> QDQ happens in +# the forward pass, no calibration / observer / weight baking required. +quantization_stage: + quantization_modifiers: + QuantizationModifier: + ignore: ["output"] # lm_head not quantized + sequential: false + config_groups: + group_0: + targets: ["Linear"] + weights: # fully dynamic: QDQ in forward, no baking + num_bits: 5 + type: "int" + symmetric: true + strategy: "tensor" + dynamic: true + format: "mx6" # <- hits mx6 dispatch + input_activations: # fully dynamic + num_bits: 5 + type: "int" + symmetric: true + strategy: "tensor" + dynamic: true + format: "mx6" diff --git a/alto/models/patcher.py b/alto/models/patcher.py index 9c5f1bf..3129436 100644 --- a/alto/models/patcher.py +++ b/alto/models/patcher.py @@ -62,6 +62,11 @@ def forward(ctx, x, scale, zero_point, args, g_idx, global_scale): from alto.kernels.mx9.quantize import mx9_fake_quantize return mx9_fake_quantize(x, block_size=(args.group_size or BLOCK_SIZE)) + if getattr(args, "format", None) == "mx6": + from alto.kernels.mx6.format import BLOCK_SIZE + from alto.kernels.mx6.quantize import mx6_fake_quantize + + return mx6_fake_quantize(x, block_size=(args.group_size or BLOCK_SIZE)) return original_fake_quantize(x, scale, zero_point, args, g_idx, global_scale) @staticmethod diff --git a/tests/unittest/mx9_mx6/test_mx6_dispatch.py b/tests/unittest/mx9_mx6/test_mx6_dispatch.py new file mode 100644 index 0000000..f44aa47 --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx6_dispatch.py @@ -0,0 +1,121 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Tests for the MX6 dispatch wiring in ``ModelPatcher.patch_fake_quantize``. + +The MX6 kernel itself is covered by ``test_mx6_quantize.py``. This file tests the +*wiring* one layer up: after ``patch_fake_quantize()`` replaces +``compressed_tensors...forward.fake_quantize``, a ``QuantizationArgs`` carrying +``format == "mx6"`` must route to ``mx6_fake_quantize``, while plain int8 args +(no ``format``) must fall through to the original implementation untouched. + +Also checks that ``registry_patch.inject_format_field()`` makes the ``format`` +field survive pydantic validation (otherwise the recipe value is silently +dropped and dispatch never fires). + +Run with: + pytest tests/unittest/mx9_mx6/test_mx6_dispatch.py +""" + +import pytest +import torch + +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization.lifecycle import forward as forward_module + +from alto.kernels.mx6.format import BLOCK_SIZE +from alto.kernels.mx6.quantize import mx6_fake_quantize + + +def _mx6_args() -> QuantizationArgs: + return QuantizationArgs.model_validate( + { + "num_bits": 5, + "type": "int", + "symmetric": True, + "strategy": "tensor", + "dynamic": True, + "format": "mx6", + } + ) + + +def _int8_args() -> QuantizationArgs: + return QuantizationArgs.model_validate( + { + "num_bits": 8, + "type": "int", + "symmetric": True, + "strategy": "tensor", + } + ) + + +# --------------------------------------------------------------------------- # +# format field injection (registry_patch) +# --------------------------------------------------------------------------- # +def test_format_field_survives_validation(): + """Without inject_format_field(), pydantic drops the unknown ``format`` key + and dispatch can never see it.""" + args = _mx6_args() + assert getattr(args, "format", None) == "mx6" + + +def test_format_field_defaults_none_for_plain_int8(): + args = _int8_args() + assert getattr(args, "format", None) is None + + +# --------------------------------------------------------------------------- # +# dispatch routing +# --------------------------------------------------------------------------- # +def test_mx6_args_dispatch_to_mx6_kernel(): + """format == "mx6" must route fake_quantize to mx6_fake_quantize bit-exact.""" + torch.manual_seed(6) + x = torch.randn(3, 40, dtype=torch.float32) + args = _mx6_args() + + patched_out = forward_module.fake_quantize( + x=x, + scale=torch.ones(1), + zero_point=None, + args=args, + g_idx=None, + global_scale=None, + ) + expected = mx6_fake_quantize(x, block_size=BLOCK_SIZE) + + assert torch.equal(patched_out, expected) + + +def test_int8_args_fall_through_to_original(): + """Plain int8 (no format) must NOT touch the mx6 kernel; output must match a + direct INT8 per-tensor QDQ, never the mx6 result.""" + torch.manual_seed(1) + x = torch.randn(3, 40, dtype=torch.float32) + args = _int8_args() + + # per-tensor symmetric int8 scale (what the original kernel would use) + qmax = 127.0 + scale = (x.abs().max() / qmax).clamp_min(1e-12).reshape(1) + zp = torch.zeros(1, dtype=torch.int32) + + patched_out = forward_module.fake_quantize( + x=x, + scale=scale, + zero_point=zp, + args=args, + g_idx=None, + global_scale=None, + ) + + expected_int8 = (torch.round(x / scale).clamp(-128, 127)) * scale + assert torch.allclose(patched_out, expected_int8, atol=1e-5) + + # and it must differ from the mx6 path (sanity: int8 did not silently route to mx6) + mx6_out = mx6_fake_quantize(x, block_size=BLOCK_SIZE) + assert not torch.equal(patched_out, mx6_out) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v", "-p", "no:cacheprovider"])) diff --git a/tests/unittest/mx9_mx6/test_mx6_quantize.py b/tests/unittest/mx9_mx6/test_mx6_quantize.py new file mode 100644 index 0000000..35841c4 --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx6_quantize.py @@ -0,0 +1,66 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Tests for the MX6 fake-quantize wrapper. + +MX6 shares Quark's ``fake_quantize_mx6_mx9`` algorithm with MX9, but uses +``quant_bit=5``. These tests focus on proving that the wrapper selects the MX6 +bit width and still preserves the basic QDQ invariants. +""" + +import pytest +import torch + +from alto.kernels.mx6.format import BLOCK_SIZE, QUANT_BIT +from alto.kernels.mx6.quantize import mx6_fake_quantize + + +def _quark_mx6_reference(x: torch.Tensor, block_size: int): + """Live Quark reference; returns None if Quark is unavailable.""" + try: + from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( + fake_quantize_mx6_mx9, + ) + except Exception: + return None + return fake_quantize_mx6_mx9(x.clone(), axis=-1, block_size=block_size, quant_bit=QUANT_BIT) + + +def _rand(shape, dtype): + torch.manual_seed(6) + return torch.randn(*shape, dtype=dtype) + + +@pytest.mark.parametrize("shape", [(4, 64), (8, 16), (2, 3, 32), (3, 40)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("block_size", [16, 32]) +def test_mx6_matches_quark_bit_exact(shape, dtype, block_size): + x = _rand(shape, dtype) + ref = _quark_mx6_reference(x, block_size) + if ref is None: + pytest.skip("Quark reference (fake_quantize_mx6_mx9) not importable") + + out = mx6_fake_quantize(x, block_size=block_size, quant_bit=QUANT_BIT) + assert out.dtype == ref.dtype + assert torch.equal(out, ref), (out - ref).abs().max().item() + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +def test_shape_and_dtype_preserved(dtype): + x = _rand((4, 64), dtype) + out = mx6_fake_quantize(x, block_size=BLOCK_SIZE) + assert out.shape == x.shape + assert out.dtype == dtype + + +def test_padding_path_non_divisible_last_dim(): + x = _rand((3, 40), torch.float32) + out = mx6_fake_quantize(x, block_size=BLOCK_SIZE) + assert out.shape == x.shape + assert not torch.isnan(out).any() + + +def test_axis_contract(): + x = _rand((4, 32), torch.float32) + with pytest.raises(NotImplementedError): + mx6_fake_quantize(x, block_size=BLOCK_SIZE, axis=0) diff --git a/tests/unittest/mx9_mx6/test_mx6_wa_integration.py b/tests/unittest/mx9_mx6/test_mx6_wa_integration.py new file mode 100644 index 0000000..ca45c41 --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx6_wa_integration.py @@ -0,0 +1,173 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Toy integration test for MX6 W+A dynamic quantization. + +This intentionally stays below a real llama3 validation run. It verifies the +first full ALTO wiring layer: + +recipe yaml -> QuantizationModifier -> Linear quantization_scheme -> +post_step dynamic-weight skip -> wrapped Linear forward -> MX6 W+A QDQ. +""" + +import importlib.util +import os +import sys +import types + +import torch +import yaml + + +ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + + +def _make_pkg(name: str, path: str) -> types.ModuleType: + mod = types.ModuleType(name) + mod.__path__ = [path] + mod.__package__ = name + sys.modules[name] = mod + return mod + + +def _load(modname: str, relpath: str) -> types.ModuleType: + spec = importlib.util.spec_from_file_location(modname, os.path.join(ROOT, relpath)) + mod = importlib.util.module_from_spec(spec) + sys.modules[modname] = mod + spec.loader.exec_module(mod) + return mod + + +def _install_torchtitan_stubs() -> None: + """Provide the tiny torchtitan surface needed by quantization modules.""" + _make_pkg("torchtitan", "") + _make_pkg("torchtitan.tools", "") + + logging_mod = types.ModuleType("torchtitan.tools.logging") + + class _Logger: + def debug(self, *args, **kwargs): + pass + + def info(self, *args, **kwargs): + pass + + def warning(self, *args, **kwargs): + pass + + logging_mod.logger = _Logger() + sys.modules["torchtitan.tools.logging"] = logging_mod + + utils_mod = types.ModuleType("torchtitan.tools.utils") + utils_mod.device_type = torch.device("cpu") + sys.modules["torchtitan.tools.utils"] = utils_mod + + +def _load_quantization_modifier(): + """Load only the ALTO modules needed for QuantizationModifier. + + Importing top-level ``alto`` currently pulls unrelated torchtitan entry + points. This test loads the required modules directly. MX6 reuses the MX9 + math, so both kernel packages are loaded; the ``format`` field injection + lives in ``alto.kernels.mx9.registry_patch``. + """ + _install_torchtitan_stubs() + + _make_pkg("alto", os.path.join(ROOT, "alto")) + _make_pkg("alto.kernels", os.path.join(ROOT, "alto/kernels")) + _make_pkg("alto.kernels.mx9", os.path.join(ROOT, "alto/kernels/mx9")) + _make_pkg("alto.kernels.mx6", os.path.join(ROOT, "alto/kernels/mx6")) + _make_pkg("alto.models", os.path.join(ROOT, "alto/models")) + + _load("alto.kernels.mx9.format", "alto/kernels/mx9/format.py") + _load("alto.kernels.mx9.quantize", "alto/kernels/mx9/quantize.py") + _load("alto.kernels.mx9.registry_patch", "alto/kernels/mx9/registry_patch.py") + + _load("alto.kernels.mx6.format", "alto/kernels/mx6/format.py") + quantize_mod = _load("alto.kernels.mx6.quantize", "alto/kernels/mx6/quantize.py") + + # Dynamic W/A uses no observers, but calibration.py imports Observer. + observers_mod = types.ModuleType("alto.observers") + + class _Observer: + @staticmethod + def create_instance(*args, **kwargs): + raise AssertionError("MX6 dynamic W+A should not create observers") + + observers_mod.Observer = _Observer + sys.modules["alto.observers"] = observers_mod + + _make_pkg("alto.modifiers", os.path.join(ROOT, "alto/modifiers")) + _make_pkg("alto.modifiers.utils", os.path.join(ROOT, "alto/modifiers/utils")) + hooks_mod = _load("alto.modifiers.utils.hooks", "alto/modifiers/utils/hooks.py") + sys.modules["alto.modifiers.utils"].HooksMixin = hooks_mod.HooksMixin + + modifier_base_mod = _load("alto.modifiers.base", "alto/modifiers/base.py") + sys.modules["alto.modifiers"].Modifier = modifier_base_mod.Modifier + + _make_pkg("alto.modifiers.quantization", os.path.join(ROOT, "alto/modifiers/quantization")) + _load("alto.modifiers.quantization.calibration", "alto/modifiers/quantization/calibration.py") + _load("alto.modifiers.quantization.mixin", "alto/modifiers/quantization/mixin.py") + + # base.py imports get_layers for sequential mode; this test uses sequential=False. + _make_pkg("alto.utils", os.path.join(ROOT, "alto/utils")) + _make_pkg("alto.utils.pytorch", os.path.join(ROOT, "alto/utils/pytorch")) + module_utils = types.ModuleType("alto.utils.pytorch.module") + module_utils.get_layers = lambda *args, **kwargs: {} + sys.modules["alto.utils.pytorch.module"] = module_utils + + quant_base_mod = _load("alto.modifiers.quantization.base", "alto/modifiers/quantization/base.py") + + patcher_mod = _load("alto.models.patcher", "alto/models/patcher.py") + patcher_mod.ModelPatcher.patch_fake_quantize() + + return quant_base_mod.QuantizationModifier, quantize_mod + + +def _load_mx6_modifier_from_recipe(): + recipe_path = os.path.join(ROOT, "alto/models/llama3/configs/mx6_wa_recipe.yaml") + with open(recipe_path, "r") as f: + recipe = yaml.safe_load(f) + + mod_args = recipe["quantization_stage"]["quantization_modifiers"]["QuantizationModifier"] + QuantizationModifier, quantize_mod = _load_quantization_modifier() + return QuantizationModifier(**mod_args), quantize_mod + + +def test_mx6_wa_recipe_toy_linear_lifecycle(monkeypatch): + modifier, quantize_mod = _load_mx6_modifier_from_recipe() + model = torch.nn.Sequential(torch.nn.Linear(16, 16, bias=False)) + linear = model[0] + + modifier.initialize([model]) + + scheme = linear.quantization_scheme + assert getattr(scheme.weights, "format", None) == "mx6" + assert getattr(scheme.input_activations, "format", None) == "mx6" + assert scheme.weights.dynamic is True + assert scheme.input_activations.dynamic is True + assert not hasattr(linear, "weight_observer") + assert not hasattr(linear, "input_observer") + + calls = [] + original_mx6 = quantize_mod.mx6_fake_quantize + + def counted_mx6(input_tensor, *args, **kwargs): + calls.append(tuple(input_tensor.shape)) + return original_mx6(input_tensor, *args, **kwargs) + + monkeypatch.setattr(quantize_mod, "mx6_fake_quantize", counted_mx6) + + x = torch.randn(2, 16) + modifier.pre_step([model]) + modifier.post_step([model]) # must skip static weight baking for MX6 dynamic weight + assert not hasattr(linear, "weight_observer") + assert not hasattr(linear, "weight_scale") + + out = model(x) + assert out.shape == (2, 16) + assert len(calls) == 2 + assert (2, 16) in calls # input activation QDQ + assert (16, 16) in calls # weight QDQ + + modifier.finalize([model]) From 599ab4eaaf50d6ac2da7ca56399e9eb5fe36ccb7 Mon Sep 17 00:00:00 2001 From: jiarwang Date: Mon, 15 Jun 2026 09:03:27 +0000 Subject: [PATCH 4/5] refactor: mx: move MX6/MX9 fake-quant out of kernels into modifiers/quantization Relocate the emulated MX6/MX9 fake-quant reference and the format-field registry patch from alto/kernels/{mx6,mx9}/ into alto/modifiers/quantization/ (mx.py + format_registry.py), and move the format-field injection into the quantization package __init__. Update the mx9/mx6 dispatch and quantize test imports to the new module paths. --- alto/__init__.py | 16 ----- alto/kernels/mx6/__init__.py | 14 ---- alto/kernels/mx6/format.py | 8 --- alto/kernels/mx6/quantize.py | 34 ---------- alto/kernels/mx9/__init__.py | 14 ---- alto/kernels/mx9/format.py | 11 ---- alto/modifiers/quantization/__init__.py | 8 +++ .../quantization/format_registry.py} | 12 ++-- .../quantization/mx.py} | 65 +++++++++++++------ tests/unittest/mx9_mx6/test_mx6_dispatch.py | 7 +- tests/unittest/mx9_mx6/test_mx6_quantize.py | 3 +- tests/unittest/mx9_mx6/test_mx9_dispatch.py | 7 +- tests/unittest/mx9_mx6/test_mx9_quantize.py | 3 +- 13 files changed, 70 insertions(+), 132 deletions(-) delete mode 100644 alto/kernels/mx6/__init__.py delete mode 100644 alto/kernels/mx6/format.py delete mode 100644 alto/kernels/mx6/quantize.py delete mode 100644 alto/kernels/mx9/__init__.py delete mode 100644 alto/kernels/mx9/format.py rename alto/{kernels/mx9/registry_patch.py => modifiers/quantization/format_registry.py} (76%) rename alto/{kernels/mx9/quantize.py => modifiers/quantization/mx.py} (78%) diff --git a/alto/__init__.py b/alto/__init__.py index d342a80..2bf2a5d 100644 --- a/alto/__init__.py +++ b/alto/__init__.py @@ -2,24 +2,8 @@ # # SPDX-License-Identifier: MIT -# Inject the QuantizationArgs.format field BEFORE importing modifiers: the alto -# QuantizationModifier compiles its nested QuantizationScheme schema at class -# definition time, so the field must exist first or recipes carrying -# ``format: mx6/mx9`` are rejected by the cached (format-less) schema. Calling -# inject_format_field() explicitly (it is idempotent) makes this ordering -# dependency self-documenting instead of relying on import order, which an -# import-sorter could silently reshuffle. -from .kernels.mx9.registry_patch import inject_format_field - -inject_format_field() - from .components import * from .config import * from .observers import * from .modifiers import * from .models import * - -# MX fake-quant kernels: import for their dispatch wiring. The format-field -# injection they depend on already ran above. -from .kernels import mx6 # noqa: F401,E402 -from .kernels import mx9 # noqa: F401,E402 diff --git a/alto/kernels/mx6/__init__.py b/alto/kernels/mx6/__init__.py deleted file mode 100644 index 1760aa6..0000000 --- a/alto/kernels/mx6/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2026 Advanced Micro Devices, Inc. -# -# SPDX-License-Identifier: MIT -"""MX6 emulated (fake-quant) format. - -MX6 shares the same Quark ``fake_quantize_mx6_mx9`` math as MX9 and only differs -in the element integer bit width (``quant_bit=5`` for MX6, ``8`` for MX9). - -The ``QuantizationArgs.format`` field injection (needed so ``format: mx6`` -recipes parse) lives in ``alto.kernels.mx9.registry_patch``; import it here so -this package is self-sufficient instead of relying on mx9 being imported first. -""" - -from alto.kernels.mx9 import registry_patch # noqa: F401 injects the QuantizationArgs.format field; do not remove diff --git a/alto/kernels/mx6/format.py b/alto/kernels/mx6/format.py deleted file mode 100644 index 6787fa6..0000000 --- a/alto/kernels/mx6/format.py +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright (c) 2026 Advanced Micro Devices, Inc. -# -# SPDX-License-Identifier: MIT -"""MX6 format constants (fake-quant emulation only; no real packed kernel).""" - -BLOCK_SIZE = 16 -QUANT_BIT = 5 -SHARED_PRIME_BIT_GROUP = 2 diff --git a/alto/kernels/mx6/quantize.py b/alto/kernels/mx6/quantize.py deleted file mode 100644 index 988456b..0000000 --- a/alto/kernels/mx6/quantize.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2026 Advanced Micro Devices, Inc. -# -# SPDX-License-Identifier: MIT -"""MX6 fake-quantize wrapper. - -Quark implements MX6 and MX9 with the same ``fake_quantize_mx6_mx9`` function. -The only format difference in this emulation path is ``quant_bit``: - -- MX6: ``quant_bit=5`` -- MX9: ``quant_bit=8`` - -The shared-prime-bit, block reshape, exponent, scale, clamp, and dequant math is -therefore reused from the MX9 implementation. -""" - -import torch - -from alto.kernels.mx9.quantize import mx9_fake_quantize -from .format import BLOCK_SIZE, QUANT_BIT - - -def mx6_fake_quantize( - input_tensor: torch.Tensor, - block_size: int = BLOCK_SIZE, - quant_bit: int = QUANT_BIT, - axis: int = -1, -) -> torch.Tensor: - """对输入执行 MX6 block-wise fake quantization(QDQ)。""" - return mx9_fake_quantize( - input_tensor=input_tensor, - block_size=block_size, - quant_bit=quant_bit, - axis=axis, - ) diff --git a/alto/kernels/mx9/__init__.py b/alto/kernels/mx9/__init__.py deleted file mode 100644 index bfc57fb..0000000 --- a/alto/kernels/mx9/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2026 Advanced Micro Devices, Inc. -# -# SPDX-License-Identifier: MIT -"""MX9 emulated (fake-quant) format. - -NOTE: this package currently provides fake-quantization (QDQ emulation) only -- -there is no real packed kernel / GEMM. It plugs into the standard -``QuantizationModifier`` path via runtime patches, without touching the generic -calibration / observer machinery. The patches that wire it into the quant path -(``QuantizationArgs.format`` field injection + ``fake_quantize`` mx9 dispatch) -live in ``registry_patch`` and ``alto.models.patcher``. -""" - -from . import registry_patch # noqa: F401 injects the QuantizationArgs.format field; do not remove diff --git a/alto/kernels/mx9/format.py b/alto/kernels/mx9/format.py deleted file mode 100644 index c06d354..0000000 --- a/alto/kernels/mx9/format.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright (c) 2026 Advanced Micro Devices, Inc. -# -# SPDX-License-Identifier: MIT -"""MX9 format constants (fake-quant emulation only; no real packed kernel).""" - -# Matches Quark mx9 reference (test_mx.py uses block_size=16; Quark leaves it -# configurable via MX9Spec.block_size). 16 is the L2 block; the shared prime bit -# is shared across every SHARED_PRIME_BIT_GROUP (=2) elements. -BLOCK_SIZE = 16 -QUANT_BIT = 8 -SHARED_PRIME_BIT_GROUP = 2 diff --git a/alto/modifiers/quantization/__init__.py b/alto/modifiers/quantization/__init__.py index 3a86389..c7cd600 100644 --- a/alto/modifiers/quantization/__init__.py +++ b/alto/modifiers/quantization/__init__.py @@ -2,6 +2,14 @@ # # SPDX-License-Identifier: MIT +# Inject the QuantizationArgs.format field BEFORE importing QuantizationModifier: +# the modifier compiles its nested QuantizationScheme schema at class-definition +# time, so the field must exist first or recipes carrying ``format: mx6/mx9`` are +# rejected by the cached (format-less) schema. +from .format_registry import inject_format_field + +inject_format_field() + from .base import QuantizationModifier from .gptq import GPTQModifier from .awq import AWQModifier diff --git a/alto/kernels/mx9/registry_patch.py b/alto/modifiers/quantization/format_registry.py similarity index 76% rename from alto/kernels/mx9/registry_patch.py rename to alto/modifiers/quantization/format_registry.py index e537764..0500028 100644 --- a/alto/kernels/mx9/registry_patch.py +++ b/alto/modifiers/quantization/format_registry.py @@ -1,16 +1,20 @@ # Copyright (c) 2026 Advanced Micro Devices, Inc. # # SPDX-License-Identifier: MIT -"""Runtime patches that wire emulated formats into the standard quant path. +"""Runtime patch that wires emulated formats into the standard quant path. Importing this module injects a real ``format`` field into ``compressed_tensors.QuantizationArgs`` so recipe values like ``format: mx9`` survive pydantic parsing and become readable via ``getattr(args, "format", None)`` (by default unknown fields are silently dropped). -The actual ``fake_quantize`` dispatch (``args.format == "mx9"`` -> mx9 kernel) -lives in ``alto.models.patcher.ModelPatcher.patch_fake_quantize`` where the single -wrap of ``compressed_tensors...forward.fake_quantize`` already happens. +The actual ``fake_quantize`` dispatch (``args.format == "mx9"`` -> mx9) lives in +``alto.models.patcher.ModelPatcher.patch_fake_quantize`` where the single wrap of +``compressed_tensors...forward.fake_quantize`` already happens. + +``inject_format_field()`` is called at the top of this package's ``__init__`` (before +``QuantizationModifier`` is imported) so the field exists before the modifier +compiles its nested ``QuantizationScheme`` schema. """ from typing import Optional diff --git a/alto/kernels/mx9/quantize.py b/alto/modifiers/quantization/mx.py similarity index 78% rename from alto/kernels/mx9/quantize.py rename to alto/modifiers/quantization/mx.py index 696890b..a921d28 100644 --- a/alto/kernels/mx9/quantize.py +++ b/alto/modifiers/quantization/mx.py @@ -1,13 +1,20 @@ # Copyright (c) 2026 Advanced Micro Devices, Inc. # # SPDX-License-Identifier: MIT -"""MX9 fake-quantize kernel (block-wise emulation; no real packing/kernel inference). +"""MX6 / MX9 fake-quantize (block-wise emulation; no real packing/kernel inference). -This is a port of Quark's ``fake_quantize_mx6_mx9`` for the ``quant_bit=8`` (MX9) -case, with its helpers (``_t_exponent`` / ``_reshape_to_blocks`` / ``_pad_to_blocks``) -inlined here. +Port of Quark's ``fake_quantize_mx6_mx9`` with its helpers (``_t_exponent`` / +``_reshape_to_blocks`` / ``_pad_to_blocks``) inlined. As in Quark, MX6 and MX9 are +the SAME algorithm and differ only in the element integer width ``quant_bit``: -MX9 in a nutshell: +- MX6: ``quant_bit = 5`` +- MX9: ``quant_bit = 8`` + +so a single core (``_mx_fake_quantize``) backs both, exposed via the thin +``mx6_fake_quantize`` / ``mx9_fake_quantize`` wrappers (cf. Quark's +``partial(fake_quantize_mx6_mx9, quant_bit=...)``). + +MX in a nutshell: 1. Split the input along ``axis`` into blocks of ``block_size`` elements. 2. Each block derives a shared exponent ``max_exp = floor(log2(amax))`` from its max absolute value. @@ -17,20 +24,20 @@ 4. The shared exponent yields a power-of-two scale; the block is then round -> clamp -> dequant. -The "9" is not ``quant_bit=9``: elements are still quantized to 8-bit integers, and +The "9" is not ``quant_bit=9``: elements are still quantized to 8-bit integers and the extra "1" is the per-pair prime bit that decides the one-exponent demotion. -Intuitively: 8-bit signed integer value + a shared prime bit for exponent refinement. - - ``QUANT_BIT = 8``: elements round to an 8-bit integer grid. - - ``SHARED_PRIME_BIT_GROUP = 2``: every two adjacent elements share a prime bit. This is fake quantization: the output keeps the input dtype (bf16/fp16/fp32) with -values projected onto the MX9-representable grid. The scale is computed from the +values projected onto the MX-representable grid. The scale is computed from the data at runtime; any externally supplied compressed_tensors scale is ignored here. """ import torch -from .format import BLOCK_SIZE, QUANT_BIT, SHARED_PRIME_BIT_GROUP +BLOCK_SIZE = 16 +SHARED_PRIME_BIT_GROUP = 2 +MX9_QUANT_BIT = 8 +MX6_QUANT_BIT = 5 def _pad_to_blocks(x: torch.Tensor, block_size: int) -> torch.Tensor: @@ -79,23 +86,23 @@ def _t_exponent(t: torch.Tensor) -> torch.Tensor: return _exponent_frexp_no_exception(t) -def mx9_fake_quantize( +def _mx_fake_quantize( input_tensor: torch.Tensor, - block_size: int = BLOCK_SIZE, - quant_bit: int = QUANT_BIT, + block_size: int, + quant_bit: int, axis: int = -1, ) -> torch.Tensor: - """Block-wise MX9 fake quantization (QDQ), bit-exact with Quark's - ``fake_quantize_mx6_mx9(quant_bit=8)``. + """Block-wise MX fake quantization (QDQ), bit-exact with Quark's + ``fake_quantize_mx6_mx9``. Args: input_tensor: tensor to quantize (weight or activation). - block_size: elements per MX9 block. + block_size: elements per MX block. quant_bit: element integer width (8 for MX9, 5 for MX6). axis: axis to block along. Only ``axis=-1`` is supported on the ALTO path. """ if axis != -1: - raise NotImplementedError("mx9_fake_quantize supports axis=-1 only") + raise NotImplementedError("mx fake_quantize supports axis=-1 only") input_dtype = input_tensor.dtype @@ -132,7 +139,7 @@ def mx9_fake_quantize( # Demoted pairs drop one exponent (the prime-bit refinement). shared_exp = max_exp - demote.long() - # Final power-of-two scale: 2 ** (shared_exp - quant_bit + 2) (=2**(shared_exp-6) for MX9). + # Final power-of-two scale: 2 ** (shared_exp - quant_bit + 2). scale = torch.pow(2.0, shared_exp - quant_bit + 2) # Max representable integer magnitude. clamp_max guards against float overflow @@ -153,3 +160,23 @@ def mx9_fake_quantize( output_tensor = output_tensor.reshape(output_tensor.size(0), -1) output_tensor = output_tensor[:, : input_shape[-1]].reshape(input_shape).to(input_dtype) return output_tensor.transpose(axis, -1) + + +def mx9_fake_quantize( + input_tensor: torch.Tensor, + block_size: int = BLOCK_SIZE, + quant_bit: int = MX9_QUANT_BIT, + axis: int = -1, +) -> torch.Tensor: + """MX9 block-wise fake quantization (QDQ); ``quant_bit=8``.""" + return _mx_fake_quantize(input_tensor, block_size=block_size, quant_bit=quant_bit, axis=axis) + + +def mx6_fake_quantize( + input_tensor: torch.Tensor, + block_size: int = BLOCK_SIZE, + quant_bit: int = MX6_QUANT_BIT, + axis: int = -1, +) -> torch.Tensor: + """MX6 block-wise fake quantization (QDQ); same math as MX9 with ``quant_bit=5``.""" + return _mx_fake_quantize(input_tensor, block_size=block_size, quant_bit=quant_bit, axis=axis) diff --git a/tests/unittest/mx9_mx6/test_mx6_dispatch.py b/tests/unittest/mx9_mx6/test_mx6_dispatch.py index f44aa47..870f776 100644 --- a/tests/unittest/mx9_mx6/test_mx6_dispatch.py +++ b/tests/unittest/mx9_mx6/test_mx6_dispatch.py @@ -9,7 +9,7 @@ ``format == "mx6"`` must route to ``mx6_fake_quantize``, while plain int8 args (no ``format``) must fall through to the original implementation untouched. -Also checks that ``registry_patch.inject_format_field()`` makes the ``format`` +Also checks that ``format_registry.inject_format_field()`` makes the ``format`` field survive pydantic validation (otherwise the recipe value is silently dropped and dispatch never fires). @@ -23,8 +23,7 @@ from compressed_tensors.quantization import QuantizationArgs from compressed_tensors.quantization.lifecycle import forward as forward_module -from alto.kernels.mx6.format import BLOCK_SIZE -from alto.kernels.mx6.quantize import mx6_fake_quantize +from alto.modifiers.quantization.mx import BLOCK_SIZE, mx6_fake_quantize def _mx6_args() -> QuantizationArgs: @@ -52,7 +51,7 @@ def _int8_args() -> QuantizationArgs: # --------------------------------------------------------------------------- # -# format field injection (registry_patch) +# format field injection (alto.modifiers.quantization.format_registry) # --------------------------------------------------------------------------- # def test_format_field_survives_validation(): """Without inject_format_field(), pydantic drops the unknown ``format`` key diff --git a/tests/unittest/mx9_mx6/test_mx6_quantize.py b/tests/unittest/mx9_mx6/test_mx6_quantize.py index 35841c4..f7769e3 100644 --- a/tests/unittest/mx9_mx6/test_mx6_quantize.py +++ b/tests/unittest/mx9_mx6/test_mx6_quantize.py @@ -11,8 +11,7 @@ import pytest import torch -from alto.kernels.mx6.format import BLOCK_SIZE, QUANT_BIT -from alto.kernels.mx6.quantize import mx6_fake_quantize +from alto.modifiers.quantization.mx import BLOCK_SIZE, MX6_QUANT_BIT as QUANT_BIT, mx6_fake_quantize def _quark_mx6_reference(x: torch.Tensor, block_size: int): diff --git a/tests/unittest/mx9_mx6/test_mx9_dispatch.py b/tests/unittest/mx9_mx6/test_mx9_dispatch.py index a118d79..c22745a 100644 --- a/tests/unittest/mx9_mx6/test_mx9_dispatch.py +++ b/tests/unittest/mx9_mx6/test_mx9_dispatch.py @@ -9,7 +9,7 @@ ``format == "mx9"`` must route to ``mx9_fake_quantize``, while plain int8 args (no ``format``) must fall through to the original implementation untouched. -Also checks that ``registry_patch.inject_format_field()`` makes the ``format`` +Also checks that ``format_registry.inject_format_field()`` makes the ``format`` field survive pydantic validation (otherwise the recipe value is silently dropped and dispatch never fires). @@ -23,8 +23,7 @@ from compressed_tensors.quantization import QuantizationArgs from compressed_tensors.quantization.lifecycle import forward as forward_module -from alto.kernels.mx9.format import BLOCK_SIZE -from alto.kernels.mx9.quantize import mx9_fake_quantize +from alto.modifiers.quantization.mx import BLOCK_SIZE, mx9_fake_quantize def _mx9_args() -> QuantizationArgs: @@ -52,7 +51,7 @@ def _int8_args() -> QuantizationArgs: # --------------------------------------------------------------------------- # -# format field injection (registry_patch) +# format field injection (alto.modifiers.quantization.format_registry) # --------------------------------------------------------------------------- # def test_format_field_survives_validation(): """Without inject_format_field(), pydantic drops the unknown ``format`` key diff --git a/tests/unittest/mx9_mx6/test_mx9_quantize.py b/tests/unittest/mx9_mx6/test_mx9_quantize.py index 3dabedc..a148463 100644 --- a/tests/unittest/mx9_mx6/test_mx9_quantize.py +++ b/tests/unittest/mx9_mx6/test_mx9_quantize.py @@ -14,8 +14,7 @@ import pytest import torch -from alto.kernels.mx9.format import BLOCK_SIZE, QUANT_BIT -from alto.kernels.mx9.quantize import mx9_fake_quantize +from alto.modifiers.quantization.mx import BLOCK_SIZE, MX9_QUANT_BIT as QUANT_BIT, mx9_fake_quantize # --------------------------------------------------------------------------- # From 0891d9b243c324fe711735d18d9d38f287b9fd35 Mon Sep 17 00:00:00 2001 From: jiarwang Date: Mon, 15 Jun 2026 09:04:04 +0000 Subject: [PATCH 5/5] fix: mx: address Copilot review feedback on MX6/MX9 PTQ flow - mx.py: replace assert on block_size divisibility with an explicit ValueError (asserts can be stripped under python -O and give an unhelpful message). - patcher.py: thread args.num_bits through the mx9/mx6 dispatch (falling back to the format default) instead of silently ignoring it, so an inconsistent num_bits no longer quantizes at the wrong width. - mx9/mx6 W+A integration tests: route all sys.modules mutations through monkeypatch.setitem so stubs are reverted at teardown and cannot leak into other tests in the same process. - examples/llama3.2_1b_mx9.sh: require MODEL_PATH instead of a hardcoded /wekafs path and default the checkpoint folder under ./outputs. --- alto/models/patcher.py | 30 ++++++-- alto/modifiers/quantization/mx.py | 6 +- examples/llama3.2_1b_mx9.sh | 17 ++-- .../mx9_mx6/test_mx6_wa_integration.py | 77 +++++++++---------- .../mx9_mx6/test_mx9_wa_integration.py | 74 +++++++++--------- 5 files changed, 109 insertions(+), 95 deletions(-) diff --git a/alto/models/patcher.py b/alto/models/patcher.py index 3129436..8c75135 100644 --- a/alto/models/patcher.py +++ b/alto/models/patcher.py @@ -58,15 +58,29 @@ class FakeQuantizeFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, scale, zero_point, args, g_idx, global_scale): if getattr(args, "format", None) == "mx9": - from alto.kernels.mx9.format import BLOCK_SIZE - from alto.kernels.mx9.quantize import mx9_fake_quantize - - return mx9_fake_quantize(x, block_size=(args.group_size or BLOCK_SIZE)) + from alto.modifiers.quantization.mx import ( + BLOCK_SIZE, + MX9_QUANT_BIT, + mx9_fake_quantize, + ) + + return mx9_fake_quantize( + x, + block_size=(args.group_size or BLOCK_SIZE), + quant_bit=(args.num_bits or MX9_QUANT_BIT), + ) if getattr(args, "format", None) == "mx6": - from alto.kernels.mx6.format import BLOCK_SIZE - from alto.kernels.mx6.quantize import mx6_fake_quantize - - return mx6_fake_quantize(x, block_size=(args.group_size or BLOCK_SIZE)) + from alto.modifiers.quantization.mx import ( + BLOCK_SIZE, + MX6_QUANT_BIT, + mx6_fake_quantize, + ) + + return mx6_fake_quantize( + x, + block_size=(args.group_size or BLOCK_SIZE), + quant_bit=(args.num_bits or MX6_QUANT_BIT), + ) return original_fake_quantize(x, scale, zero_point, args, g_idx, global_scale) @staticmethod diff --git a/alto/modifiers/quantization/mx.py b/alto/modifiers/quantization/mx.py index a921d28..bd48d71 100644 --- a/alto/modifiers/quantization/mx.py +++ b/alto/modifiers/quantization/mx.py @@ -130,7 +130,11 @@ def _mx_fake_quantize( # Prime bit: a pair is demoted only if BOTH its elements are below the block max. # Reshape last dim into pairs, AND within each pair, then broadcast back. n = SHARED_PRIME_BIT_GROUP - assert demote.shape[-1] % n == 0 + if demote.shape[-1] % n != 0: + raise ValueError( + f"block_size ({block_size}) must be a multiple of " + f"SHARED_PRIME_BIT_GROUP ({n})" + ) flat_shape = demote.shape demote = demote.reshape(*flat_shape[:-1], flat_shape[-1] // n, n) demote = torch.sum(demote, -1, keepdim=True) == n diff --git a/examples/llama3.2_1b_mx9.sh b/examples/llama3.2_1b_mx9.sh index 6b92ba0..204794a 100755 --- a/examples/llama3.2_1b_mx9.sh +++ b/examples/llama3.2_1b_mx9.sh @@ -7,10 +7,10 @@ # Weight and input activations are quantized dynamically through # alto/models/llama3/configs/mx9_wa_recipe.yaml # -# Usage: -# bash examples/llama3.2_1b_mx9.sh -# VALIDATOR_STEPS=100 bash examples/llama3.2_1b_mx9.sh -# CONFIG=llama3_1b bash examples/llama3.2_1b_mx9.sh # BF16 baseline +# Usage (MODEL_PATH is required, point it at your local Llama-3.2-1B dir): +# MODEL_PATH=/path/to/Llama-3.2-1B bash examples/llama3.2_1b_mx9.sh +# MODEL_PATH=/path/to/Llama-3.2-1B VALIDATOR_STEPS=100 bash examples/llama3.2_1b_mx9.sh +# MODEL_PATH=/path/to/Llama-3.2-1B CONFIG=llama3_1b bash examples/llama3.2_1b_mx9.sh # BF16 baseline rm -rf outputs/ set -ex @@ -23,9 +23,14 @@ MODULE=${MODULE:-"llama3"} CONFIG=${CONFIG:-"llama3_1b_mx9_wa"} COMM_MODE=${COMM_MODE:-""} -MODEL_PATH=${MODEL_PATH:-"/wekafs/jiarwang/Llama-3.2-1B"} +MODEL_PATH=${MODEL_PATH:-""} +if [ -z "${MODEL_PATH}" ]; then + echo "ERROR: MODEL_PATH must be set to your local Llama-3.2-1B directory, e.g." >&2 + echo " MODEL_PATH=/path/to/Llama-3.2-1B bash $0" >&2 + exit 1 +fi VALIDATOR_STEPS=${VALIDATOR_STEPS:-"10"} -CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-"/wekafs/jiarwang/mx9_e2e_logs/ckpt_${CONFIG}_$(date +%Y%m%d_%H%M%S)"} +CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-"./outputs/ckpt_${CONFIG}_$(date +%Y%m%d_%H%M%S)"} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} diff --git a/tests/unittest/mx9_mx6/test_mx6_wa_integration.py b/tests/unittest/mx9_mx6/test_mx6_wa_integration.py index ca45c41..a9b4b9d 100644 --- a/tests/unittest/mx9_mx6/test_mx6_wa_integration.py +++ b/tests/unittest/mx9_mx6/test_mx6_wa_integration.py @@ -22,26 +22,26 @@ ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) -def _make_pkg(name: str, path: str) -> types.ModuleType: +def _make_pkg(monkeypatch, name: str, path: str) -> types.ModuleType: mod = types.ModuleType(name) mod.__path__ = [path] mod.__package__ = name - sys.modules[name] = mod + monkeypatch.setitem(sys.modules, name, mod) return mod -def _load(modname: str, relpath: str) -> types.ModuleType: +def _load(monkeypatch, modname: str, relpath: str) -> types.ModuleType: spec = importlib.util.spec_from_file_location(modname, os.path.join(ROOT, relpath)) mod = importlib.util.module_from_spec(spec) - sys.modules[modname] = mod + monkeypatch.setitem(sys.modules, modname, mod) spec.loader.exec_module(mod) return mod -def _install_torchtitan_stubs() -> None: +def _install_torchtitan_stubs(monkeypatch) -> None: """Provide the tiny torchtitan surface needed by quantization modules.""" - _make_pkg("torchtitan", "") - _make_pkg("torchtitan.tools", "") + _make_pkg(monkeypatch, "torchtitan", "") + _make_pkg(monkeypatch, "torchtitan.tools", "") logging_mod = types.ModuleType("torchtitan.tools.logging") @@ -56,35 +56,27 @@ def warning(self, *args, **kwargs): pass logging_mod.logger = _Logger() - sys.modules["torchtitan.tools.logging"] = logging_mod + monkeypatch.setitem(sys.modules, "torchtitan.tools.logging", logging_mod) utils_mod = types.ModuleType("torchtitan.tools.utils") utils_mod.device_type = torch.device("cpu") - sys.modules["torchtitan.tools.utils"] = utils_mod + monkeypatch.setitem(sys.modules, "torchtitan.tools.utils", utils_mod) -def _load_quantization_modifier(): +def _load_quantization_modifier(monkeypatch): """Load only the ALTO modules needed for QuantizationModifier. Importing top-level ``alto`` currently pulls unrelated torchtitan entry points. This test loads the required modules directly. MX6 reuses the MX9 - math, so both kernel packages are loaded; the ``format`` field injection - lives in ``alto.kernels.mx9.registry_patch``. + math (both live in ``alto.modifiers.quantization.mx``); the ``format`` field + injection lives in ``alto.modifiers.quantization.format_registry``. All + ``sys.modules`` mutations go through ``monkeypatch`` so they are reverted at + test teardown and cannot leak into other tests in the same process. """ - _install_torchtitan_stubs() + _install_torchtitan_stubs(monkeypatch) - _make_pkg("alto", os.path.join(ROOT, "alto")) - _make_pkg("alto.kernels", os.path.join(ROOT, "alto/kernels")) - _make_pkg("alto.kernels.mx9", os.path.join(ROOT, "alto/kernels/mx9")) - _make_pkg("alto.kernels.mx6", os.path.join(ROOT, "alto/kernels/mx6")) - _make_pkg("alto.models", os.path.join(ROOT, "alto/models")) - - _load("alto.kernels.mx9.format", "alto/kernels/mx9/format.py") - _load("alto.kernels.mx9.quantize", "alto/kernels/mx9/quantize.py") - _load("alto.kernels.mx9.registry_patch", "alto/kernels/mx9/registry_patch.py") - - _load("alto.kernels.mx6.format", "alto/kernels/mx6/format.py") - quantize_mod = _load("alto.kernels.mx6.quantize", "alto/kernels/mx6/quantize.py") + _make_pkg(monkeypatch, "alto", os.path.join(ROOT, "alto")) + _make_pkg(monkeypatch, "alto.models", os.path.join(ROOT, "alto/models")) # Dynamic W/A uses no observers, but calibration.py imports Observer. observers_mod = types.ModuleType("alto.observers") @@ -95,47 +87,50 @@ def create_instance(*args, **kwargs): raise AssertionError("MX6 dynamic W+A should not create observers") observers_mod.Observer = _Observer - sys.modules["alto.observers"] = observers_mod + monkeypatch.setitem(sys.modules, "alto.observers", observers_mod) - _make_pkg("alto.modifiers", os.path.join(ROOT, "alto/modifiers")) - _make_pkg("alto.modifiers.utils", os.path.join(ROOT, "alto/modifiers/utils")) - hooks_mod = _load("alto.modifiers.utils.hooks", "alto/modifiers/utils/hooks.py") + _make_pkg(monkeypatch, "alto.modifiers", os.path.join(ROOT, "alto/modifiers")) + _make_pkg(monkeypatch, "alto.modifiers.utils", os.path.join(ROOT, "alto/modifiers/utils")) + hooks_mod = _load(monkeypatch, "alto.modifiers.utils.hooks", "alto/modifiers/utils/hooks.py") sys.modules["alto.modifiers.utils"].HooksMixin = hooks_mod.HooksMixin - modifier_base_mod = _load("alto.modifiers.base", "alto/modifiers/base.py") + modifier_base_mod = _load(monkeypatch, "alto.modifiers.base", "alto/modifiers/base.py") sys.modules["alto.modifiers"].Modifier = modifier_base_mod.Modifier - _make_pkg("alto.modifiers.quantization", os.path.join(ROOT, "alto/modifiers/quantization")) - _load("alto.modifiers.quantization.calibration", "alto/modifiers/quantization/calibration.py") - _load("alto.modifiers.quantization.mixin", "alto/modifiers/quantization/mixin.py") + _make_pkg(monkeypatch, "alto.modifiers.quantization", os.path.join(ROOT, "alto/modifiers/quantization")) + # Inject the format field (and load the MX module) before QuantizationModifier. + _load(monkeypatch, "alto.modifiers.quantization.format_registry", "alto/modifiers/quantization/format_registry.py") + quantize_mod = _load(monkeypatch, "alto.modifiers.quantization.mx", "alto/modifiers/quantization/mx.py") + _load(monkeypatch, "alto.modifiers.quantization.calibration", "alto/modifiers/quantization/calibration.py") + _load(monkeypatch, "alto.modifiers.quantization.mixin", "alto/modifiers/quantization/mixin.py") # base.py imports get_layers for sequential mode; this test uses sequential=False. - _make_pkg("alto.utils", os.path.join(ROOT, "alto/utils")) - _make_pkg("alto.utils.pytorch", os.path.join(ROOT, "alto/utils/pytorch")) + _make_pkg(monkeypatch, "alto.utils", os.path.join(ROOT, "alto/utils")) + _make_pkg(monkeypatch, "alto.utils.pytorch", os.path.join(ROOT, "alto/utils/pytorch")) module_utils = types.ModuleType("alto.utils.pytorch.module") module_utils.get_layers = lambda *args, **kwargs: {} - sys.modules["alto.utils.pytorch.module"] = module_utils + monkeypatch.setitem(sys.modules, "alto.utils.pytorch.module", module_utils) - quant_base_mod = _load("alto.modifiers.quantization.base", "alto/modifiers/quantization/base.py") + quant_base_mod = _load(monkeypatch, "alto.modifiers.quantization.base", "alto/modifiers/quantization/base.py") - patcher_mod = _load("alto.models.patcher", "alto/models/patcher.py") + patcher_mod = _load(monkeypatch, "alto.models.patcher", "alto/models/patcher.py") patcher_mod.ModelPatcher.patch_fake_quantize() return quant_base_mod.QuantizationModifier, quantize_mod -def _load_mx6_modifier_from_recipe(): +def _load_mx6_modifier_from_recipe(monkeypatch): recipe_path = os.path.join(ROOT, "alto/models/llama3/configs/mx6_wa_recipe.yaml") with open(recipe_path, "r") as f: recipe = yaml.safe_load(f) mod_args = recipe["quantization_stage"]["quantization_modifiers"]["QuantizationModifier"] - QuantizationModifier, quantize_mod = _load_quantization_modifier() + QuantizationModifier, quantize_mod = _load_quantization_modifier(monkeypatch) return QuantizationModifier(**mod_args), quantize_mod def test_mx6_wa_recipe_toy_linear_lifecycle(monkeypatch): - modifier, quantize_mod = _load_mx6_modifier_from_recipe() + modifier, quantize_mod = _load_mx6_modifier_from_recipe(monkeypatch) model = torch.nn.Sequential(torch.nn.Linear(16, 16, bias=False)) linear = model[0] diff --git a/tests/unittest/mx9_mx6/test_mx9_wa_integration.py b/tests/unittest/mx9_mx6/test_mx9_wa_integration.py index 28fdabd..f090180 100644 --- a/tests/unittest/mx9_mx6/test_mx9_wa_integration.py +++ b/tests/unittest/mx9_mx6/test_mx9_wa_integration.py @@ -22,26 +22,26 @@ ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) -def _make_pkg(name: str, path: str) -> types.ModuleType: +def _make_pkg(monkeypatch, name: str, path: str) -> types.ModuleType: mod = types.ModuleType(name) mod.__path__ = [path] mod.__package__ = name - sys.modules[name] = mod + monkeypatch.setitem(sys.modules, name, mod) return mod -def _load(modname: str, relpath: str) -> types.ModuleType: +def _load(monkeypatch, modname: str, relpath: str) -> types.ModuleType: spec = importlib.util.spec_from_file_location(modname, os.path.join(ROOT, relpath)) mod = importlib.util.module_from_spec(spec) - sys.modules[modname] = mod + monkeypatch.setitem(sys.modules, modname, mod) spec.loader.exec_module(mod) return mod -def _install_torchtitan_stubs() -> None: +def _install_torchtitan_stubs(monkeypatch) -> None: """Provide the tiny torchtitan surface needed by quantization modules.""" - _make_pkg("torchtitan", "") - _make_pkg("torchtitan.tools", "") + _make_pkg(monkeypatch, "torchtitan", "") + _make_pkg(monkeypatch, "torchtitan.tools", "") logging_mod = types.ModuleType("torchtitan.tools.logging") @@ -56,32 +56,25 @@ def warning(self, *args, **kwargs): pass logging_mod.logger = _Logger() - sys.modules["torchtitan.tools.logging"] = logging_mod + monkeypatch.setitem(sys.modules, "torchtitan.tools.logging", logging_mod) utils_mod = types.ModuleType("torchtitan.tools.utils") utils_mod.device_type = torch.device("cpu") - sys.modules["torchtitan.tools.utils"] = utils_mod + monkeypatch.setitem(sys.modules, "torchtitan.tools.utils", utils_mod) -def _load_quantization_modifier(): +def _load_quantization_modifier(monkeypatch): """Load only the ALTO modules needed for QuantizationModifier. Importing top-level ``alto`` currently pulls unrelated torchtitan entry - points, so the test loads the required modules directly. + points, so the test loads the required modules directly. All ``sys.modules`` + mutations go through ``monkeypatch`` so they are reverted at test teardown + and cannot leak into other tests in the same process. """ - _install_torchtitan_stubs() + _install_torchtitan_stubs(monkeypatch) - _make_pkg("alto", os.path.join(ROOT, "alto")) - _make_pkg("alto.kernels", os.path.join(ROOT, "alto/kernels")) - _make_pkg("alto.kernels.mx9", os.path.join(ROOT, "alto/kernels/mx9")) - _make_pkg("alto.models", os.path.join(ROOT, "alto/models")) - - _load("alto.kernels.mx9.format", "alto/kernels/mx9/format.py") - quantize_mod = _load( - "alto.kernels.mx9.quantize", - "alto/kernels/mx9/quantize.py", - ) - _load("alto.kernels.mx9.registry_patch", "alto/kernels/mx9/registry_patch.py") + _make_pkg(monkeypatch, "alto", os.path.join(ROOT, "alto")) + _make_pkg(monkeypatch, "alto.models", os.path.join(ROOT, "alto/models")) # Dynamic W/A uses no observers, but calibration.py imports Observer. observers_mod = types.ModuleType("alto.observers") @@ -92,47 +85,50 @@ def create_instance(*args, **kwargs): raise AssertionError("MX9 dynamic W+A should not create observers") observers_mod.Observer = _Observer - sys.modules["alto.observers"] = observers_mod + monkeypatch.setitem(sys.modules, "alto.observers", observers_mod) - _make_pkg("alto.modifiers", os.path.join(ROOT, "alto/modifiers")) - _make_pkg("alto.modifiers.utils", os.path.join(ROOT, "alto/modifiers/utils")) - hooks_mod = _load("alto.modifiers.utils.hooks", "alto/modifiers/utils/hooks.py") + _make_pkg(monkeypatch, "alto.modifiers", os.path.join(ROOT, "alto/modifiers")) + _make_pkg(monkeypatch, "alto.modifiers.utils", os.path.join(ROOT, "alto/modifiers/utils")) + hooks_mod = _load(monkeypatch, "alto.modifiers.utils.hooks", "alto/modifiers/utils/hooks.py") sys.modules["alto.modifiers.utils"].HooksMixin = hooks_mod.HooksMixin - modifier_base_mod = _load("alto.modifiers.base", "alto/modifiers/base.py") + modifier_base_mod = _load(monkeypatch, "alto.modifiers.base", "alto/modifiers/base.py") sys.modules["alto.modifiers"].Modifier = modifier_base_mod.Modifier - _make_pkg("alto.modifiers.quantization", os.path.join(ROOT, "alto/modifiers/quantization")) - _load("alto.modifiers.quantization.calibration", "alto/modifiers/quantization/calibration.py") - _load("alto.modifiers.quantization.mixin", "alto/modifiers/quantization/mixin.py") + _make_pkg(monkeypatch, "alto.modifiers.quantization", os.path.join(ROOT, "alto/modifiers/quantization")) + # Inject the format field (and load the MX module) before QuantizationModifier. + _load(monkeypatch, "alto.modifiers.quantization.format_registry", "alto/modifiers/quantization/format_registry.py") + quantize_mod = _load(monkeypatch, "alto.modifiers.quantization.mx", "alto/modifiers/quantization/mx.py") + _load(monkeypatch, "alto.modifiers.quantization.calibration", "alto/modifiers/quantization/calibration.py") + _load(monkeypatch, "alto.modifiers.quantization.mixin", "alto/modifiers/quantization/mixin.py") # base.py imports get_layers for sequential mode; this test uses sequential=False. - _make_pkg("alto.utils", os.path.join(ROOT, "alto/utils")) - _make_pkg("alto.utils.pytorch", os.path.join(ROOT, "alto/utils/pytorch")) + _make_pkg(monkeypatch, "alto.utils", os.path.join(ROOT, "alto/utils")) + _make_pkg(monkeypatch, "alto.utils.pytorch", os.path.join(ROOT, "alto/utils/pytorch")) module_utils = types.ModuleType("alto.utils.pytorch.module") module_utils.get_layers = lambda *args, **kwargs: {} - sys.modules["alto.utils.pytorch.module"] = module_utils + monkeypatch.setitem(sys.modules, "alto.utils.pytorch.module", module_utils) - quant_base_mod = _load("alto.modifiers.quantization.base", "alto/modifiers/quantization/base.py") + quant_base_mod = _load(monkeypatch, "alto.modifiers.quantization.base", "alto/modifiers/quantization/base.py") - patcher_mod = _load("alto.models.patcher", "alto/models/patcher.py") + patcher_mod = _load(monkeypatch, "alto.models.patcher", "alto/models/patcher.py") patcher_mod.ModelPatcher.patch_fake_quantize() return quant_base_mod.QuantizationModifier, quantize_mod -def _load_mx9_modifier_from_recipe(): +def _load_mx9_modifier_from_recipe(monkeypatch): recipe_path = os.path.join(ROOT, "alto/models/llama3/configs/mx9_wa_recipe.yaml") with open(recipe_path, "r") as f: recipe = yaml.safe_load(f) mod_args = recipe["quantization_stage"]["quantization_modifiers"]["QuantizationModifier"] - QuantizationModifier, quantize_mod = _load_quantization_modifier() + QuantizationModifier, quantize_mod = _load_quantization_modifier(monkeypatch) return QuantizationModifier(**mod_args), quantize_mod def test_mx9_wa_recipe_toy_linear_lifecycle(monkeypatch): - modifier, quantize_mod = _load_mx9_modifier_from_recipe() + modifier, quantize_mod = _load_mx9_modifier_from_recipe(monkeypatch) model = torch.nn.Sequential(torch.nn.Linear(16, 16, bias=False)) linear = model[0]