Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions alto/models/llama3/config_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
"llama3_8b_lpt",
"llama3_1b_gptq",
"llama3_1b_awq",
"llama3_1b_mx9_wa",
"llama3_1b_mx6_wa",
"llama3_8b",
"llama3_8b_gptq",
"llama3_8b_rtn",
Expand Down Expand Up @@ -182,6 +184,26 @@ def llama3_1b_awq() -> Trainer.Config:
return config


def llama3_1b_mx9_wa() -> Trainer.Config:
config = llama3_1b()
config.training.steps = 1
config.optimizer = OptimizersContainer.Config(lr=0.0)
config.model_converters = ModelConvertersContainer.Config(converters=[
ModelOptConverter.Config(recipe="./alto/models/llama3/configs/mx9_wa_recipe.yaml",),
],)
return config


def llama3_1b_mx6_wa() -> Trainer.Config:
config = llama3_1b()
config.training.steps = 1
config.optimizer = OptimizersContainer.Config(lr=0.0)
config.model_converters = ModelConvertersContainer.Config(converters=[
ModelOptConverter.Config(recipe="./alto/models/llama3/configs/mx6_wa_recipe.yaml",),
],)
return config


LLAMA3_8B_PATH = "/workspace/Model-Optimizer/models/meta-llama/Llama-3.1-8B"


Expand Down
25 changes: 25 additions & 0 deletions alto/models/llama3/configs/mx6_wa_recipe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# MX6 W+A fully-dynamic recipe (single GPU, dense).
# Weight and input-activation are both quantized, both dynamic -> QDQ happens in
# the forward pass, no calibration / observer / weight baking required.
quantization_stage:
quantization_modifiers:
QuantizationModifier:
ignore: ["output"] # lm_head not quantized
sequential: false
config_groups:
group_0:
targets: ["Linear"]
weights: # fully dynamic: QDQ in forward, no baking
num_bits: 5
type: "int"
symmetric: true
strategy: "tensor"
dynamic: true
format: "mx6" # <- hits mx6 dispatch
input_activations: # fully dynamic
num_bits: 5
type: "int"
symmetric: true
strategy: "tensor"
dynamic: true
format: "mx6"
25 changes: 25 additions & 0 deletions alto/models/llama3/configs/mx9_wa_recipe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# MX9 W+A fully-dynamic recipe (single GPU, dense).
# Weight and input-activation are both quantized, both dynamic -> QDQ happens in
# the forward pass, no calibration / observer / weight baking required.
quantization_stage:
quantization_modifiers:
QuantizationModifier:
ignore: ["output"] # lm_head not quantized
sequential: false
config_groups:
group_0:
targets: ["Linear"]
weights: # fully dynamic: QDQ in forward, no baking
num_bits: 8
type: "int"
symmetric: true
strategy: "tensor"
dynamic: true
format: "mx9" # <- hits mx9 dispatch
input_activations: # fully dynamic
num_bits: 8
type: "int"
symmetric: true
strategy: "tensor"
dynamic: true
format: "mx9"
24 changes: 24 additions & 0 deletions alto/models/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,30 @@ class FakeQuantizeFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, x, scale, zero_point, args, g_idx, global_scale):
if getattr(args, "format", None) == "mx9":
from alto.modifiers.quantization.mx import (
BLOCK_SIZE,
MX9_QUANT_BIT,
mx9_fake_quantize,
)

return mx9_fake_quantize(
x,
block_size=(args.group_size or BLOCK_SIZE),
quant_bit=(args.num_bits or MX9_QUANT_BIT),
)
if getattr(args, "format", None) == "mx6":
from alto.modifiers.quantization.mx import (
BLOCK_SIZE,
MX6_QUANT_BIT,
mx6_fake_quantize,
)

return mx6_fake_quantize(
x,
block_size=(args.group_size or BLOCK_SIZE),
quant_bit=(args.num_bits or MX6_QUANT_BIT),
)
return original_fake_quantize(x, scale, zero_point, args, g_idx, global_scale)

@staticmethod
Expand Down
8 changes: 8 additions & 0 deletions alto/modifiers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,14 @@
#
# SPDX-License-Identifier: MIT

# Inject the QuantizationArgs.format field BEFORE importing QuantizationModifier:
# the modifier compiles its nested QuantizationScheme schema at class-definition
# time, so the field must exist first or recipes carrying ``format: mx6/mx9`` are
# rejected by the cached (format-less) schema.
from .format_registry import inject_format_field

inject_format_field()

from .base import QuantizationModifier
from .gptq import GPTQModifier
from .awq import AWQModifier
Expand Down
6 changes: 5 additions & 1 deletion alto/modifiers/quantization/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@

import torch
import tqdm
from compressed_tensors.quantization import disable_quantization, enable_quantization
from compressed_tensors.quantization import (
disable_quantization,
enable_quantization,
)
from compressed_tensors.utils import getattr_chain, match_named_modules
from pydantic import Field, PrivateAttr
from torch.nn import Module
Expand All @@ -28,6 +31,7 @@

__all__ = ["QuantizationModifier"]


DEFAULT_SUBLAYER_GROUPS = [
["wq", "wk", "wv", "q_proj", "k_proj", "v_proj"],
["wo", "o_proj"],
Expand Down
5 changes: 5 additions & 0 deletions alto/modifiers/quantization/calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ def update_weight_zp_scale(module: Module):
if getattr_chain(module, "quantization_scheme.weights", None) is None:
return

# Fully dynamic weights have no observer: their QDQ runs live in the patched
# forward on every pass, so there is no static scale to compute or bake.
if not hasattr(module, "weight_observer"):
return

if getattr(module, "quantization_status", None) != QuantizationStatus.CALIBRATION:
logger.warning("Attempting to calibrate weights of a module not in calibration mode")

Expand Down
48 changes: 48 additions & 0 deletions alto/modifiers/quantization/format_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2026 Advanced Micro Devices, Inc.
#
# SPDX-License-Identifier: MIT
"""Runtime patch that wires emulated formats into the standard quant path.

Importing this module injects a real ``format`` field into
``compressed_tensors.QuantizationArgs`` so recipe values like ``format: mx9``
survive pydantic parsing and become readable via ``getattr(args, "format", None)``
(by default unknown fields are silently dropped).

The actual ``fake_quantize`` dispatch (``args.format == "mx9"`` -> mx9) lives in
``alto.models.patcher.ModelPatcher.patch_fake_quantize`` where the single wrap of
``compressed_tensors...forward.fake_quantize`` already happens.

``inject_format_field()`` is called at the top of this package's ``__init__`` (before
``QuantizationModifier`` is imported) so the field exists before the modifier
compiles its nested ``QuantizationScheme`` schema.
"""

from typing import Optional

_FORMAT_FIELD_INJECTED = False


def inject_format_field() -> None:
"""Add ``format: Optional[str] = None`` to ``QuantizationArgs`` (idempotent)."""
global _FORMAT_FIELD_INJECTED
if _FORMAT_FIELD_INJECTED:
return

from pydantic.fields import FieldInfo
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig, QuantizationScheme

if "format" not in QuantizationArgs.model_fields:
QuantizationArgs.model_fields["format"] = FieldInfo(
annotation=Optional[str], default=None
)
QuantizationArgs.model_rebuild(force=True)
# QuantizationArgs is nested inside these models. Rebuild them as well so
# recipe dictionaries with weights/input_activations.format are accepted
# instead of being rejected by the old cached schema.
QuantizationScheme.model_rebuild(force=True)
QuantizationConfig.model_rebuild(force=True)

_FORMAT_FIELD_INJECTED = True


inject_format_field()
Loading