diff --git a/alto/models/llama3/config_registry.py b/alto/models/llama3/config_registry.py index e499b0c..0dd990a 100644 --- a/alto/models/llama3/config_registry.py +++ b/alto/models/llama3/config_registry.py @@ -28,6 +28,8 @@ "llama3_8b_lpt", "llama3_1b_gptq", "llama3_1b_awq", + "llama3_1b_mx9_wa", + "llama3_1b_mx6_wa", "llama3_8b", "llama3_8b_gptq", "llama3_8b_rtn", @@ -182,6 +184,26 @@ def llama3_1b_awq() -> Trainer.Config: return config +def llama3_1b_mx9_wa() -> Trainer.Config: + config = llama3_1b() + config.training.steps = 1 + config.optimizer = OptimizersContainer.Config(lr=0.0) + config.model_converters = ModelConvertersContainer.Config(converters=[ + ModelOptConverter.Config(recipe="./alto/models/llama3/configs/mx9_wa_recipe.yaml",), + ],) + return config + + +def llama3_1b_mx6_wa() -> Trainer.Config: + config = llama3_1b() + config.training.steps = 1 + config.optimizer = OptimizersContainer.Config(lr=0.0) + config.model_converters = ModelConvertersContainer.Config(converters=[ + ModelOptConverter.Config(recipe="./alto/models/llama3/configs/mx6_wa_recipe.yaml",), + ],) + return config + + LLAMA3_8B_PATH = "/workspace/Model-Optimizer/models/meta-llama/Llama-3.1-8B" diff --git a/alto/models/llama3/configs/mx6_wa_recipe.yaml b/alto/models/llama3/configs/mx6_wa_recipe.yaml new file mode 100644 index 0000000..5ab4d3b --- /dev/null +++ b/alto/models/llama3/configs/mx6_wa_recipe.yaml @@ -0,0 +1,25 @@ +# MX6 W+A fully-dynamic recipe (single GPU, dense). +# Weight and input-activation are both quantized, both dynamic -> QDQ happens in +# the forward pass, no calibration / observer / weight baking required. +quantization_stage: + quantization_modifiers: + QuantizationModifier: + ignore: ["output"] # lm_head not quantized + sequential: false + config_groups: + group_0: + targets: ["Linear"] + weights: # fully dynamic: QDQ in forward, no baking + num_bits: 5 + type: "int" + symmetric: true + strategy: "tensor" + dynamic: true + format: "mx6" # <- hits mx6 dispatch + input_activations: # fully dynamic + num_bits: 5 + type: "int" + symmetric: true + strategy: "tensor" + dynamic: true + format: "mx6" diff --git a/alto/models/llama3/configs/mx9_wa_recipe.yaml b/alto/models/llama3/configs/mx9_wa_recipe.yaml new file mode 100644 index 0000000..3df4841 --- /dev/null +++ b/alto/models/llama3/configs/mx9_wa_recipe.yaml @@ -0,0 +1,25 @@ +# MX9 W+A fully-dynamic recipe (single GPU, dense). +# Weight and input-activation are both quantized, both dynamic -> QDQ happens in +# the forward pass, no calibration / observer / weight baking required. +quantization_stage: + quantization_modifiers: + QuantizationModifier: + ignore: ["output"] # lm_head not quantized + sequential: false + config_groups: + group_0: + targets: ["Linear"] + weights: # fully dynamic: QDQ in forward, no baking + num_bits: 8 + type: "int" + symmetric: true + strategy: "tensor" + dynamic: true + format: "mx9" # <- hits mx9 dispatch + input_activations: # fully dynamic + num_bits: 8 + type: "int" + symmetric: true + strategy: "tensor" + dynamic: true + format: "mx9" diff --git a/alto/models/patcher.py b/alto/models/patcher.py index 6eee519..8c75135 100644 --- a/alto/models/patcher.py +++ b/alto/models/patcher.py @@ -57,6 +57,30 @@ class FakeQuantizeFunction(torch.autograd.Function): @staticmethod def forward(ctx, x, scale, zero_point, args, g_idx, global_scale): + if getattr(args, "format", None) == "mx9": + from alto.modifiers.quantization.mx import ( + BLOCK_SIZE, + MX9_QUANT_BIT, + mx9_fake_quantize, + ) + + return mx9_fake_quantize( + x, + block_size=(args.group_size or BLOCK_SIZE), + quant_bit=(args.num_bits or MX9_QUANT_BIT), + ) + if getattr(args, "format", None) == "mx6": + from alto.modifiers.quantization.mx import ( + BLOCK_SIZE, + MX6_QUANT_BIT, + mx6_fake_quantize, + ) + + return mx6_fake_quantize( + x, + block_size=(args.group_size or BLOCK_SIZE), + quant_bit=(args.num_bits or MX6_QUANT_BIT), + ) return original_fake_quantize(x, scale, zero_point, args, g_idx, global_scale) @staticmethod diff --git a/alto/modifiers/quantization/__init__.py b/alto/modifiers/quantization/__init__.py index 3a86389..c7cd600 100644 --- a/alto/modifiers/quantization/__init__.py +++ b/alto/modifiers/quantization/__init__.py @@ -2,6 +2,14 @@ # # SPDX-License-Identifier: MIT +# Inject the QuantizationArgs.format field BEFORE importing QuantizationModifier: +# the modifier compiles its nested QuantizationScheme schema at class-definition +# time, so the field must exist first or recipes carrying ``format: mx6/mx9`` are +# rejected by the cached (format-less) schema. +from .format_registry import inject_format_field + +inject_format_field() + from .base import QuantizationModifier from .gptq import GPTQModifier from .awq import AWQModifier diff --git a/alto/modifiers/quantization/base.py b/alto/modifiers/quantization/base.py index 4aba465..3511fdd 100644 --- a/alto/modifiers/quantization/base.py +++ b/alto/modifiers/quantization/base.py @@ -15,7 +15,10 @@ import torch import tqdm -from compressed_tensors.quantization import disable_quantization, enable_quantization +from compressed_tensors.quantization import ( + disable_quantization, + enable_quantization, +) from compressed_tensors.utils import getattr_chain, match_named_modules from pydantic import Field, PrivateAttr from torch.nn import Module @@ -28,6 +31,7 @@ __all__ = ["QuantizationModifier"] + DEFAULT_SUBLAYER_GROUPS = [ ["wq", "wk", "wv", "q_proj", "k_proj", "v_proj"], ["wo", "o_proj"], diff --git a/alto/modifiers/quantization/calibration.py b/alto/modifiers/quantization/calibration.py index f340f71..7eb54b2 100644 --- a/alto/modifiers/quantization/calibration.py +++ b/alto/modifiers/quantization/calibration.py @@ -135,6 +135,11 @@ def update_weight_zp_scale(module: Module): if getattr_chain(module, "quantization_scheme.weights", None) is None: return + # Fully dynamic weights have no observer: their QDQ runs live in the patched + # forward on every pass, so there is no static scale to compute or bake. + if not hasattr(module, "weight_observer"): + return + if getattr(module, "quantization_status", None) != QuantizationStatus.CALIBRATION: logger.warning("Attempting to calibrate weights of a module not in calibration mode") diff --git a/alto/modifiers/quantization/format_registry.py b/alto/modifiers/quantization/format_registry.py new file mode 100644 index 0000000..0500028 --- /dev/null +++ b/alto/modifiers/quantization/format_registry.py @@ -0,0 +1,48 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Runtime patch that wires emulated formats into the standard quant path. + +Importing this module injects a real ``format`` field into +``compressed_tensors.QuantizationArgs`` so recipe values like ``format: mx9`` +survive pydantic parsing and become readable via ``getattr(args, "format", None)`` +(by default unknown fields are silently dropped). + +The actual ``fake_quantize`` dispatch (``args.format == "mx9"`` -> mx9) lives in +``alto.models.patcher.ModelPatcher.patch_fake_quantize`` where the single wrap of +``compressed_tensors...forward.fake_quantize`` already happens. + +``inject_format_field()`` is called at the top of this package's ``__init__`` (before +``QuantizationModifier`` is imported) so the field exists before the modifier +compiles its nested ``QuantizationScheme`` schema. +""" + +from typing import Optional + +_FORMAT_FIELD_INJECTED = False + + +def inject_format_field() -> None: + """Add ``format: Optional[str] = None`` to ``QuantizationArgs`` (idempotent).""" + global _FORMAT_FIELD_INJECTED + if _FORMAT_FIELD_INJECTED: + return + + from pydantic.fields import FieldInfo + from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig, QuantizationScheme + + if "format" not in QuantizationArgs.model_fields: + QuantizationArgs.model_fields["format"] = FieldInfo( + annotation=Optional[str], default=None + ) + QuantizationArgs.model_rebuild(force=True) + # QuantizationArgs is nested inside these models. Rebuild them as well so + # recipe dictionaries with weights/input_activations.format are accepted + # instead of being rejected by the old cached schema. + QuantizationScheme.model_rebuild(force=True) + QuantizationConfig.model_rebuild(force=True) + + _FORMAT_FIELD_INJECTED = True + + +inject_format_field() diff --git a/alto/modifiers/quantization/mx.py b/alto/modifiers/quantization/mx.py new file mode 100644 index 0000000..bd48d71 --- /dev/null +++ b/alto/modifiers/quantization/mx.py @@ -0,0 +1,186 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""MX6 / MX9 fake-quantize (block-wise emulation; no real packing/kernel inference). + +Port of Quark's ``fake_quantize_mx6_mx9`` with its helpers (``_t_exponent`` / +``_reshape_to_blocks`` / ``_pad_to_blocks``) inlined. As in Quark, MX6 and MX9 are +the SAME algorithm and differ only in the element integer width ``quant_bit``: + +- MX6: ``quant_bit = 5`` +- MX9: ``quant_bit = 8`` + +so a single core (``_mx_fake_quantize``) backs both, exposed via the thin +``mx6_fake_quantize`` / ``mx9_fake_quantize`` wrappers (cf. Quark's +``partial(fake_quantize_mx6_mx9, quant_bit=...)``). + +MX in a nutshell: + 1. Split the input along ``axis`` into blocks of ``block_size`` elements. + 2. Each block derives a shared exponent ``max_exp = floor(log2(amax))`` from its + max absolute value. + 3. Every ``SHARED_PRIME_BIT_GROUP`` (=2) adjacent elements share one prime bit: + if both elements in a pair sit at least one exponent below the block max, the + pair uses ``max_exp - 1`` (finer resolution); otherwise ``max_exp``. + 4. The shared exponent yields a power-of-two scale; the block is then + round -> clamp -> dequant. + +The "9" is not ``quant_bit=9``: elements are still quantized to 8-bit integers and +the extra "1" is the per-pair prime bit that decides the one-exponent demotion. + +This is fake quantization: the output keeps the input dtype (bf16/fp16/fp32) with +values projected onto the MX-representable grid. The scale is computed from the +data at runtime; any externally supplied compressed_tensors scale is ignored here. +""" + +import torch + +BLOCK_SIZE = 16 +SHARED_PRIME_BIT_GROUP = 2 +MX9_QUANT_BIT = 8 +MX6_QUANT_BIT = 5 + + +def _pad_to_blocks(x: torch.Tensor, block_size: int) -> torch.Tensor: + """Right-pad the last dim with zeros to a multiple of ``block_size``.""" + pad = block_size - x.size(-1) % block_size + if pad == block_size: + return x + return torch.nn.functional.pad(x, (0, pad)) + + +def _reshape_to_blocks(x: torch.Tensor, block_size: int, axis: int) -> torch.Tensor: + """Reshape into ``[num_rows, num_blocks, block_size]``. + + Blocking always happens along the last dim, so the requested ``axis`` is moved + there first (e.g. weight ``[out, in]`` with axis=-1 blocks each output channel's + input dim; activation ``[batch, seq, hidden]`` blocks each token's hidden vector). + """ + if axis > x.dim() - 1: + raise IndexError("Axis is larger than number of tensor dimensions") + + # Move the quantized axis to the last dim, flatten the rest, pad, then split + # the last dim into blocks. Padding is trimmed after quantization. + x = x.transpose(axis, -1) + x = x.reshape(-1, x.size(-1)) + x = _pad_to_blocks(x, block_size) + return x.reshape(x.size(0), x.size(1) // block_size, block_size) + + +def _exponent_frexp_no_exception(t: torch.Tensor) -> torch.Tensor: + """Read the exponent straight from the float bit pattern (~ floor(log2(|x|))).""" + with torch.no_grad(): + if t.dtype == torch.float32: + return ((t.view(torch.int32) >> 23) & 0xFF) - 127 + if t.dtype == torch.bfloat16: + return ((t.view(torch.int16) >> 7) & 0xFF) - 127 + if t.dtype == torch.float16: + # Matches Quark: fp16 exponent bias treated as -15. + return ((t.view(torch.int16) >> 10) & 0x1F) - 15 + raise ValueError(f"Unsupported data type: {t.dtype}") + + +def _t_exponent(t: torch.Tensor) -> torch.Tensor: + """Per-element exponent; NaN/Inf are zeroed before extraction.""" + with torch.no_grad(): + t = torch.nan_to_num(t, nan=0, posinf=0, neginf=0) + return _exponent_frexp_no_exception(t) + + +def _mx_fake_quantize( + input_tensor: torch.Tensor, + block_size: int, + quant_bit: int, + axis: int = -1, +) -> torch.Tensor: + """Block-wise MX fake quantization (QDQ), bit-exact with Quark's + ``fake_quantize_mx6_mx9``. + + Args: + input_tensor: tensor to quantize (weight or activation). + block_size: elements per MX block. + quant_bit: element integer width (8 for MX9, 5 for MX6). + axis: axis to block along. Only ``axis=-1`` is supported on the ALTO path. + """ + if axis != -1: + raise NotImplementedError("mx fake_quantize supports axis=-1 only") + + input_dtype = input_tensor.dtype + + # Shape to restore after blocking: _reshape_to_blocks moves axis to the last dim, + # so record the post-transpose shape to rebuild before transposing back. + input_shape = list(input_tensor.shape) + input_shape[-1], input_shape[axis] = input_shape[axis], input_shape[-1] + + # Detached block view used only for the per-block max exponent (no gradient). + block_x = _reshape_to_blocks(input_tensor.detach(), block_size, axis) + # Out-of-place nan_to_num: in-place breaks torch.compile/inductor functional graphs. + block_x = torch.nan_to_num(block_x, nan=0.0, posinf=0.0, neginf=0.0) + + # Per-block max abs -> shared block exponent max_exp = floor(log2(amax)). + amax, _ = torch.max(torch.abs(block_x), dim=-1, keepdim=True) + max_exp = _t_exponent(amax) # [num_rows, num_blocks, 1] + + # Block the original input for the per-element QDQ. + input_tensor = _reshape_to_blocks(input_tensor, block_size, axis) + t_exp = _t_exponent(input_tensor) # [num_rows, num_blocks, block_size] + + # Per element: is it at least one exponent below its block max? (max_exp broadcasts) + demote = max_exp - t_exp >= 1 + + # Prime bit: a pair is demoted only if BOTH its elements are below the block max. + # Reshape last dim into pairs, AND within each pair, then broadcast back. + n = SHARED_PRIME_BIT_GROUP + if demote.shape[-1] % n != 0: + raise ValueError( + f"block_size ({block_size}) must be a multiple of " + f"SHARED_PRIME_BIT_GROUP ({n})" + ) + flat_shape = demote.shape + demote = demote.reshape(*flat_shape[:-1], flat_shape[-1] // n, n) + demote = torch.sum(demote, -1, keepdim=True) == n + demote = demote.repeat(*([1] * (demote.dim() - 1)), n).reshape(flat_shape) + + # Demoted pairs drop one exponent (the prime-bit refinement). + shared_exp = max_exp - demote.long() + + # Final power-of-two scale: 2 ** (shared_exp - quant_bit + 2). + scale = torch.pow(2.0, shared_exp - quant_bit + 2) + + # Max representable integer magnitude. clamp_max guards against float overflow + # when max_exp is large: quant_max = (2**(max_exp+1) - scale) / scale. + quant_max = ( + torch.clamp_max( + torch.pow(2.0, max_exp.to(torch.float64) + 1) - scale, + torch.finfo(torch.float32).max, + ).to(torch.float32) + / scale + ) + + # QDQ: round to the integer grid, clamp to range, dequant back to float. + output_tensor = torch.round(input_tensor / scale) + output_tensor = torch.clamp(output_tensor, -quant_max, quant_max) * scale + + # Flatten blocks, drop padding, restore shape/dtype, and transpose axis back. + output_tensor = output_tensor.reshape(output_tensor.size(0), -1) + output_tensor = output_tensor[:, : input_shape[-1]].reshape(input_shape).to(input_dtype) + return output_tensor.transpose(axis, -1) + + +def mx9_fake_quantize( + input_tensor: torch.Tensor, + block_size: int = BLOCK_SIZE, + quant_bit: int = MX9_QUANT_BIT, + axis: int = -1, +) -> torch.Tensor: + """MX9 block-wise fake quantization (QDQ); ``quant_bit=8``.""" + return _mx_fake_quantize(input_tensor, block_size=block_size, quant_bit=quant_bit, axis=axis) + + +def mx6_fake_quantize( + input_tensor: torch.Tensor, + block_size: int = BLOCK_SIZE, + quant_bit: int = MX6_QUANT_BIT, + axis: int = -1, +) -> torch.Tensor: + """MX6 block-wise fake quantization (QDQ); same math as MX9 with ``quant_bit=5``.""" + return _mx_fake_quantize(input_tensor, block_size=block_size, quant_bit=quant_bit, axis=axis) diff --git a/examples/llama3.2_1b_mx9.sh b/examples/llama3.2_1b_mx9.sh new file mode 100755 index 0000000..204794a --- /dev/null +++ b/examples/llama3.2_1b_mx9.sh @@ -0,0 +1,62 @@ +#!/usr/bin/bash +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT + +# MX9 W8A8 dynamic fake-quant validation on Llama-3.2-1B +# Weight and input activations are quantized dynamically through +# alto/models/llama3/configs/mx9_wa_recipe.yaml +# +# Usage (MODEL_PATH is required, point it at your local Llama-3.2-1B dir): +# MODEL_PATH=/path/to/Llama-3.2-1B bash examples/llama3.2_1b_mx9.sh +# MODEL_PATH=/path/to/Llama-3.2-1B VALIDATOR_STEPS=100 bash examples/llama3.2_1b_mx9.sh +# MODEL_PATH=/path/to/Llama-3.2-1B CONFIG=llama3_1b bash examples/llama3.2_1b_mx9.sh # BF16 baseline +rm -rf outputs/ +set -ex + +NGPU=${NGPU:-"1"} +export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0} +export HIP_VISIBLE_DEVICES=${HIP_VISIBLE_DEVICES:-${CUDA_VISIBLE_DEVICES}} +export LOG_RANK=${LOG_RANK:-0} +TRAIN_FILE=${TRAIN_FILE:-"alto.train"} +MODULE=${MODULE:-"llama3"} +CONFIG=${CONFIG:-"llama3_1b_mx9_wa"} +COMM_MODE=${COMM_MODE:-""} + +MODEL_PATH=${MODEL_PATH:-""} +if [ -z "${MODEL_PATH}" ]; then + echo "ERROR: MODEL_PATH must be set to your local Llama-3.2-1B directory, e.g." >&2 + echo " MODEL_PATH=/path/to/Llama-3.2-1B bash $0" >&2 + exit 1 +fi +VALIDATOR_STEPS=${VALIDATOR_STEPS:-"10"} +CHECKPOINT_FOLDER=${CHECKPOINT_FOLDER:-"./outputs/ckpt_${CONFIG}_$(date +%Y%m%d_%H%M%S)"} + +TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} + +if [ -n "$COMM_MODE" ]; then + echo "Running with comm_mode=${COMM_MODE}" + NGPU="${NGPU}" LOCAL_RANK=0 python3 -m ${TRAIN_FILE} \ + --module ${MODULE} \ + --config ${CONFIG} \ + --hf_assets_path "${MODEL_PATH}" \ + --checkpoint.initial_load_path "${MODEL_PATH}" \ + --checkpoint.folder "${CHECKPOINT_FOLDER}" \ + --validator.steps "${VALIDATOR_STEPS}" \ + "$@" \ + --comm.mode=${COMM_MODE} \ + --training.steps 1 +else + PYTORCH_ALLOC_CONF="expandable_segments:True" \ + TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ + torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ + --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ + -m ${TRAIN_FILE} \ + --module ${MODULE} \ + --config ${CONFIG} \ + --hf_assets_path "${MODEL_PATH}" \ + --checkpoint.initial_load_path "${MODEL_PATH}" \ + --checkpoint.folder "${CHECKPOINT_FOLDER}" \ + --validator.steps "${VALIDATOR_STEPS}" \ + "$@" +fi diff --git a/tests/unittest/mx9_mx6/__init__.py b/tests/unittest/mx9_mx6/__init__.py new file mode 100644 index 0000000..85eac52 --- /dev/null +++ b/tests/unittest/mx9_mx6/__init__.py @@ -0,0 +1,3 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT diff --git a/tests/unittest/mx9_mx6/test_mx6_dispatch.py b/tests/unittest/mx9_mx6/test_mx6_dispatch.py new file mode 100644 index 0000000..870f776 --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx6_dispatch.py @@ -0,0 +1,120 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Tests for the MX6 dispatch wiring in ``ModelPatcher.patch_fake_quantize``. + +The MX6 kernel itself is covered by ``test_mx6_quantize.py``. This file tests the +*wiring* one layer up: after ``patch_fake_quantize()`` replaces +``compressed_tensors...forward.fake_quantize``, a ``QuantizationArgs`` carrying +``format == "mx6"`` must route to ``mx6_fake_quantize``, while plain int8 args +(no ``format``) must fall through to the original implementation untouched. + +Also checks that ``format_registry.inject_format_field()`` makes the ``format`` +field survive pydantic validation (otherwise the recipe value is silently +dropped and dispatch never fires). + +Run with: + pytest tests/unittest/mx9_mx6/test_mx6_dispatch.py +""" + +import pytest +import torch + +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization.lifecycle import forward as forward_module + +from alto.modifiers.quantization.mx import BLOCK_SIZE, mx6_fake_quantize + + +def _mx6_args() -> QuantizationArgs: + return QuantizationArgs.model_validate( + { + "num_bits": 5, + "type": "int", + "symmetric": True, + "strategy": "tensor", + "dynamic": True, + "format": "mx6", + } + ) + + +def _int8_args() -> QuantizationArgs: + return QuantizationArgs.model_validate( + { + "num_bits": 8, + "type": "int", + "symmetric": True, + "strategy": "tensor", + } + ) + + +# --------------------------------------------------------------------------- # +# format field injection (alto.modifiers.quantization.format_registry) +# --------------------------------------------------------------------------- # +def test_format_field_survives_validation(): + """Without inject_format_field(), pydantic drops the unknown ``format`` key + and dispatch can never see it.""" + args = _mx6_args() + assert getattr(args, "format", None) == "mx6" + + +def test_format_field_defaults_none_for_plain_int8(): + args = _int8_args() + assert getattr(args, "format", None) is None + + +# --------------------------------------------------------------------------- # +# dispatch routing +# --------------------------------------------------------------------------- # +def test_mx6_args_dispatch_to_mx6_kernel(): + """format == "mx6" must route fake_quantize to mx6_fake_quantize bit-exact.""" + torch.manual_seed(6) + x = torch.randn(3, 40, dtype=torch.float32) + args = _mx6_args() + + patched_out = forward_module.fake_quantize( + x=x, + scale=torch.ones(1), + zero_point=None, + args=args, + g_idx=None, + global_scale=None, + ) + expected = mx6_fake_quantize(x, block_size=BLOCK_SIZE) + + assert torch.equal(patched_out, expected) + + +def test_int8_args_fall_through_to_original(): + """Plain int8 (no format) must NOT touch the mx6 kernel; output must match a + direct INT8 per-tensor QDQ, never the mx6 result.""" + torch.manual_seed(1) + x = torch.randn(3, 40, dtype=torch.float32) + args = _int8_args() + + # per-tensor symmetric int8 scale (what the original kernel would use) + qmax = 127.0 + scale = (x.abs().max() / qmax).clamp_min(1e-12).reshape(1) + zp = torch.zeros(1, dtype=torch.int32) + + patched_out = forward_module.fake_quantize( + x=x, + scale=scale, + zero_point=zp, + args=args, + g_idx=None, + global_scale=None, + ) + + expected_int8 = (torch.round(x / scale).clamp(-128, 127)) * scale + assert torch.allclose(patched_out, expected_int8, atol=1e-5) + + # and it must differ from the mx6 path (sanity: int8 did not silently route to mx6) + mx6_out = mx6_fake_quantize(x, block_size=BLOCK_SIZE) + assert not torch.equal(patched_out, mx6_out) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v", "-p", "no:cacheprovider"])) diff --git a/tests/unittest/mx9_mx6/test_mx6_quantize.py b/tests/unittest/mx9_mx6/test_mx6_quantize.py new file mode 100644 index 0000000..f7769e3 --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx6_quantize.py @@ -0,0 +1,65 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Tests for the MX6 fake-quantize wrapper. + +MX6 shares Quark's ``fake_quantize_mx6_mx9`` algorithm with MX9, but uses +``quant_bit=5``. These tests focus on proving that the wrapper selects the MX6 +bit width and still preserves the basic QDQ invariants. +""" + +import pytest +import torch + +from alto.modifiers.quantization.mx import BLOCK_SIZE, MX6_QUANT_BIT as QUANT_BIT, mx6_fake_quantize + + +def _quark_mx6_reference(x: torch.Tensor, block_size: int): + """Live Quark reference; returns None if Quark is unavailable.""" + try: + from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( + fake_quantize_mx6_mx9, + ) + except Exception: + return None + return fake_quantize_mx6_mx9(x.clone(), axis=-1, block_size=block_size, quant_bit=QUANT_BIT) + + +def _rand(shape, dtype): + torch.manual_seed(6) + return torch.randn(*shape, dtype=dtype) + + +@pytest.mark.parametrize("shape", [(4, 64), (8, 16), (2, 3, 32), (3, 40)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("block_size", [16, 32]) +def test_mx6_matches_quark_bit_exact(shape, dtype, block_size): + x = _rand(shape, dtype) + ref = _quark_mx6_reference(x, block_size) + if ref is None: + pytest.skip("Quark reference (fake_quantize_mx6_mx9) not importable") + + out = mx6_fake_quantize(x, block_size=block_size, quant_bit=QUANT_BIT) + assert out.dtype == ref.dtype + assert torch.equal(out, ref), (out - ref).abs().max().item() + + +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +def test_shape_and_dtype_preserved(dtype): + x = _rand((4, 64), dtype) + out = mx6_fake_quantize(x, block_size=BLOCK_SIZE) + assert out.shape == x.shape + assert out.dtype == dtype + + +def test_padding_path_non_divisible_last_dim(): + x = _rand((3, 40), torch.float32) + out = mx6_fake_quantize(x, block_size=BLOCK_SIZE) + assert out.shape == x.shape + assert not torch.isnan(out).any() + + +def test_axis_contract(): + x = _rand((4, 32), torch.float32) + with pytest.raises(NotImplementedError): + mx6_fake_quantize(x, block_size=BLOCK_SIZE, axis=0) diff --git a/tests/unittest/mx9_mx6/test_mx6_wa_integration.py b/tests/unittest/mx9_mx6/test_mx6_wa_integration.py new file mode 100644 index 0000000..a9b4b9d --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx6_wa_integration.py @@ -0,0 +1,168 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Toy integration test for MX6 W+A dynamic quantization. + +This intentionally stays below a real llama3 validation run. It verifies the +first full ALTO wiring layer: + +recipe yaml -> QuantizationModifier -> Linear quantization_scheme -> +post_step dynamic-weight skip -> wrapped Linear forward -> MX6 W+A QDQ. +""" + +import importlib.util +import os +import sys +import types + +import torch +import yaml + + +ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + + +def _make_pkg(monkeypatch, name: str, path: str) -> types.ModuleType: + mod = types.ModuleType(name) + mod.__path__ = [path] + mod.__package__ = name + monkeypatch.setitem(sys.modules, name, mod) + return mod + + +def _load(monkeypatch, modname: str, relpath: str) -> types.ModuleType: + spec = importlib.util.spec_from_file_location(modname, os.path.join(ROOT, relpath)) + mod = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, modname, mod) + spec.loader.exec_module(mod) + return mod + + +def _install_torchtitan_stubs(monkeypatch) -> None: + """Provide the tiny torchtitan surface needed by quantization modules.""" + _make_pkg(monkeypatch, "torchtitan", "") + _make_pkg(monkeypatch, "torchtitan.tools", "") + + logging_mod = types.ModuleType("torchtitan.tools.logging") + + class _Logger: + def debug(self, *args, **kwargs): + pass + + def info(self, *args, **kwargs): + pass + + def warning(self, *args, **kwargs): + pass + + logging_mod.logger = _Logger() + monkeypatch.setitem(sys.modules, "torchtitan.tools.logging", logging_mod) + + utils_mod = types.ModuleType("torchtitan.tools.utils") + utils_mod.device_type = torch.device("cpu") + monkeypatch.setitem(sys.modules, "torchtitan.tools.utils", utils_mod) + + +def _load_quantization_modifier(monkeypatch): + """Load only the ALTO modules needed for QuantizationModifier. + + Importing top-level ``alto`` currently pulls unrelated torchtitan entry + points. This test loads the required modules directly. MX6 reuses the MX9 + math (both live in ``alto.modifiers.quantization.mx``); the ``format`` field + injection lives in ``alto.modifiers.quantization.format_registry``. All + ``sys.modules`` mutations go through ``monkeypatch`` so they are reverted at + test teardown and cannot leak into other tests in the same process. + """ + _install_torchtitan_stubs(monkeypatch) + + _make_pkg(monkeypatch, "alto", os.path.join(ROOT, "alto")) + _make_pkg(monkeypatch, "alto.models", os.path.join(ROOT, "alto/models")) + + # Dynamic W/A uses no observers, but calibration.py imports Observer. + observers_mod = types.ModuleType("alto.observers") + + class _Observer: + @staticmethod + def create_instance(*args, **kwargs): + raise AssertionError("MX6 dynamic W+A should not create observers") + + observers_mod.Observer = _Observer + monkeypatch.setitem(sys.modules, "alto.observers", observers_mod) + + _make_pkg(monkeypatch, "alto.modifiers", os.path.join(ROOT, "alto/modifiers")) + _make_pkg(monkeypatch, "alto.modifiers.utils", os.path.join(ROOT, "alto/modifiers/utils")) + hooks_mod = _load(monkeypatch, "alto.modifiers.utils.hooks", "alto/modifiers/utils/hooks.py") + sys.modules["alto.modifiers.utils"].HooksMixin = hooks_mod.HooksMixin + + modifier_base_mod = _load(monkeypatch, "alto.modifiers.base", "alto/modifiers/base.py") + sys.modules["alto.modifiers"].Modifier = modifier_base_mod.Modifier + + _make_pkg(monkeypatch, "alto.modifiers.quantization", os.path.join(ROOT, "alto/modifiers/quantization")) + # Inject the format field (and load the MX module) before QuantizationModifier. + _load(monkeypatch, "alto.modifiers.quantization.format_registry", "alto/modifiers/quantization/format_registry.py") + quantize_mod = _load(monkeypatch, "alto.modifiers.quantization.mx", "alto/modifiers/quantization/mx.py") + _load(monkeypatch, "alto.modifiers.quantization.calibration", "alto/modifiers/quantization/calibration.py") + _load(monkeypatch, "alto.modifiers.quantization.mixin", "alto/modifiers/quantization/mixin.py") + + # base.py imports get_layers for sequential mode; this test uses sequential=False. + _make_pkg(monkeypatch, "alto.utils", os.path.join(ROOT, "alto/utils")) + _make_pkg(monkeypatch, "alto.utils.pytorch", os.path.join(ROOT, "alto/utils/pytorch")) + module_utils = types.ModuleType("alto.utils.pytorch.module") + module_utils.get_layers = lambda *args, **kwargs: {} + monkeypatch.setitem(sys.modules, "alto.utils.pytorch.module", module_utils) + + quant_base_mod = _load(monkeypatch, "alto.modifiers.quantization.base", "alto/modifiers/quantization/base.py") + + patcher_mod = _load(monkeypatch, "alto.models.patcher", "alto/models/patcher.py") + patcher_mod.ModelPatcher.patch_fake_quantize() + + return quant_base_mod.QuantizationModifier, quantize_mod + + +def _load_mx6_modifier_from_recipe(monkeypatch): + recipe_path = os.path.join(ROOT, "alto/models/llama3/configs/mx6_wa_recipe.yaml") + with open(recipe_path, "r") as f: + recipe = yaml.safe_load(f) + + mod_args = recipe["quantization_stage"]["quantization_modifiers"]["QuantizationModifier"] + QuantizationModifier, quantize_mod = _load_quantization_modifier(monkeypatch) + return QuantizationModifier(**mod_args), quantize_mod + + +def test_mx6_wa_recipe_toy_linear_lifecycle(monkeypatch): + modifier, quantize_mod = _load_mx6_modifier_from_recipe(monkeypatch) + model = torch.nn.Sequential(torch.nn.Linear(16, 16, bias=False)) + linear = model[0] + + modifier.initialize([model]) + + scheme = linear.quantization_scheme + assert getattr(scheme.weights, "format", None) == "mx6" + assert getattr(scheme.input_activations, "format", None) == "mx6" + assert scheme.weights.dynamic is True + assert scheme.input_activations.dynamic is True + assert not hasattr(linear, "weight_observer") + assert not hasattr(linear, "input_observer") + + calls = [] + original_mx6 = quantize_mod.mx6_fake_quantize + + def counted_mx6(input_tensor, *args, **kwargs): + calls.append(tuple(input_tensor.shape)) + return original_mx6(input_tensor, *args, **kwargs) + + monkeypatch.setattr(quantize_mod, "mx6_fake_quantize", counted_mx6) + + x = torch.randn(2, 16) + modifier.pre_step([model]) + modifier.post_step([model]) # must skip static weight baking for MX6 dynamic weight + assert not hasattr(linear, "weight_observer") + assert not hasattr(linear, "weight_scale") + + out = model(x) + assert out.shape == (2, 16) + assert len(calls) == 2 + assert (2, 16) in calls # input activation QDQ + assert (16, 16) in calls # weight QDQ + + modifier.finalize([model]) diff --git a/tests/unittest/mx9_mx6/test_mx9_dispatch.py b/tests/unittest/mx9_mx6/test_mx9_dispatch.py new file mode 100644 index 0000000..c22745a --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx9_dispatch.py @@ -0,0 +1,120 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Tests for the MX9 dispatch wiring in ``ModelPatcher.patch_fake_quantize``. + +The MX9 kernel itself is covered by ``test_mx9_quantize.py``. This file tests the +*wiring* one layer up: after ``patch_fake_quantize()`` replaces +``compressed_tensors...forward.fake_quantize``, a ``QuantizationArgs`` carrying +``format == "mx9"`` must route to ``mx9_fake_quantize``, while plain int8 args +(no ``format``) must fall through to the original implementation untouched. + +Also checks that ``format_registry.inject_format_field()`` makes the ``format`` +field survive pydantic validation (otherwise the recipe value is silently +dropped and dispatch never fires). + +Run with: + pytest tests/unittest/mx9_mx6/test_mx9_dispatch.py +""" + +import pytest +import torch + +from compressed_tensors.quantization import QuantizationArgs +from compressed_tensors.quantization.lifecycle import forward as forward_module + +from alto.modifiers.quantization.mx import BLOCK_SIZE, mx9_fake_quantize + + +def _mx9_args() -> QuantizationArgs: + return QuantizationArgs.model_validate( + { + "num_bits": 8, + "type": "int", + "symmetric": True, + "strategy": "tensor", + "dynamic": True, + "format": "mx9", + } + ) + + +def _int8_args() -> QuantizationArgs: + return QuantizationArgs.model_validate( + { + "num_bits": 8, + "type": "int", + "symmetric": True, + "strategy": "tensor", + } + ) + + +# --------------------------------------------------------------------------- # +# format field injection (alto.modifiers.quantization.format_registry) +# --------------------------------------------------------------------------- # +def test_format_field_survives_validation(): + """Without inject_format_field(), pydantic drops the unknown ``format`` key + and dispatch can never see it.""" + args = _mx9_args() + assert getattr(args, "format", None) == "mx9" + + +def test_format_field_defaults_none_for_plain_int8(): + args = _int8_args() + assert getattr(args, "format", None) is None + + +# --------------------------------------------------------------------------- # +# dispatch routing +# --------------------------------------------------------------------------- # +def test_mx9_args_dispatch_to_mx9_kernel(): + """format == "mx9" must route fake_quantize to mx9_fake_quantize bit-exact.""" + torch.manual_seed(0) + x = torch.randn(3, 40, dtype=torch.float32) + args = _mx9_args() + + patched_out = forward_module.fake_quantize( + x=x, + scale=torch.ones(1), + zero_point=None, + args=args, + g_idx=None, + global_scale=None, + ) + expected = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + + assert torch.equal(patched_out, expected) + + +def test_int8_args_fall_through_to_original(): + """Plain int8 (no format) must NOT touch the mx9 kernel; output must match a + direct INT8 per-tensor QDQ, never the mx9 result.""" + torch.manual_seed(1) + x = torch.randn(3, 40, dtype=torch.float32) + args = _int8_args() + + # per-tensor symmetric int8 scale (what the original kernel would use) + qmax = 127.0 + scale = (x.abs().max() / qmax).clamp_min(1e-12).reshape(1) + zp = torch.zeros(1, dtype=torch.int32) + + patched_out = forward_module.fake_quantize( + x=x, + scale=scale, + zero_point=zp, + args=args, + g_idx=None, + global_scale=None, + ) + + expected_int8 = (torch.round(x / scale).clamp(-128, 127)) * scale + assert torch.allclose(patched_out, expected_int8, atol=1e-5) + + # and it must differ from the mx9 path (sanity: int8 did not silently route to mx9) + mx9_out = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + assert not torch.equal(patched_out, mx9_out) + + +if __name__ == "__main__": + raise SystemExit(pytest.main([__file__, "-v", "-p", "no:cacheprovider"])) diff --git a/tests/unittest/mx9_mx6/test_mx9_quantize.py b/tests/unittest/mx9_mx6/test_mx9_quantize.py new file mode 100644 index 0000000..a148463 --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx9_quantize.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Tests for the MX9 fake-quantize kernel. + +Three layers: + 1. Golden-vector parity vs Quark's published mx9 output (hardcoded, so it runs + even when Quark is not installed). This is the primary correctness anchor. + 2. Optional bit-exact parity vs a live Quark import (skipped if unavailable). + 3. Format-property checks (block independence, prime-bit demotion, padding, + NaN/Inf, dtype, idempotence, axis contract). +""" + +import pytest +import torch + +from alto.modifiers.quantization.mx import BLOCK_SIZE, MX9_QUANT_BIT as QUANT_BIT, mx9_fake_quantize + + +# --------------------------------------------------------------------------- # +# Helpers +# --------------------------------------------------------------------------- # +def _quark_interesting_pattern() -> torch.Tensor: + """Quark test_mx.py:228 create_4d_tensor_with_interesting_pattern().""" + result = torch.zeros(1, 2, 3, 4) + for x1 in range(2): + for x2 in range(3): + for x3 in range(4): + result[0, x1, x2, x3] = (x3 + 1.0) * (10 * (x2 + 1.0)) + 100.0 * x1 + return result + + +# Quark's published mx9 output for the pattern above @ block_size=16, axis=-1 +# (test_mx.py:341-348). MX9 is high precision, so it round-trips these integers +# exactly. +_QUARK_MX9_GOLDEN = torch.tensor( + [ + [ + [[10.0, 20.0, 30.0, 40.0], [20.0, 40.0, 60.0, 80.0], [30.0, 60.0, 90.0, 120.0]], + [[110.0, 120.0, 130.0, 140.0], [120.0, 140.0, 160.0, 180.0], [130.0, 160.0, 190.0, 220.0]], + ] + ] +) + + +def _quark_mx9_reference(x: torch.Tensor, block_size: int): + """Live Quark reference; returns None (caller skips) if Quark is unavailable.""" + try: + from quark.torch.kernel.hw_emulation.hw_emulation_interface import ( + fake_quantize_mx6_mx9, + ) + except Exception: + return None + return fake_quantize_mx6_mx9(x.clone(), axis=-1, block_size=block_size, quant_bit=QUANT_BIT) + + +def _rand(shape, dtype): + torch.manual_seed(0) + return torch.randn(*shape, dtype=dtype) + + +# --------------------------------------------------------------------------- # +# Layer 1: golden vector (always runs, no Quark needed) +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) +def test_mx9_matches_quark_golden_vector(dtype): + x = _quark_interesting_pattern().to(dtype) + out = mx9_fake_quantize(x, block_size=16) + expected = _QUARK_MX9_GOLDEN.to(dtype) + assert out.dtype == dtype + assert torch.equal(out, expected), (out - expected).abs().max().item() + + +# --------------------------------------------------------------------------- # +# Layer 2: live Quark bit-exact (bonus, skips if Quark absent) +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("shape", [(4, 64), (8, 16), (2, 3, 32), (3, 40)]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("block_size", [16, 32]) +def test_mx9_matches_quark_bit_exact(shape, dtype, block_size): + x = _rand(shape, dtype) + ref = _quark_mx9_reference(x, block_size) + if ref is None: + pytest.skip("Quark reference (fake_quantize_mx6_mx9) not importable") + + out = mx9_fake_quantize(x, block_size=block_size, quant_bit=QUANT_BIT) + assert out.dtype == ref.dtype + assert torch.equal(out, ref), (out - ref).abs().max().item() + + +# --------------------------------------------------------------------------- # +# Layer 3: format-property checks +# --------------------------------------------------------------------------- # +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +def test_shape_and_dtype_preserved(dtype): + x = _rand((4, 64), dtype) + out = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + assert out.shape == x.shape + assert out.dtype == dtype + + +def test_block_independence(): + # A block of huge values must not degrade the scale of a separate small block. + # Two rows = two independent blocks (last dim == block_size). + big = torch.full((1, BLOCK_SIZE), 100.0) + small = torch.full((1, BLOCK_SIZE), 0.01) + joint = mx9_fake_quantize(torch.cat([big, small], dim=0), block_size=BLOCK_SIZE) + alone = mx9_fake_quantize(small, block_size=BLOCK_SIZE) + # small block's result is identical whether or not the big block shares the tensor. + assert torch.equal(joint[1:2], alone) + + +def test_prime_bit_demotes_small_pair(): + # MX9's distinguishing feature: when BOTH elements of an adjacent pair are + # >= 1 octave below the block max, that pair's scale drops one exponent, + # giving the pair finer resolution. With block max 128 (max_exp=7) the + # non-demoted scale is 2.0 and the demoted scale is 1.0, so v=1.4 rounds to + # 2.0 (err 0.6) without demotion but to 1.0 (err 0.4) with it. + v = 1.4 + block = torch.zeros(1, BLOCK_SIZE) + block[0, 0] = 128.0 # block max -> sets shared exponent, never demoted + block[0, 1] = v # pair (idx 0,1): idx0 is the max -> NOT demoted + block[0, 2] = v # pair (idx 2,3): BOTH small -> demoted + block[0, 3] = v + out = mx9_fake_quantize(block, block_size=BLOCK_SIZE) + # Hard values: with max_exp=7 the non-demoted scale is 2.0 and the demoted + # scale is 1.0, so v=1.4 lands on exactly those grid points. + assert out[0, 1].item() == 2.0 # non-demoted pair (idx0 is the max) + assert out[0, 2].item() == 1.0 # demoted pair (both small) + assert out[0, 3].item() == 1.0 + err_non_demoted = (out[0, 1] - v).abs().item() + err_demoted = (out[0, 2] - v).abs().item() + assert err_demoted < err_non_demoted + + +def test_padding_path_non_divisible_last_dim(): + x = _rand((3, 40), torch.float32) # 40 not divisible by 16 + out = mx9_fake_quantize(x, block_size=16) + assert out.shape == x.shape + # Padding zeros must not corrupt the real elements: quantizing the same + # data laid out as exact blocks gives the matching prefix. + assert not torch.isnan(out).any() + + +def test_inf_does_not_poison_block(): + # nan_to_num is applied to the amax copy only (matching Quark), so ±Inf in + # the data does NOT blow up the shared scale: the real elements survive and + # Inf is clamped to the finite quant range. NaN in the *data* path is not + # scrubbed and would propagate -- documented limitation, tested separately. + x = torch.full((1, BLOCK_SIZE), 4.0) + x[0, 0] = float("inf") + x[0, 1] = float("-inf") + out = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + assert not torch.isinf(out).any() + # The real 4.0 values are unaffected by the poisoned amax candidate. + assert torch.equal(out[0, 2:], torch.full((BLOCK_SIZE - 2,), 4.0)) + + +def test_nan_in_data_propagates(): + # Known limitation faithful to Quark: a NaN element survives QDQ as NaN. + # Locking this in so any future change to the data-path sanitization is loud. + x = torch.full((1, BLOCK_SIZE), 4.0) + x[0, 0] = float("nan") + out = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + assert torch.isnan(out[0, 0]) + assert torch.equal(out[0, 1:], torch.full((BLOCK_SIZE - 1,), 4.0)) + + +def test_qdq_idempotent(): + x = _rand((8, 64), torch.float32) + once = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + twice = mx9_fake_quantize(once, block_size=BLOCK_SIZE) + assert torch.equal(once, twice) + + +def test_zeros_stay_zero(): + x = torch.zeros((2, 32), dtype=torch.float32) + out = mx9_fake_quantize(x, block_size=BLOCK_SIZE) + assert torch.equal(out, x) + + +def test_axis_contract(): + x = _rand((4, 32), torch.float32) + with pytest.raises(NotImplementedError): + mx9_fake_quantize(x, block_size=BLOCK_SIZE, axis=0) diff --git a/tests/unittest/mx9_mx6/test_mx9_wa_integration.py b/tests/unittest/mx9_mx6/test_mx9_wa_integration.py new file mode 100644 index 0000000..f090180 --- /dev/null +++ b/tests/unittest/mx9_mx6/test_mx9_wa_integration.py @@ -0,0 +1,166 @@ +# Copyright (c) 2026 Advanced Micro Devices, Inc. +# +# SPDX-License-Identifier: MIT +"""Toy integration test for MX9 W+A dynamic quantization. + +This intentionally stays below a real llama3 validation run. It verifies the +first full ALTO wiring layer: + +recipe yaml -> QuantizationModifier -> Linear quantization_scheme -> +post_step dynamic-weight skip -> wrapped Linear forward -> MX9 W+A QDQ. +""" + +import importlib.util +import os +import sys +import types + +import torch +import yaml + + +ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + + +def _make_pkg(monkeypatch, name: str, path: str) -> types.ModuleType: + mod = types.ModuleType(name) + mod.__path__ = [path] + mod.__package__ = name + monkeypatch.setitem(sys.modules, name, mod) + return mod + + +def _load(monkeypatch, modname: str, relpath: str) -> types.ModuleType: + spec = importlib.util.spec_from_file_location(modname, os.path.join(ROOT, relpath)) + mod = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, modname, mod) + spec.loader.exec_module(mod) + return mod + + +def _install_torchtitan_stubs(monkeypatch) -> None: + """Provide the tiny torchtitan surface needed by quantization modules.""" + _make_pkg(monkeypatch, "torchtitan", "") + _make_pkg(monkeypatch, "torchtitan.tools", "") + + logging_mod = types.ModuleType("torchtitan.tools.logging") + + class _Logger: + def debug(self, *args, **kwargs): + pass + + def info(self, *args, **kwargs): + pass + + def warning(self, *args, **kwargs): + pass + + logging_mod.logger = _Logger() + monkeypatch.setitem(sys.modules, "torchtitan.tools.logging", logging_mod) + + utils_mod = types.ModuleType("torchtitan.tools.utils") + utils_mod.device_type = torch.device("cpu") + monkeypatch.setitem(sys.modules, "torchtitan.tools.utils", utils_mod) + + +def _load_quantization_modifier(monkeypatch): + """Load only the ALTO modules needed for QuantizationModifier. + + Importing top-level ``alto`` currently pulls unrelated torchtitan entry + points, so the test loads the required modules directly. All ``sys.modules`` + mutations go through ``monkeypatch`` so they are reverted at test teardown + and cannot leak into other tests in the same process. + """ + _install_torchtitan_stubs(monkeypatch) + + _make_pkg(monkeypatch, "alto", os.path.join(ROOT, "alto")) + _make_pkg(monkeypatch, "alto.models", os.path.join(ROOT, "alto/models")) + + # Dynamic W/A uses no observers, but calibration.py imports Observer. + observers_mod = types.ModuleType("alto.observers") + + class _Observer: + @staticmethod + def create_instance(*args, **kwargs): + raise AssertionError("MX9 dynamic W+A should not create observers") + + observers_mod.Observer = _Observer + monkeypatch.setitem(sys.modules, "alto.observers", observers_mod) + + _make_pkg(monkeypatch, "alto.modifiers", os.path.join(ROOT, "alto/modifiers")) + _make_pkg(monkeypatch, "alto.modifiers.utils", os.path.join(ROOT, "alto/modifiers/utils")) + hooks_mod = _load(monkeypatch, "alto.modifiers.utils.hooks", "alto/modifiers/utils/hooks.py") + sys.modules["alto.modifiers.utils"].HooksMixin = hooks_mod.HooksMixin + + modifier_base_mod = _load(monkeypatch, "alto.modifiers.base", "alto/modifiers/base.py") + sys.modules["alto.modifiers"].Modifier = modifier_base_mod.Modifier + + _make_pkg(monkeypatch, "alto.modifiers.quantization", os.path.join(ROOT, "alto/modifiers/quantization")) + # Inject the format field (and load the MX module) before QuantizationModifier. + _load(monkeypatch, "alto.modifiers.quantization.format_registry", "alto/modifiers/quantization/format_registry.py") + quantize_mod = _load(monkeypatch, "alto.modifiers.quantization.mx", "alto/modifiers/quantization/mx.py") + _load(monkeypatch, "alto.modifiers.quantization.calibration", "alto/modifiers/quantization/calibration.py") + _load(monkeypatch, "alto.modifiers.quantization.mixin", "alto/modifiers/quantization/mixin.py") + + # base.py imports get_layers for sequential mode; this test uses sequential=False. + _make_pkg(monkeypatch, "alto.utils", os.path.join(ROOT, "alto/utils")) + _make_pkg(monkeypatch, "alto.utils.pytorch", os.path.join(ROOT, "alto/utils/pytorch")) + module_utils = types.ModuleType("alto.utils.pytorch.module") + module_utils.get_layers = lambda *args, **kwargs: {} + monkeypatch.setitem(sys.modules, "alto.utils.pytorch.module", module_utils) + + quant_base_mod = _load(monkeypatch, "alto.modifiers.quantization.base", "alto/modifiers/quantization/base.py") + + patcher_mod = _load(monkeypatch, "alto.models.patcher", "alto/models/patcher.py") + patcher_mod.ModelPatcher.patch_fake_quantize() + + return quant_base_mod.QuantizationModifier, quantize_mod + + +def _load_mx9_modifier_from_recipe(monkeypatch): + recipe_path = os.path.join(ROOT, "alto/models/llama3/configs/mx9_wa_recipe.yaml") + with open(recipe_path, "r") as f: + recipe = yaml.safe_load(f) + + mod_args = recipe["quantization_stage"]["quantization_modifiers"]["QuantizationModifier"] + QuantizationModifier, quantize_mod = _load_quantization_modifier(monkeypatch) + return QuantizationModifier(**mod_args), quantize_mod + + +def test_mx9_wa_recipe_toy_linear_lifecycle(monkeypatch): + modifier, quantize_mod = _load_mx9_modifier_from_recipe(monkeypatch) + model = torch.nn.Sequential(torch.nn.Linear(16, 16, bias=False)) + linear = model[0] + + modifier.initialize([model]) + + scheme = linear.quantization_scheme + assert getattr(scheme.weights, "format", None) == "mx9" + assert getattr(scheme.input_activations, "format", None) == "mx9" + assert scheme.weights.dynamic is True + assert scheme.input_activations.dynamic is True + assert not hasattr(linear, "weight_observer") + assert not hasattr(linear, "input_observer") + + calls = [] + original_mx9 = quantize_mod.mx9_fake_quantize + + def counted_mx9(input_tensor, *args, **kwargs): + calls.append(tuple(input_tensor.shape)) + return original_mx9(input_tensor, *args, **kwargs) + + monkeypatch.setattr(quantize_mod, "mx9_fake_quantize", counted_mx9) + + x = torch.randn(2, 16) + modifier.pre_step([model]) + modifier.post_step([model]) # must skip static weight baking for MX9 dynamic weight + assert not hasattr(linear, "weight_observer") + assert not hasattr(linear, "weight_scale") + + out = model(x) + assert out.shape == (2, 16) + assert len(calls) == 2 + assert (2, 16) in calls # input activation QDQ + assert (16, 16) in calls # weight QDQ + + modifier.finalize([model])