-
Notifications
You must be signed in to change notification settings - Fork 353
feat: parallelize fakequant export across GPUs via ThreadPoolExecutor #1241
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
15832ec
0b4d48b
07f9626
29b9888
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,6 +14,11 @@ | |
| # limitations under the License. | ||
| """Export HuggingFace model to vLLM fakequant checkpoint.""" | ||
|
|
||
| import logging | ||
| import time | ||
| from collections import defaultdict | ||
| from concurrent.futures import ThreadPoolExecutor | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
|
|
||
| import torch | ||
|
|
@@ -28,6 +33,93 @@ | |
|
|
||
| __all__ = ["export_hf_vllm_fq_checkpoint"] | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| @dataclass | ||
| class _WeightQuantWork: | ||
| """A single weight tensor to be fake-quantized during export.""" | ||
|
|
||
| sd_key: str | ||
| quantizer: TensorQuantizer | ||
| weight: torch.Tensor | ||
| # For optional pre_quant_scale folding: | ||
| inp_q: TensorQuantizer | None | ||
| inp_q_key: str | None | ||
|
|
||
|
|
||
| def _collect_quant_work( | ||
| model: nn.Module, state_dict: dict[str, torch.Tensor] | ||
| ) -> list[_WeightQuantWork]: | ||
| """Collect all weight quantization work items from the model.""" | ||
| work_items = [] | ||
| seen_keys: set[str] = set() | ||
| for module_name, module in model.named_modules(): | ||
| if not isinstance(module, QuantModule): | ||
| continue | ||
| for attr_name, quantizer in module.named_children(): | ||
| if not ( | ||
| attr_name.endswith("weight_quantizer") | ||
| and isinstance(quantizer, TensorQuantizer) | ||
| and quantizer.fake_quant | ||
| and quantizer.is_enabled | ||
| ): | ||
| continue | ||
| weight_name = attr_name.removesuffix("_quantizer") | ||
| prefix = f"{module_name}." if module_name else "" | ||
| sd_key = f"{prefix}{weight_name}" | ||
| assert sd_key not in seen_keys, f"Weight {sd_key} has already been fakequantized" | ||
| seen_keys.add(sd_key) | ||
| if sd_key not in state_dict: | ||
| continue | ||
| # Check for pre_quant_scale folding eligibility. | ||
| inp_q = None | ||
| inp_q_key = None | ||
| inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") | ||
| if hasattr(module, inp_attr): | ||
| candidate = getattr(module, inp_attr) | ||
| if ( | ||
| hasattr(candidate, "_pre_quant_scale") | ||
| and candidate._pre_quant_scale is not None | ||
| and candidate._disabled | ||
| and getattr(candidate, "_enable_pre_quant_scale", True) | ||
| ): | ||
| inp_q = candidate | ||
| inp_q_key = get_unwrapped_name( | ||
| f"{module_name}.{inp_attr}" if module_name else inp_attr, model | ||
| ) | ||
| work_items.append( | ||
| _WeightQuantWork( | ||
| sd_key=sd_key, | ||
| quantizer=quantizer, | ||
| weight=state_dict[sd_key], | ||
| inp_q=inp_q, | ||
| inp_q_key=inp_q_key, | ||
| ) | ||
| ) | ||
| return work_items | ||
|
|
||
|
|
||
| def _process_weight(item: _WeightQuantWork) -> tuple[str, torch.Tensor, str | None]: | ||
| """Fake-quantize a single weight tensor and optionally fold pre_quant_scale. | ||
|
|
||
| Returns (sd_key, quantized_weight_on_cpu, inp_q_key_or_None). | ||
| """ | ||
| w = item.weight | ||
| w_quant = item.quantizer(w.float()).to(w.dtype) | ||
| if item.inp_q is not None: | ||
| scale = item.inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) | ||
| w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) | ||
| return item.sd_key, w_quant.cpu(), item.inp_q_key | ||
|
|
||
|
|
||
| def _process_device_batch(items: list[_WeightQuantWork], device: torch.device): | ||
| """Process all weight items on a single GPU. Runs in a dedicated thread.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: |
||
| with torch.inference_mode(), torch.cuda.device(device): | ||
| results = [_process_weight(item) for item in items] | ||
| torch.cuda.synchronize(device) | ||
| return results | ||
|
|
||
|
|
||
| def disable_rotate(quantizer: TensorQuantizer): | ||
| """Return a disabled copy of the quantizer's ``_rotate`` field, preserving its type.""" | ||
|
|
@@ -41,6 +133,7 @@ def disable_rotate(quantizer: TensorQuantizer): | |
| def export_hf_vllm_fq_checkpoint( | ||
| model: nn.Module, | ||
| export_dir: Path | str, | ||
| parallel: bool = True, | ||
| ): | ||
| """Export quantized HF weights + ``vllm_fq_modelopt_state.pth`` for vLLM fake-quant reload. | ||
|
|
||
|
|
@@ -53,6 +146,9 @@ def export_hf_vllm_fq_checkpoint( | |
| Args: | ||
| model: In-memory quantized model. | ||
| export_dir: Output dir for HF files and ``vllm_fq_modelopt_state.pth``. | ||
| parallel: If True, fake-quantize weights across GPUs concurrently using | ||
| one thread per GPU device. Falls back to sequential when all weights | ||
| are on the same device or on CPU. Default True. | ||
| """ | ||
| export_dir = Path(export_dir) | ||
| export_dir.mkdir(parents=True, exist_ok=True) | ||
|
|
@@ -62,50 +158,66 @@ def export_hf_vllm_fq_checkpoint( | |
| # parameters are never modified. Apply each weight quantizer's fake-quant | ||
| # to the corresponding weight tensor in the copy. | ||
| state_dict = model.state_dict() | ||
| fakequant_weights = set() | ||
| input_quantizers_folded_pqs = ( | ||
| set() | ||
| ) # keys for input_quantizers where pre_quant_scale was folded | ||
| fakequant_weights: set[str] = set() | ||
| input_quantizers_folded_pqs: set[str] = set() | ||
|
|
||
| work_items = _collect_quant_work(model, state_dict) | ||
|
|
||
| # Group work items by device for parallel dispatch. | ||
| device_groups: dict[torch.device, list[_WeightQuantWork]] = defaultdict(list) | ||
| for item in work_items: | ||
| device_groups[item.weight.device].append(item) | ||
|
|
||
| num_cuda_devices = sum(1 for d in device_groups if d.type == "cuda") | ||
| use_parallel = parallel and num_cuda_devices > 1 | ||
|
|
||
| t0 = time.monotonic() | ||
| with torch.inference_mode(): | ||
| for module_name, module in model.named_modules(): | ||
| if not isinstance(module, QuantModule): | ||
| continue | ||
| for attr_name, quantizer in module.named_children(): | ||
| if not ( | ||
| attr_name.endswith("weight_quantizer") | ||
| and isinstance(quantizer, TensorQuantizer) | ||
| and quantizer.fake_quant | ||
| and quantizer.is_enabled | ||
| ): | ||
| continue | ||
| weight_name = attr_name.removesuffix("_quantizer") | ||
| prefix = f"{module_name}." if module_name else "" | ||
| sd_key = f"{prefix}{weight_name}" | ||
| assert sd_key not in fakequant_weights, ( | ||
| f"Weight {sd_key} has already been fakequantized" | ||
| ) | ||
| if sd_key in state_dict: | ||
| w = state_dict[sd_key] | ||
| w_quant = quantizer(w.float()).to(w.dtype).cpu() | ||
| # Fold pre_quant_scale: (x*s)@fake_quant(W) = x@(fake_quant(W)*s) | ||
| # Only valid when input_quantizer does NOT fake-quant activations. If it does | ||
| # fake_quant(x*s), the non-linearity prevents folding s into W. | ||
| inp_attr = attr_name.replace("weight_quantizer", "input_quantizer") | ||
| if hasattr(module, inp_attr): | ||
| inp_q = getattr(module, inp_attr) | ||
| if ( | ||
| hasattr(inp_q, "_pre_quant_scale") | ||
| and inp_q._pre_quant_scale is not None | ||
| and inp_q._disabled | ||
| ): | ||
| scale = inp_q._pre_quant_scale.squeeze().to(device=w_quant.device) | ||
| w_quant = (w_quant * scale[None, :]).to(w_quant.dtype) | ||
| inp_q_key = get_unwrapped_name( | ||
| f"{module_name}.{inp_attr}" if module_name else inp_attr, model | ||
| ) | ||
| if use_parallel: | ||
| logger.info( | ||
| "Parallel export: %d weights across %d GPUs (%s)", | ||
| len(work_items), | ||
| num_cuda_devices, | ||
| ", ".join(f"{d}: {len(items)} weights" for d, items in device_groups.items()), | ||
| ) | ||
| with ThreadPoolExecutor(max_workers=num_cuda_devices) as pool: | ||
| # Submit GPU batches first (non-blocking) | ||
| futures = [ | ||
| pool.submit(_process_device_batch, items, device) | ||
| for device, items in device_groups.items() | ||
| if device.type == "cuda" | ||
| ] | ||
| # Process CPU weights inline while GPU futures run | ||
| for device, items in device_groups.items(): | ||
| if device.type != "cuda": | ||
| for sd_key, w_quant, inp_q_key in map(_process_weight, items): | ||
| state_dict[sd_key] = w_quant | ||
| fakequant_weights.add(sd_key) | ||
| if inp_q_key is not None: | ||
| input_quantizers_folded_pqs.add(inp_q_key) | ||
| # Collect GPU results | ||
| for future in futures: | ||
| for sd_key, w_quant, inp_q_key in future.result(): | ||
| state_dict[sd_key] = w_quant | ||
| fakequant_weights.add(sd_key) | ||
| if inp_q_key is not None: | ||
| input_quantizers_folded_pqs.add(inp_q_key) | ||
| state_dict[sd_key] = w_quant | ||
| fakequant_weights.add(sd_key) | ||
| else: | ||
| # Sequential fallback (single GPU, CPU, or parallel=False). | ||
| for item in work_items: | ||
| sd_key, w_quant, inp_q_key = _process_weight(item) | ||
| state_dict[sd_key] = w_quant | ||
| fakequant_weights.add(sd_key) | ||
| if inp_q_key is not None: | ||
| input_quantizers_folded_pqs.add(inp_q_key) | ||
|
|
||
| elapsed = time.monotonic() - t0 | ||
| logger.info( | ||
| "Export step 1 (%s): %d weights fake-quantized in %.1fs", | ||
| "parallel" if use_parallel else "sequential", | ||
| len(fakequant_weights), | ||
| elapsed, | ||
| ) | ||
|
|
||
| # Filter quantizer tensors out for a clean HF checkpoint. | ||
| clean_sd = {k: v for k, v in state_dict.items() if "quantizer" not in k} | ||
|
|
@@ -166,4 +278,5 @@ def export_hf_vllm_fq_checkpoint( | |
|
|
||
| for wq, orig_rotate in wqs_to_restore: | ||
| wq.enable() | ||
| wq._rotate = orig_rotate | ||
| if orig_rotate is not None: | ||
| wq._rotate = orig_rotate | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,139 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Test parallel vs sequential export produces identical outputs.""" | ||
|
|
||
| import pytest | ||
| import torch | ||
| from _test_utils.torch.transformers_models import create_tiny_llama_dir | ||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| import modelopt.torch.quantization as mtq | ||
| from modelopt.torch.export import export_hf_vllm_fq_checkpoint | ||
|
|
||
|
|
||
| def _quantize_model(tmp_path, suffix=""): | ||
| """Create and quantize a tiny LLaMA model. Returns (model, export_dir).""" | ||
| tiny_model_dir = create_tiny_llama_dir(tmp_path / f"model{suffix}", num_hidden_layers=4) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This |
||
| model = AutoModelForCausalLM.from_pretrained(tiny_model_dir) | ||
| model = model.cuda() | ||
| model.eval() | ||
|
|
||
| def forward_loop(model): | ||
| input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda() | ||
| with torch.no_grad(): | ||
| model(input_ids) | ||
|
|
||
| model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop) | ||
| return model | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("quant_cfg", [mtq.FP8_DEFAULT_CFG]) | ||
| def test_parallel_vs_sequential_identical(tmp_path, quant_cfg): | ||
| """Verify parallel export produces bitwise identical output to sequential.""" | ||
|
Comment on lines
+43
to
+44
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Docstring claims "bitwise identical" but implementation uses The docstring at line 44 states "bitwise identical output" but line 83 uses 📝 Option 1: Use torch.equal for true bitwise check for key in seq_sd:
- assert torch.allclose(seq_sd[key], par_sd[key]), (
+ assert torch.equal(seq_sd[key], par_sd[key]), (
f"Weight mismatch for {key}: max diff={torch.abs(seq_sd[key] - par_sd[key]).max()}"
)📝 Option 2: Update docstring to match implementation- """Verify parallel export produces bitwise identical output to sequential."""
+ """Verify parallel export produces numerically equivalent output to sequential."""Also applies to: 82-85 🤖 Prompt for AI Agents |
||
| num_gpus = torch.cuda.device_count() | ||
| if num_gpus < 2: | ||
| pytest.skip("Need >= 2 GPUs for parallel export test") | ||
|
|
||
| # Create a tiny model and spread across GPUs. | ||
| tiny_model_dir = create_tiny_llama_dir(tmp_path / "model", num_hidden_layers=4) | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| tiny_model_dir, device_map="auto", torch_dtype=torch.float16 | ||
| ) | ||
| model.eval() | ||
|
|
||
| def forward_loop(model): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With only 4 hidden layers in a tiny model, # After quantization, verify weights are on multiple GPUs
devices = {p.device for p in model.parameters()}
assert len([d for d in devices if d.type == 'cuda']) > 1, (
f"Model not spread across GPUs: {devices}. Test won't exercise parallel path."
) |
||
| first_device = next(model.parameters()).device | ||
| input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).to(first_device) | ||
| with torch.no_grad(): | ||
| model(input_ids) | ||
|
|
||
| model = mtq.quantize(model, quant_cfg, forward_loop) | ||
|
|
||
| # Export sequentially. | ||
| seq_dir = tmp_path / "export_sequential" | ||
| export_hf_vllm_fq_checkpoint(model, export_dir=seq_dir, parallel=False) | ||
|
|
||
| # Re-enable weight quantizers (export disables them — need to restore for second export). | ||
| # The function already re-enables them at the end, so we can just call it again. | ||
|
|
||
| # Export in parallel. | ||
| par_dir = tmp_path / "export_parallel" | ||
| export_hf_vllm_fq_checkpoint(model, export_dir=par_dir, parallel=True) | ||
|
|
||
| # Compare HF weights. | ||
| seq_model = AutoModelForCausalLM.from_pretrained(seq_dir) | ||
| par_model = AutoModelForCausalLM.from_pretrained(par_dir) | ||
| seq_sd = seq_model.state_dict() | ||
| par_sd = par_model.state_dict() | ||
|
|
||
| assert seq_sd.keys() == par_sd.keys(), "Key mismatch between sequential and parallel export" | ||
| for key in seq_sd: | ||
| assert torch.allclose(seq_sd[key], par_sd[key]), ( | ||
| f"Weight mismatch for {key}: max diff={torch.abs(seq_sd[key] - par_sd[key]).max()}" | ||
| ) | ||
|
|
||
| # Compare full modelopt state payload (weights_only=False: modelopt_state contains | ||
| # Python objects — dicts, strings, quantizer configs — that require unpickling). | ||
| seq_state = torch.load(seq_dir / "vllm_fq_modelopt_state.pth", weights_only=False) | ||
| par_state = torch.load(par_dir / "vllm_fq_modelopt_state.pth", weights_only=False) | ||
|
|
||
|
Comment on lines
+87
to
+91
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add explicit safety justification for Per coding guidelines, 📝 Suggested comment enhancement- # Compare full modelopt state payload (weights_only=False: modelopt_state contains
- # Python objects — dicts, strings, quantizer configs — that require unpickling).
+ # Compare full modelopt state payload.
+ # weights_only=False: these files are internally-generated by the test above (not
+ # user-supplied) and contain Python objects (dicts, quantizer configs) requiring unpickling.
seq_state = torch.load(seq_dir / "vllm_fq_modelopt_state.pth", weights_only=False)
par_state = torch.load(par_dir / "vllm_fq_modelopt_state.pth", weights_only=False)🤖 Prompt for AI Agents |
||
| # Compare modelopt_state_dict (quantizer metadata including quantizer_state). | ||
| seq_msd = seq_state.get("modelopt_state_dict", []) | ||
| par_msd = par_state.get("modelopt_state_dict", []) | ||
| assert len(seq_msd) == len(par_msd), "modelopt_state_dict length mismatch" | ||
| for (seq_mode, seq_ms), (par_mode, par_ms) in zip(seq_msd, par_msd): | ||
| assert seq_mode == par_mode, f"Mode mismatch: {seq_mode} vs {par_mode}" | ||
|
|
||
| # Compare modelopt_state_weights (per-quantizer tensor state). | ||
| seq_qsd = seq_state["modelopt_state_weights"] | ||
| par_qsd = par_state["modelopt_state_weights"] | ||
| assert seq_qsd.keys() == par_qsd.keys(), "Quantizer state dict key mismatch" | ||
| for key in seq_qsd: | ||
| seq_val = seq_qsd[key] | ||
| par_val = par_qsd[key] | ||
| if isinstance(seq_val, dict): | ||
| for k in seq_val: | ||
| if isinstance(seq_val[k], torch.Tensor): | ||
| assert torch.equal(seq_val[k], par_val[k]), ( | ||
| f"Quantizer state mismatch for {key}.{k}" | ||
| ) | ||
| else: | ||
| assert seq_val[k] == par_val[k], f"Quantizer state mismatch for {key}.{k}" | ||
| else: | ||
| assert seq_val == par_val, f"Quantizer state mismatch for {key}" | ||
|
|
||
|
|
||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") | ||
| def test_single_gpu_fallback(tmp_path): | ||
| """Verify parallel=True gracefully falls back to sequential on single GPU.""" | ||
| tiny_model_dir = create_tiny_llama_dir(tmp_path / "model", num_hidden_layers=2) | ||
| model = AutoModelForCausalLM.from_pretrained(tiny_model_dir) | ||
| model = model.cuda() # All on cuda:0 | ||
| model.eval() | ||
|
|
||
| def forward_loop(model): | ||
| input_ids = torch.randint(0, model.config.vocab_size, (1, 128)).cuda() | ||
| with torch.no_grad(): | ||
| model(input_ids) | ||
|
|
||
| model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, forward_loop) | ||
|
|
||
| # parallel=True but single GPU → should fall back to sequential without error. | ||
| export_dir = tmp_path / "export" | ||
| export_hf_vllm_fq_checkpoint(model, export_dir=export_dir, parallel=True) | ||
|
|
||
| assert (export_dir / "vllm_fq_modelopt_state.pth").exists() | ||
| reloaded = AutoModelForCausalLM.from_pretrained(export_dir) | ||
| assert reloaded is not None | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_enable_pre_quant_scalecheck is a correctness improvement (prevents double-folding whendisable_pre_quant_scale_and_resmoothwas called), but it's a behavioral change unrelated to parallelization. Consider noting this in the PR description or splitting it into a separate fix for easier review/bisect.Also,
getattr(candidate, "_enable_pre_quant_scale", True)always returnsTruefor the default becauseTensorQuantizer.__init__setsself._enable_pre_quant_scale = True. Thegetattrwith default is unnecessary —candidate._enable_pre_quant_scalewould suffice (sincecandidateis guaranteed to be aTensorQuantizerhere).