Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions alto/nn/decomposed_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,27 @@
import torch.nn.functional as F


# ``lora_rank`` is the contraction dim (K) of the ``lora_update @ u`` GEMM. When
# u/v are quantized (NVFP4/MXFP4), low-precision kernels quantize K in blocks of
# this size, so lora_rank must be a positive multiple of it. NVFP4 uses 16
# (BLOCK_SIZE_DEFAULT); MXFP4 uses 32 and is a multiple of 16, so requiring a
# multiple of 16 covers both. Without this check, e.g. lora_rank=8 fails deep in
# the kernel with an opaque torch._check error instead of here at construction.
LORA_RANK_BLOCK_MULTIPLE = 16


def _validate_lora_rank(lora_rank: int) -> None:
if lora_rank <= 0 or lora_rank % LORA_RANK_BLOCK_MULTIPLE != 0:
raise ValueError(
f"lora_rank must be a positive multiple of {LORA_RANK_BLOCK_MULTIPLE} "
f"(required by NVFP4/MXFP4 block-quantized kernels), got {lora_rank}."
)


class DecomposedLinear(nn.Module):
def __init__(self, in_features, out_features, bias=True, lora_rank=32):
super(DecomposedLinear, self).__init__()
_validate_lora_rank(lora_rank)
self.in_features = in_features
self.out_features = out_features

Expand All @@ -28,14 +46,30 @@ def forward(self, input):

@classmethod
def from_linear(cls, linear: nn.Linear, lora_rank: int = 32):
new_layer = cls(linear.in_features, linear.out_features, linear.bias is not None, lora_rank)
new_layer.weight = linear.weight
new_layer.bias = linear.bias
_validate_lora_rank(lora_rank)
# Build u/v/sigma directly on the source weight's device/dtype (which may
# be "meta" during TorchTitan's meta-device model construction). We must
# NOT allocate on CPU and then `.to(device=...)`: moving a real CPU tensor
# onto a meta device triggers an incompatible set_data and crashes during
# `on_convert`, before `to_empty`/`init_weights` have materialized params.
device = linear.weight.device
dtype = linear.weight.dtype
new_layer.u.data = new_layer.u.data.to(device=device, dtype=dtype)
new_layer.v.data = new_layer.v.data.to(device=device, dtype=dtype)
new_layer.sigma.data = new_layer.sigma.data.to(device=device, dtype=dtype)

new_layer = cls.__new__(cls)
nn.Module.__init__(new_layer)
new_layer.in_features = linear.in_features
new_layer.out_features = linear.out_features
new_layer.weight = linear.weight
new_layer.bias = linear.bias
new_layer.u = nn.Parameter(
torch.empty(lora_rank, linear.out_features, device=device, dtype=dtype)
) # transposed
new_layer.v = nn.Parameter(
torch.empty(linear.in_features, lora_rank, device=device, dtype=dtype)
)
new_layer.sigma = nn.Parameter(
torch.empty(lora_rank, device=device, dtype=dtype)
)
return new_layer

def init_lora_weights(self, init_std: float = 0.02):
Expand Down
145 changes: 119 additions & 26 deletions tests/unittest/nn/test_decomposed_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,85 @@
#
# SPDX-License-Identifier: MIT

from dataclasses import dataclass

import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F

from alto.nn import DecomposedLinear
from alto.kernels.dispatch import TrainingOpConfig, swap_params
from alto.kernels.dispatch.tensor import TrainingWeightWrapperBaseTensor


# ---------------------------------------------------------------------------
# Shared scheme spec + helpers
#
# Per-scheme quirks (block size, two-level-scaling default, SNR bar) live in a
# single table so the test bodies stay scheme-agnostic and NVFP4/MXFP4 are
# exercised apples-to-apples. Adding a future scheme = one new row here.
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class SchemeTestSpec:
block_size: int
two_level_scaling: str
min_snr: float


SCHEME_SPECS = {
"mxfp4": SchemeTestSpec(block_size=32, two_level_scaling="none", min_snr=10.0),
"nvfp4": SchemeTestSpec(block_size=16, two_level_scaling="none", min_snr=10.0),
}

QUANT_PRECISIONS = list(SCHEME_SPECS)


def make_quant_config(precision: str, **overrides) -> TrainingOpConfig:
"""Single entry point for building a quant config, encapsulating the
per-scheme ``two_level_scaling`` semantics (nvfp4: tensorwise-capable,
mxfp4: blockwise-capable; both default to ``none`` here)."""
spec = SCHEME_SPECS[precision]
cfg = dict(
precision=precision,
use_2dblock_x=False,
use_2dblock_w=True,
use_hadamard=True,
use_sr_grad=False,
use_dge=False,
two_level_scaling=spec.two_level_scaling,
)
cfg.update(overrides)
return TrainingOpConfig(**cfg)


def quant_batch(precision: str) -> int:
"""Activation rows must be divisible by the scheme block size when
``use_2dblock_x=False``; pick a comfortable multiple for a meaningful SNR."""
return 4 * SCHEME_SPECS[precision].block_size


def build_decomposed_linear(
in_features, out_features, bias, lora_rank, *, device, init_std=0.1, seed=0
):
torch.manual_seed(seed)
dl = DecomposedLinear(in_features, out_features, bias, lora_rank).to(device)
dl.weight.data.normal_(mean=0, std=init_std)
if bias:
dl.bias.data.normal_(mean=0, std=init_std)
dl.u.data.normal_(mean=0, std=init_std)
dl.v.data.normal_(mean=0, std=init_std)
dl.sigma.data.normal_(mean=0, std=init_std)
return dl


def compute_snr(y_ref, y):
return 20 * torch.log10(torch.norm(y_ref) / torch.norm(y_ref - y))


# ---------------------------------------------------------------------------
# Math correctness of from_linear (scheme-agnostic, CPU)
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("in_features", [128, 256, 512])
@pytest.mark.parametrize("out_features", [128, 256, 512])
@pytest.mark.parametrize("bias", [False, True])
Expand All @@ -34,40 +105,62 @@ def test_decomposed_linear(in_features, out_features, bias, lora_rank):
assert torch.allclose(y_ref, y_decomposed, atol=1e-4, rtol=1e-4)


@pytest.mark.parametrize("in_features", [128, 256, 512])
@pytest.mark.parametrize("out_features", [128, 256, 512])
# ---------------------------------------------------------------------------
# Meta-device build + swap path (TorchTitan).
#
# TorchTitan builds the model on the "meta" device and on_convert wraps params
# BEFORE to_empty/init_weights materialize them. from_linear must therefore
# allocate u/v/sigma on the source device (meta) -- never on CPU then
# `.to("meta")`, which triggers an incompatible set_data and crashes -- and
# swap_params must tolerate meta tensors. CPU-only (no GPU required).
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("bias", [False, True])
@pytest.mark.parametrize("lora_rank", [16, 32])
@pytest.mark.parametrize("precision", QUANT_PRECISIONS)
def test_decomposed_linear_meta_build_and_swap(bias, lora_rank, precision):
in_features, out_features = 128, 256
with torch.device("meta"):
linear = nn.Linear(in_features, out_features, bias=bias)

dl = DecomposedLinear.from_linear(linear, lora_rank=lora_rank)
assert dl.u.is_meta and dl.v.is_meta and dl.sigma.is_meta
assert dl.weight.is_meta
assert dl.u.shape == (lora_rank, out_features)
assert dl.v.shape == (in_features, lora_rank)
assert dl.sigma.shape == (lora_rank,)

cfg = make_quant_config(precision)
for p in ("weight", "u", "v"):
swap_params(dl, config=cfg, target_parameter_name=p)
assert isinstance(getattr(dl, p).data, TrainingWeightWrapperBaseTensor)

# ---------------------------------------------------------------------------
# End-to-end quantized forward SNR (CUDA), parametrized over all schemes.
# Feature sizes use only the small/large boundary [128, 512] to keep CI cost
# down while still covering the block-alignment edges.
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("in_features", [128, 512])
@pytest.mark.parametrize("out_features", [128, 512])
@pytest.mark.parametrize("bias", [False, True])
@pytest.mark.parametrize("lora_rank", [32, 64])
@pytest.mark.parametrize("precision", ["mxfp4"])
@pytest.mark.parametrize("precision", QUANT_PRECISIONS)
def test_decomposed_linear_quantization(in_features, out_features, bias, lora_rank, precision):
STD = 0.1
decomposed_linear = DecomposedLinear(in_features, out_features, bias, lora_rank).to("cuda")
decomposed_linear.weight.data.normal_(mean=0, std=STD)
if bias:
decomposed_linear.bias.data.normal_(mean=0, std=STD)
decomposed_linear.u.data.normal_(mean=0, std=STD)
decomposed_linear.v.data.normal_(mean=0, std=STD)
decomposed_linear.sigma.data.normal_(mean=0, std=STD)
spec = SCHEME_SPECS[precision]
decomposed_linear = build_decomposed_linear(
in_features, out_features, bias, lora_rank, device="cuda"
)

x = torch.randn(1, in_features, device="cuda")
x = torch.randn(quant_batch(precision), in_features, device="cuda")
y_ref = decomposed_linear(x)

quant_op_config = TrainingOpConfig(
precision=precision,
use_2dblock_x=False,
use_2dblock_w=True,
use_hadamard=True,
use_sr_grad=False,
use_dge=False,
)
swap_params(decomposed_linear, config=quant_op_config, target_parameter_name="weight")
swap_params(decomposed_linear, config=quant_op_config, target_parameter_name="u")
swap_params(decomposed_linear, config=quant_op_config, target_parameter_name="v")
quant_op_config = make_quant_config(precision)
for p in ("weight", "u", "v"):
swap_params(decomposed_linear, config=quant_op_config, target_parameter_name=p)
y = decomposed_linear(x)

max_diff = torch.max(torch.abs(y_ref - y))
mean_diff = torch.mean(torch.abs(y_ref - y))
snr = 20 * torch.log10(torch.norm(y_ref) / torch.norm(y_ref - y))
snr = compute_snr(y_ref, y)
cossim = torch.nn.functional.cosine_similarity(y_ref.flatten(), y.flatten(), dim=0)
print(f"snr={snr}, cossim={cossim}, max_diff={max_diff}, mean_diff={mean_diff}")
assert snr > 10, f"SNR too low: {snr}"
print(f"precision={precision}, snr={snr}, cossim={cossim}, max_diff={max_diff}, mean_diff={mean_diff}")
assert snr > spec.min_snr, f"SNR too low for {precision}: {snr}"