Skip to content

Commit e764ad9

Browse files
realAsmaclaude
andcommitted
Refactor llm_qat example with YAML configs and ModelOpt argument parser
Replace shell-based launch script with YAML config files and integrate ModelOpt's HfArgumentParser plugin for cleaner argument handling. Add auto-generated ARGUMENTS.md, update README with new usage patterns, and add unit tests for the argument parser plugin. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent fcb09bf commit e764ad9

13 files changed

Lines changed: 532 additions & 292 deletions

File tree

.pre-commit-config.yaml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,20 @@ repos:
127127
- id: markdownlint-cli2
128128
args: ["--fix"]
129129

130+
- repo: local
131+
hooks:
132+
- id: generate-arguments-md
133+
name: Regenerate examples/llm_qat/ARGUMENTS.md
134+
entry: bash -c 'python examples/llm_qat/main.py --generate_docs examples/llm_qat/ARGUMENTS.md && git diff --exit-code examples/llm_qat/ARGUMENTS.md'
135+
language: system
136+
files: >-
137+
(?x)^(
138+
examples/llm_qat/main\.py|
139+
modelopt/torch/opt/plugins/transformers\.py|
140+
modelopt/torch/quantization/plugins/transformers_trainer\.py
141+
)$
142+
pass_filenames: false
143+
130144
##### Manual hooks (Expect many false positives)
131145
# These hooks are only run with `pre-commit run --all-files --hook-stage manual <hook_id>`
132146

examples/llm_qat/ARGUMENTS.md

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Argument Reference
2+
3+
_Auto-generated — do not edit by hand._
4+
5+
## QuantizationArguments
6+
7+
| Argument | Type | Default | Description |
8+
|----------|------|---------|-------------|
9+
| `--quant_cfg` | `str` | `None` | Specify the quantization format for PTQ/QAT. if specified, PTQ/QAT will be enabled with the specified quantization format |
10+
| `--calib_size` | `int` | `512` | Specify the calibration size for quantization. The calibration dataset is used to setup the quantization scale parameters for PTQ/QAT. |
11+
| `--compress` | `bool` | `False` | Whether to compress the model weights after quantization for QLoRA. This is useful for reducing the model size. |
12+
13+
## DataArguments
14+
15+
| Argument | Type | Default | Description |
16+
|----------|------|---------|-------------|
17+
| `--dataset` | `str` | `"Daring-Anteater"` | Specify the dataset. |
18+
| `--train_size` | `int` | `0` | Number of training samples to use. If `0`, use default training size. |
19+
| `--eval_size` | `int` | `0` | Number of evaluation samples to use. If `0`, use default evaluation size. |
20+
21+
## ModelArguments
22+
23+
| Argument | Type | Default | Description |
24+
|----------|------|---------|-------------|
25+
| `--model_name_or_path` | `str` | `"meta-llama/Llama-2-7b-hf"` | |
26+
| `--teacher_model` | `str` | `None` | The name or path of the teacher model to use for distillation. |
27+
28+
## TrainingArguments
29+
30+
Extends [HuggingFace TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Only additional/overridden arguments are shown below.
31+
32+
| Argument | Type | Default | Description |
33+
|----------|------|---------|-------------|
34+
| `--cache_dir` | `str` | `None` | |
35+
| `--model_max_length` | `int` | `4096` | Maximum sequence length. Sequences will be right padded (and possibly truncated). |
36+
| `--lora` | `bool` | `False` | Whether to add LoRA (Low-Rank Adaptation) adapter before training. When using real quantization, the LoRA adapter must be set, as quantized weights will be frozen during training. |
37+
| `--distill` | `bool` | `False` | Select if training with distillation. |
38+

examples/llm_qat/README.md

Lines changed: 100 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,39 @@ Quantization aware distillation (QAD) can be used to further improve accuracy of
3939

4040
The Llama3-8B fine-tuning and QAT below requires a minimum of 2 x 80GB GPUs per machine.
4141

42+
#### How to Launch
43+
44+
Use `accelerate launch` with a backend config and pass arguments via CLI or a YAML config file:
45+
46+
```sh
47+
# With YAML config (recommended)
48+
accelerate launch --config-file accelerate_config/fsdp2.yaml main.py \
49+
--config configs/qat_nvfp4.yaml
50+
51+
# With YAML + CLI overrides
52+
accelerate launch --config-file accelerate_config/fsdp2.yaml main.py \
53+
--config configs/qat_nvfp4.yaml --learning_rate 5e-5
54+
55+
# CLI only (no YAML)
56+
accelerate launch --config-file accelerate_config/fsdp2.yaml main.py \
57+
--model_name_or_path meta-llama/Meta-Llama-3-8B \
58+
--quant_cfg NVFP4_DEFAULT_CFG \
59+
--num_train_epochs 2.0 \
60+
--learning_rate 1e-5 \
61+
--output_dir llama3-qat
62+
```
63+
64+
#### Backend Configuration
65+
66+
| Backend | Config File | Notes |
67+
|---------|------------|-------|
68+
| FSDP2 | `accelerate_config/fsdp2.yaml` | Recommended for multi-GPU |
69+
| FSDP1 | `accelerate_config/fsdp1.yaml` | Legacy FSDP |
70+
| DDP | `accelerate_config/ddp.yaml` | Add `--gradient_checkpointing True` |
71+
| DeepSpeed | `accelerate_config/deepspeed.yaml` | Add `--gradient_checkpointing True` |
72+
73+
See [ARGUMENTS.md](ARGUMENTS.md) for the full argument reference.
74+
4275
#### QAT Example Workflow
4376

4477
In QAT, a model quantized using [mtq.quantize()](https://nvidia.github.io/Model-Optimizer/reference/generated/modelopt.torch.quantization.model_quant.html#modelopt.torch.quantization.model_quant.quantize) can be directly fine-tuned with the original training pipeline. During QAT, the scaling factors inside quantizers are frozen and the model weights are fine-tuned.
@@ -103,6 +136,34 @@ python simple_qat_train.py --model meta-llama/Llama-3.2-3B
103136

104137
To train larger models with distributed training, please refer to [End-to-end QAT Example](#end-to-end-qat-example).
105138

139+
#### QATTrainer Example Workflow
140+
141+
`QATTrainer` is a drop-in replacement for HuggingFace's `Trainer` that handles quantization internally — no need to manually call `mtq.quantize()`. Quantization is configured via `quant_args`.
142+
143+
```python
144+
from modelopt.torch.quantization.plugins.transformers_trainer import QATTrainer, QuantizationArguments
145+
146+
...
147+
148+
# [Not shown] load model, tokenizer, data loaders etc
149+
quant_args = QuantizationArguments(quant_cfg="NVFP4_DEFAULT_CFG")
150+
151+
trainer = QATTrainer(
152+
model=model,
153+
processing_class=tokenizer,
154+
args=training_args,
155+
quant_args=quant_args,
156+
**data_module,
157+
)
158+
159+
trainer.train() # QATTrainer quantizes the model and runs QAT
160+
161+
# Save the final model weights; An example usage
162+
trainer.save_model()
163+
```
164+
165+
> **_NOTE:_** `QADTrainer` (shown below) extends `QATTrainer` with distillation support for cases where QAT alone is not enough.
166+
106167
#### QAD Example Workflow
107168

108169
Here is an example workflow for performing QAD:
@@ -182,22 +243,24 @@ This folder contains end-to-end runnable fine-tuning/QAT pipeline where Llama3-8
182243
First, we need to run un-quantized fine-tuning. Here is the command for that:
183244

184245
```sh
185-
./launch.sh --model meta-llama/Meta-Llama-3-8B \
186-
--num_epochs 2.0 \
187-
--lr 1e-5 \
246+
accelerate launch --config-file accelerate_config/fsdp2.yaml main.py \
247+
--model_name_or_path meta-llama/Meta-Llama-3-8B \
248+
--num_train_epochs 2.0 \
249+
--learning_rate 1e-5 \
188250
--do_train True \
189251
--output_dir llama3-finetune
190252
```
191253

192254
This will generate a fine-tuned checkpoint in `output_dir` specified above. You can load this checkpoint, quantize the model, evaluate PTQ results or run additional QAT.
193-
This can be accomplished by specifying the quantization format to the `launch.sh` script.
255+
This can be accomplished by specifying the quantization format.
194256
In this example, we are quantizing the model with INT4 block-wise weights and INT8 per-tensor activation quantization.
195257

196258
To perform PTQ evaluation, run:
197259

198260
```sh
199261
# Load the checkpoint from previous fine-tuning stage, quantize the model and evaluate without additional training
200-
./launch.sh --model llama3-finetune \
262+
accelerate launch --config-file accelerate_config/fsdp2.yaml main.py \
263+
--model_name_or_path llama3-finetune \
201264
--do_train False \
202265
--quant_cfg NVFP4_DEFAULT_CFG
203266
```
@@ -206,14 +269,22 @@ To perform QAT, run:
206269

207270
```sh
208271
# Load the quantized checkpoint from previous fine-tuning stage and run additional training (QAT)
209-
./launch.sh --model llama3-finetune \
210-
--num_epochs 2.0 \
211-
--lr 1e-5 \
272+
accelerate launch --config-file accelerate_config/fsdp2.yaml main.py \
273+
--model_name_or_path llama3-finetune \
274+
--num_train_epochs 2.0 \
275+
--learning_rate 1e-5 \
212276
--do_train True \
213277
--quant_cfg NVFP4_DEFAULT_CFG \
214278
--output_dir llama3-qat
215279
```
216280

281+
Or equivalently, using a YAML config:
282+
283+
```sh
284+
accelerate launch --config-file accelerate_config/fsdp2.yaml main.py \
285+
--config configs/qat_nvfp4.yaml
286+
```
287+
217288
You may alternatively perform QAT with any other quantization formats from **ModelOpt**. Please see more details on the supported quantization formats and how to use them as shown below:
218289

219290
```python
@@ -223,18 +294,14 @@ import modelopt.torch.quantization as mtq
223294
help(mtq.config)
224295
```
225296

226-
You could also add your own customized quantization format to `CUSTOM_QUANT_CFG` from `main.py` and perform QAT.
227-
228297
> **_NOTE:_** QAT requires higher memory than the full-precision fine-tuning. A solution to avoid this extra memory usage is to use [activation checkpointing](https://pytorch.org/docs/stable/checkpoint.html) or gradient checkpointing. Activation checkpointing can be enabled easily with training frameworks such as Huggingface by adding an additional argument `gradient_checkpointing True`. Learn more [here](https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one#gradient-checkpointing). Activation checkpointing or gradient checkpointing is enabled by default in this example.
229298
230299
> **_NOTE:_** Like any other model training, the QAT model accuracy can be further improved by optimizing the training
231300
> hyper-parameters such as learning rate, training duration etc.
232301
233-
> **_NOTE:_** `launch.sh` defaults to use `LlamaDecoderLayer` as the transformer layer class. If your model uses a different class, you need to pass `--fsdp_transformer_layer_cls_to_wrap <your_layer_class>` to the `launch.sh` script. For example, for `Qwen/Qwen3-8B`, specify `--fsdp_transformer_layer_cls_to_wrap Qwen3DecoderLayer` as an additional argument.
234-
235302
### Results
236303

237-
Here is an example result following the workflow above with slightly different hyper-parameters (We used an effective batch size of 128 by adjusting `--train_bs` and `--accum_steps` as per the available GPU memory).
304+
Here is an example result following the workflow above with slightly different hyper-parameters (We used an effective batch size of 128 by adjusting `--per_device_train_batch_size` and `--gradient_accumulation_steps` as per the available GPU memory).
238305
As we can see below, QAT has improved the validation perplexity.
239306

240307
You could get slightly different numbers depending on your hyper-parameters - however you should be able to see consistent improvement
@@ -255,13 +322,22 @@ for QAT over PTQ alone.
255322
To perform QAD with logits loss, run:
256323

257324
```sh
258-
./launch.sh --model llama3-finetune \
259-
--num_epochs 3 \
260-
--lr 4e-5 \
325+
accelerate launch --config-file accelerate_config/fsdp2.yaml main.py \
326+
--config configs/qad_nvfp4.yaml
327+
```
328+
329+
Or equivalently with CLI args:
330+
331+
```sh
332+
accelerate launch --config-file accelerate_config/fsdp2.yaml main.py \
333+
--model_name_or_path llama3-finetune \
334+
--num_train_epochs 3 \
335+
--learning_rate 4e-5 \
261336
--quant_cfg NVFP4_DEFAULT_CFG \
262337
--do_train True \
263338
--output_dir llama-qad \
264-
--distill True
339+
--distill True \
340+
--teacher_model llama3-finetune
265341
```
266342

267343
> **_NOTE:_** QAD doesn't support FSDP1 (<https://docs.pytorch.org/docs/stable/fsdp.html>) backend - only FSDP2.
@@ -300,14 +376,15 @@ See more details on deployment of quantized model [here](../llm_ptq/README.md).
300376

301377
## End-to-end QLoRA with Real Quantization
302378

303-
[QLoRA](https://arxiv.org/pdf/2305.14314) is a technique mainly intended for further reducing the training memory requirement of LoRA. In QLoRA, the LoRA backbone weights are quantized to reduce the model footprint. Unlike QAT which uses simulated quantization, QLoRA requires real quantization. To compress the model weights after quantization, we use the `mtq.compress()` function, which currently supports FP8, FP4, and INT4 formats. This feature can be enabled by passing `--compress True` to the `launch.sh` script. For detailed configuration options and patterns, please refer to the `modelopt.torch.quantization.compress` documentation.
379+
[QLoRA](https://arxiv.org/pdf/2305.14314) is a technique mainly intended for further reducing the training memory requirement of LoRA. In QLoRA, the LoRA backbone weights are quantized to reduce the model footprint. Unlike QAT which uses simulated quantization, QLoRA requires real quantization. To compress the model weights after quantization, we use the `mtq.compress()` function, which currently supports FP8, FP4, and INT4 formats. This feature can be enabled by passing `--compress True`. For detailed configuration options and patterns, please refer to the `modelopt.torch.quantization.compress` documentation.
304380

305381
To evaluate QLoRA quantized model before training, run:
306382

307383
```sh
308384
# Load the HF checkpoint, quantize the model and evaluate without additional training
309385
# Also compress the model after quantization
310-
./launch.sh --model meta-llama/Meta-Llama-3-8B \
386+
accelerate launch --config-file accelerate_config/ddp.yaml main.py \
387+
--model_name_or_path meta-llama/Meta-Llama-3-8B \
311388
--do_train False \
312389
--quant_cfg NVFP4_DEFAULT_CFG \
313390
--compress True
@@ -318,9 +395,10 @@ To perform QLoRA training, run:
318395
```sh
319396
# Load the HF checkpoint, quantize the model, add LoRA adapter, and run additional training
320397
# Also compress the model after quantization
321-
./launch.sh --model meta-llama/Meta-Llama-3-8B \
322-
--num_epochs 0.5 \
323-
--lr 1e-3 \
398+
accelerate launch --config-file accelerate_config/ddp.yaml main.py \
399+
--model_name_or_path meta-llama/Meta-Llama-3-8B \
400+
--num_train_epochs 0.5 \
401+
--learning_rate 1e-3 \
324402
--do_train True \
325403
--output_dir llama3-fp4-qlora \
326404
--quant_cfg NVFP4_DEFAULT_CFG \
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Full-precision fine-tuning (no quantization)
2+
model_name_or_path: meta-llama/Meta-Llama-3-8B
3+
num_train_epochs: 2.0
4+
learning_rate: 1e-5
5+
per_device_train_batch_size: 4
6+
output_dir: llama3-finetune
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# PTQ: Post-Training Quantization evaluation
2+
model_name_or_path: meta-llama/Meta-Llama-3-8B
3+
quant_cfg: NVFP4_DEFAULT_CFG
4+
do_train: false
5+
per_device_eval_batch_size: 4
6+
output_dir: llama3-ptq-eval
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# QAD: Quantization-Aware Distillation with NVFP4
2+
model_name_or_path: meta-llama/Meta-Llama-3-8B
3+
teacher_model: meta-llama/Meta-Llama-3-8B
4+
distill: true
5+
quant_cfg: NVFP4_DEFAULT_CFG
6+
num_train_epochs: 2.0
7+
learning_rate: 1e-5
8+
per_device_train_batch_size: 4
9+
output_dir: llama3-qad-nvfp4
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# QAT: Quantization-Aware Training with NVFP4
2+
model_name_or_path: meta-llama/Meta-Llama-3-8B
3+
quant_cfg: NVFP4_DEFAULT_CFG
4+
num_train_epochs: 2.0
5+
learning_rate: 1e-5
6+
per_device_train_batch_size: 4
7+
output_dir: llama3-qat-nvfp4

0 commit comments

Comments
 (0)