Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
199 changes: 156 additions & 43 deletions modelopt/torch/export/plugins/vllm_fakequant_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
# limitations under the License.
"""Export HuggingFace model to vLLM fakequant checkpoint."""

import logging
import time
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from pathlib import Path

import torch
Expand All @@ -28,6 +33,93 @@

__all__ = ["export_hf_vllm_fq_checkpoint"]

logger = logging.getLogger(__name__)


@dataclass
class _WeightQuantWork:
"""A single weight tensor to be fake-quantized during export."""

sd_key: str
quantizer: TensorQuantizer
weight: torch.Tensor
# For optional pre_quant_scale folding:
inp_q: TensorQuantizer | None
inp_q_key: str | None


def _collect_quant_work(
model: nn.Module, state_dict: dict[str, torch.Tensor]
) -> list[_WeightQuantWork]:
"""Collect all weight quantization work items from the model."""
work_items = []
seen_keys: set[str] = set()
for module_name, module in model.named_modules():
if not isinstance(module, QuantModule):
continue
for attr_name, quantizer in module.named_children():
if not (
attr_name.endswith("weight_quantizer")
and isinstance(quantizer, TensorQuantizer)
and quantizer.fake_quant
and quantizer.is_enabled
):
continue
weight_name = attr_name.removesuffix("_quantizer")
prefix = f"{module_name}." if module_name else ""
sd_key = f"{prefix}{weight_name}"
assert sd_key not in seen_keys, f"Weight {sd_key} has already been fakequantized"
seen_keys.add(sd_key)
if sd_key not in state_dict:
continue
# Check for pre_quant_scale folding eligibility.
inp_q = None
inp_q_key = None
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
if hasattr(module, inp_attr):
candidate = getattr(module, inp_attr)
if (
hasattr(candidate, "_pre_quant_scale")
and candidate._pre_quant_scale is not None
and candidate._disabled
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _enable_pre_quant_scale check is a correctness improvement (prevents double-folding when disable_pre_quant_scale_and_resmooth was called), but it's a behavioral change unrelated to parallelization. Consider noting this in the PR description or splitting it into a separate fix for easier review/bisect.

Also, getattr(candidate, "_enable_pre_quant_scale", True) always returns True for the default because TensorQuantizer.__init__ sets self._enable_pre_quant_scale = True. The getattr with default is unnecessary — candidate._enable_pre_quant_scale would suffice (since candidate is guaranteed to be a TensorQuantizer here).

and getattr(candidate, "_enable_pre_quant_scale", True)
):
inp_q = candidate
inp_q_key = get_unwrapped_name(
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
)
work_items.append(
_WeightQuantWork(
sd_key=sd_key,
quantizer=quantizer,
weight=state_dict[sd_key],
inp_q=inp_q,
inp_q_key=inp_q_key,
)
)
return work_items


def _process_weight(item: _WeightQuantWork) -> tuple[str, torch.Tensor, str | None]:
"""Fake-quantize a single weight tensor and optionally fold pre_quant_scale.

Returns (sd_key, quantized_weight_on_cpu, inp_q_key_or_None).
"""
w = item.weight
w_quant = item.quantizer(w.float()).to(w.dtype)
if item.inp_q is not None:
scale = item.inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
return item.sd_key, w_quant.cpu(), item.inp_q_key


def _process_device_batch(items: list[_WeightQuantWork], device: torch.device):
"""Process all weight items on a single GPU. Runs in a dedicated thread."""
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: _process_device_batch wraps with torch.inference_mode() but it's always called from within the outer with torch.inference_mode(): block in export_hf_vllm_fq_checkpoint. The nested context is harmless (inference_mode is re-entrant) but redundant for the sequential/CPU-inline paths. For the threaded GPU path, it IS necessary since torch.inference_mode() is thread-local. Consider adding a comment explaining why this inner context is needed (thread-local inference mode).

with torch.inference_mode(), torch.cuda.device(device):
results = [_process_weight(item) for item in items]
torch.cuda.synchronize(device)
return results


def disable_rotate(quantizer: TensorQuantizer):
"""Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type."""
Expand All @@ -41,6 +133,7 @@ def disable_rotate(quantizer: TensorQuantizer):
def export_hf_vllm_fq_checkpoint(
model: nn.Module,
export_dir: Path | str,
parallel: bool = True,
):
"""Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload.

Expand All @@ -53,6 +146,9 @@ def export_hf_vllm_fq_checkpoint(
Args:
model: In-memory quantized model.
export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``.
parallel: If True, fake-quantize weights across GPUs concurrently using
one thread per GPU device. Falls back to sequential when all weights
are on the same device or on CPU. Default True.
"""
export_dir = Path(export_dir)
export_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -62,50 +158,66 @@ def export_hf_vllm_fq_checkpoint(
# parameters are never modified. Apply each weight quantizer's fake-quant
# to the corresponding weight tensor in the copy.
state_dict = model.state_dict()
fakequant_weights = set()
input_quantizers_folded_pqs = (
set()
) # keys for input_quantizers where pre_quant_scale was folded
fakequant_weights: set[str] = set()
input_quantizers_folded_pqs: set[str] = set()

work_items = _collect_quant_work(model, state_dict)

# Group work items by device for parallel dispatch.
device_groups: dict[torch.device, list[_WeightQuantWork]] = defaultdict(list)
for item in work_items:
device_groups[item.weight.device].append(item)

num_cuda_devices = sum(1 for d in device_groups if d.type == "cuda")
use_parallel = parallel and num_cuda_devices > 1

t0 = time.monotonic()
with torch.inference_mode():
for module_name, module in model.named_modules():
if not isinstance(module, QuantModule):
continue
for attr_name, quantizer in module.named_children():
if not (
attr_name.endswith("weight_quantizer")
and isinstance(quantizer, TensorQuantizer)
and quantizer.fake_quant
and quantizer.is_enabled
):
continue
weight_name = attr_name.removesuffix("_quantizer")
prefix = f"{module_name}." if module_name else ""
sd_key = f"{prefix}{weight_name}"
assert sd_key not in fakequant_weights, (
f"Weight {sd_key} has already been fakequantized"
)
if sd_key in state_dict:
w = state_dict[sd_key]
w_quant = quantizer(w.float()).to(w.dtype).cpu()
# Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s)
# Only valid when input_quantizer does NOT fake-quant activations. If it does
# fake_quant(x*s), the non-linearity prevents folding s into W.
inp_attr = attr_name.replace("weight_quantizer", "input_quantizer")
if hasattr(module, inp_attr):
inp_q = getattr(module, inp_attr)
if (
hasattr(inp_q, "_pre_quant_scale")
and inp_q._pre_quant_scale is not None
and inp_q._disabled
):
scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device)
w_quant = (w_quant * scale[None, :]).to(w_quant.dtype)
inp_q_key = get_unwrapped_name(
f"{module_name}.{inp_attr}" if module_name else inp_attr, model
)
if use_parallel:
logger.info(
"Parallel export: %d weights across %d GPUs (%s)",
len(work_items),
num_cuda_devices,
", ".join(f"{d}: {len(items)} weights" for d, items in device_groups.items()),
)
with ThreadPoolExecutor(max_workers=num_cuda_devices) as pool:
# Submit GPU batches first (non-blocking)
futures = [
pool.submit(_process_device_batch, items, device)
for device, items in device_groups.items()
if device.type == "cuda"
]
# Process CPU weights inline while GPU futures run
for device, items in device_groups.items():
if device.type != "cuda":
for sd_key, w_quant, inp_q_key in map(_process_weight, items):
state_dict[sd_key] = w_quant
fakequant_weights.add(sd_key)
if inp_q_key is not None:
input_quantizers_folded_pqs.add(inp_q_key)
# Collect GPU results
for future in futures:
for sd_key, w_quant, inp_q_key in future.result():
state_dict[sd_key] = w_quant
fakequant_weights.add(sd_key)
if inp_q_key is not None:
input_quantizers_folded_pqs.add(inp_q_key)
state_dict[sd_key] = w_quant
fakequant_weights.add(sd_key)
else:
# Sequential fallback (single GPU, CPU, or parallel=False).
for item in work_items:
sd_key, w_quant, inp_q_key = _process_weight(item)
state_dict[sd_key] = w_quant
fakequant_weights.add(sd_key)
if inp_q_key is not None:
input_quantizers_folded_pqs.add(inp_q_key)

elapsed = time.monotonic() - t0
logger.info(
"Export step 1 (%s): %d weights fake-quantized in %.1fs",
"parallel" if use_parallel else "sequential",
len(fakequant_weights),
elapsed,
)

# Filter quantizer tensors out for a clean HF checkpoint.
clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k}
Expand Down Expand Up @@ -166,4 +278,5 @@ def export_hf_vllm_fq_checkpoint(

for wq, orig_rotate in wqs_to_restore:
wq.enable()
wq._rotate = orig_rotate
if orig_rotate is not None:
wq._rotate = orig_rotate
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if orig_rotate is not None: guard is a behavioral change unrelated to parallelization. Looking at TensorQuantizer.__init__set_from_attribute_config, _rotate is always set (defaults to False), so it should never be None. If this guard is protecting against some edge case, please document which scenario produces _rotate = None. Otherwise, this changes the restore semantics — the old code would always restore, the new code skips restoration if None.

139 changes: 139 additions & 0 deletions tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test parallel vs sequential export produces identical outputs."""

import pytest
import torch
from _test_utils.torch.transformers_models import create_tiny_llama_dir
from transformers import AutoModelForCausalLM

import modelopt.torch.quantization as mtq
from modelopt.torch.export import export_hf_vllm_fq_checkpoint


def _quantize_model(tmp_path, suffix=""):
"""Create and quantize a tiny LLaMA model. Returns (model, export_dir)."""
tiny_model_dir = create_tiny_llama_dir(tmp_path / f"model{suffix}", num_hidden_layers=4)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This _quantize_model helper function is defined but never called by any test in this file. Please remove it to avoid dead code confusion.

model = AutoModelForCausalLM.from_pretrained(tiny_model_dir)
model = model.cuda()
model.eval()

def forward_loop(model):
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda()
with torch.no_grad():
model(input_ids)

model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop)
return model


@pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG])
def test_parallel_vs_sequential_identical(tmp_path, quant_cfg):
"""Verify parallel export produces bitwise identical output to sequential."""
Comment on lines +43 to +44
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Docstring claims "bitwise identical" but implementation uses torch.allclose.

The docstring at line 44 states "bitwise identical output" but line 83 uses torch.allclose which tolerates small numerical differences. For true bitwise identity verification, use torch.equal. If small FP differences are expected due to threading non-determinism, update the docstring to reflect this.

📝 Option 1: Use torch.equal for true bitwise check
     for key in seq_sd:
-        assert torch.allclose(seq_sd[key], par_sd[key]), (
+        assert torch.equal(seq_sd[key], par_sd[key]), (
             f"Weight mismatch for {key}: max diff={torch.abs(seq_sd[key] - par_sd[key]).max()}"
         )
📝 Option 2: Update docstring to match implementation
-    """Verify parallel export produces bitwise identical output to sequential."""
+    """Verify parallel export produces numerically equivalent output to sequential."""

Also applies to: 82-85

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py` around
lines 43 - 44, The docstring for test_parallel_vs_sequential_identical claims
"bitwise identical" but the assertion uses torch.allclose (in the body of
test_parallel_vs_sequential_identical), which allows numerical tolerance; make
the behavior and text consistent: either replace torch.allclose(...) with
torch.equal(...) to perform a true bitwise check, or update the
test_parallel_vs_sequential_identical docstring to describe that outputs are
numerically close (not bitwise identical) if non-determinism is expected; update
any other similar assertions around the same block (lines 82–85 area) to match
the chosen approach.

num_gpus = torch.cuda.device_count()
if num_gpus < 2:
pytest.skip("Need >= 2 GPUs for parallel export test")

# Create a tiny model and spread across GPUs.
tiny_model_dir = create_tiny_llama_dir(tmp_path / "model", num_hidden_layers=4)
model = AutoModelForCausalLM.from_pretrained(
tiny_model_dir, device_map="auto", torch_dtype=torch.float16
)
model.eval()

def forward_loop(model):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With only 4 hidden layers in a tiny model, device_map="auto" may place everything on a single GPU depending on the system's memory profile, causing the parallel path to silently fall back to sequential. Consider adding a check after quantization that weights are actually on multiple devices, or manually distributing layers:

# After quantization, verify weights are on multiple GPUs
devices = {p.device for p in model.parameters()}
assert len([d for d in devices if d.type == 'cuda']) > 1, (
    f"Model not spread across GPUs: {devices}. Test won't exercise parallel path."
)

first_device = next(model.parameters()).device
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).to(first_device)
with torch.no_grad():
model(input_ids)

model = mtq.quantize(model, quant_cfg, forward_loop)

# Export sequentially.
seq_dir = tmp_path / "export_sequential"
export_hf_vllm_fq_checkpoint(model, export_dir=seq_dir, parallel=False)

# Re-enable weight quantizers (export disables them — need to restore for second export).
# The function already re-enables them at the end, so we can just call it again.

# Export in parallel.
par_dir = tmp_path / "export_parallel"
export_hf_vllm_fq_checkpoint(model, export_dir=par_dir, parallel=True)

# Compare HF weights.
seq_model = AutoModelForCausalLM.from_pretrained(seq_dir)
par_model = AutoModelForCausalLM.from_pretrained(par_dir)
seq_sd = seq_model.state_dict()
par_sd = par_model.state_dict()

assert seq_sd.keys() == par_sd.keys(), "Key mismatch between sequential and parallel export"
for key in seq_sd:
assert torch.allclose(seq_sd[key], par_sd[key]), (
f"Weight mismatch for {key}: max diff={torch.abs(seq_sd[key] - par_sd[key]).max()}"
)

# Compare full modelopt state payload (weights_only=False: modelopt_state contains
# Python objects — dicts, strings, quantizer configs — that require unpickling).
seq_state = torch.load(seq_dir / "vllm_fq_modelopt_state.pth", weights_only=False)
par_state = torch.load(par_dir / "vllm_fq_modelopt_state.pth", weights_only=False)

Comment on lines +87 to +91
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add explicit safety justification for weights_only=False.

Per coding guidelines, torch.load(..., weights_only=False) requires an inline comment confirming the file is internally-generated and not user-supplied. The current comment explains why unpickling is needed but should also confirm the file's trusted origin.

📝 Suggested comment enhancement
-    # Compare full modelopt state payload (weights_only=False: modelopt_state contains
-    # Python objects — dicts, strings, quantizer configs — that require unpickling).
+    # Compare full modelopt state payload.
+    # weights_only=False: these files are internally-generated by the test above (not
+    # user-supplied) and contain Python objects (dicts, quantizer configs) requiring unpickling.
     seq_state = torch.load(seq_dir / "vllm_fq_modelopt_state.pth", weights_only=False)
     par_state = torch.load(par_dir / "vllm_fq_modelopt_state.pth", weights_only=False)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/gpu/torch/export/test_vllm_fakequant_hf_parallel_export.py` around
lines 87 - 91, Add an inline safety justification next to the torch.load calls
that use weights_only=False (the two calls that load
"vllm_fq_modelopt_state.pth" into seq_state and par_state from seq_dir and
par_dir) stating that these files are generated internally by the test suite and
not sourced from untrusted users, so unpickling is safe; update the existing
comment above those lines to include this explicit trust assertion referencing
the internal generation of those files.

# Compare modelopt_state_dict (quantizer metadata including quantizer_state).
seq_msd = seq_state.get("modelopt_state_dict", [])
par_msd = par_state.get("modelopt_state_dict", [])
assert len(seq_msd) == len(par_msd), "modelopt_state_dict length mismatch"
for (seq_mode, seq_ms), (par_mode, par_ms) in zip(seq_msd, par_msd):
assert seq_mode == par_mode, f"Mode mismatch: {seq_mode} vs {par_mode}"

# Compare modelopt_state_weights (per-quantizer tensor state).
seq_qsd = seq_state["modelopt_state_weights"]
par_qsd = par_state["modelopt_state_weights"]
assert seq_qsd.keys() == par_qsd.keys(), "Quantizer state dict key mismatch"
for key in seq_qsd:
seq_val = seq_qsd[key]
par_val = par_qsd[key]
if isinstance(seq_val, dict):
for k in seq_val:
if isinstance(seq_val[k], torch.Tensor):
assert torch.equal(seq_val[k], par_val[k]), (
f"Quantizer state mismatch for {key}.{k}"
)
else:
assert seq_val[k] == par_val[k], f"Quantizer state mismatch for {key}.{k}"
else:
assert seq_val == par_val, f"Quantizer state mismatch for {key}"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_single_gpu_fallback(tmp_path):
"""Verify parallel=True gracefully falls back to sequential on single GPU."""
tiny_model_dir = create_tiny_llama_dir(tmp_path / "model", num_hidden_layers=2)
model = AutoModelForCausalLM.from_pretrained(tiny_model_dir)
model = model.cuda() # All on cuda:0
model.eval()

def forward_loop(model):
input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda()
with torch.no_grad():
model(input_ids)

model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop)

# parallel=True but single GPU → should fall back to sequential without error.
export_dir = tmp_path / "export"
export_hf_vllm_fq_checkpoint(model, export_dir=export_dir, parallel=True)

assert (export_dir / "vllm_fq_modelopt_state.pth").exists()
reloaded = AutoModelForCausalLM.from_pretrained(export_dir)
assert reloaded is not None
Loading