You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Refactor llm_qat example with YAML configs and ModelOpt argument parser
Replace shell-based launch script with YAML config files and integrate
ModelOpt's HfArgumentParser plugin for cleaner argument handling. Add
auto-generated ARGUMENTS.md, update README with new usage patterns, and
add unit tests for the argument parser plugin.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
|`--quant_cfg`|`str`|`None`| Specify the quantization format for PTQ/QAT. if specified, PTQ/QAT will be enabled with the specified quantization format |
10
+
|`--calib_size`|`int`|`512`| Specify the calibration size for quantization. The calibration dataset is used to setup the quantization scale parameters for PTQ/QAT. |
11
+
|`--compress`|`bool`|`False`| Whether to compress the model weights after quantization for QLoRA. This is useful for reducing the model size. |
12
+
13
+
## DataArguments
14
+
15
+
| Argument | Type | Default | Description |
16
+
|----------|------|---------|-------------|
17
+
|`--dataset`|`str`|`"Daring-Anteater"`| Specify the dataset. |
18
+
|`--train_size`|`int`|`0`| Number of training samples to use. If `0`, use default training size. |
19
+
|`--eval_size`|`int`|`0`| Number of evaluation samples to use. If `0`, use default evaluation size. |
|`--teacher_model`|`str`|`None`| The name or path of the teacher model to use for distillation. |
27
+
28
+
## TrainingArguments
29
+
30
+
Extends [HuggingFace TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Only additional/overridden arguments are shown below.
31
+
32
+
| Argument | Type | Default | Description |
33
+
|----------|------|---------|-------------|
34
+
|`--cache_dir`|`str`|`None`||
35
+
|`--model_max_length`|`int`|`4096`| Maximum sequence length. Sequences will be right padded (and possibly truncated). |
36
+
|`--lora`|`bool`|`False`| Whether to add LoRA (Low-Rank Adaptation) adapter before training. When using real quantization, the LoRA adapter must be set, as quantized weights will be frozen during training. |
37
+
|`--distill`|`bool`|`False`| Select if training with distillation. |
See [ARGUMENTS.md](ARGUMENTS.md) for the full argument reference.
74
+
42
75
#### QAT Example Workflow
43
76
44
77
In QAT, a model quantized using [mtq.quantize()](https://nvidia.github.io/Model-Optimizer/reference/generated/modelopt.torch.quantization.model_quant.html#modelopt.torch.quantization.model_quant.quantize) can be directly fine-tuned with the original training pipeline. During QAT, the scaling factors inside quantizers are frozen and the model weights are fine-tuned.
To train larger models with distributed training, please refer to [End-to-end QAT Example](#end-to-end-qat-example).
105
138
139
+
#### QATTrainer Example Workflow
140
+
141
+
`QATTrainer` is a drop-in replacement for HuggingFace's `Trainer` that handles quantization internally — no need to manually call `mtq.quantize()`. Quantization is configured via `quant_args`.
142
+
143
+
```python
144
+
from modelopt.torch.quantization.plugins.transformers_trainer import QATTrainer, QuantizationArguments
145
+
146
+
...
147
+
148
+
# [Not shown] load model, tokenizer, data loaders etc
This will generate a fine-tuned checkpoint in `output_dir` specified above. You can load this checkpoint, quantize the model, evaluate PTQ results or run additional QAT.
193
-
This can be accomplished by specifying the quantization format to the `launch.sh` script.
255
+
This can be accomplished by specifying the quantization format.
194
256
In this example, we are quantizing the model with INT4 block-wise weights and INT8 per-tensor activation quantization.
195
257
196
258
To perform PTQ evaluation, run:
197
259
198
260
```sh
199
261
# Load the checkpoint from previous fine-tuning stage, quantize the model and evaluate without additional training
You may alternatively perform QAT with any other quantization formats from **ModelOpt**. Please see more details on the supported quantization formats and how to use them as shown below:
218
289
219
290
```python
@@ -223,18 +294,14 @@ import modelopt.torch.quantization as mtq
223
294
help(mtq.config)
224
295
```
225
296
226
-
You could also add your own customized quantization format to `CUSTOM_QUANT_CFG` from `main.py` and perform QAT.
227
-
228
297
> **_NOTE:_** QAT requires higher memory than the full-precision fine-tuning. A solution to avoid this extra memory usage is to use [activation checkpointing](https://pytorch.org/docs/stable/checkpoint.html) or gradient checkpointing. Activation checkpointing can be enabled easily with training frameworks such as Huggingface by adding an additional argument `gradient_checkpointing True`. Learn more [here](https://huggingface.co/docs/transformers/v4.20.1/en/perf_train_gpu_one#gradient-checkpointing). Activation checkpointing or gradient checkpointing is enabled by default in this example.
229
298
230
299
> **_NOTE:_** Like any other model training, the QAT model accuracy can be further improved by optimizing the training
231
300
> hyper-parameters such as learning rate, training duration etc.
232
301
233
-
> **_NOTE:_**`launch.sh` defaults to use `LlamaDecoderLayer` as the transformer layer class. If your model uses a different class, you need to pass `--fsdp_transformer_layer_cls_to_wrap <your_layer_class>` to the `launch.sh` script. For example, for `Qwen/Qwen3-8B`, specify `--fsdp_transformer_layer_cls_to_wrap Qwen3DecoderLayer` as an additional argument.
234
-
235
302
### Results
236
303
237
-
Here is an example result following the workflow above with slightly different hyper-parameters (We used an effective batch size of 128 by adjusting `--train_bs` and `--accum_steps` as per the available GPU memory).
304
+
Here is an example result following the workflow above with slightly different hyper-parameters (We used an effective batch size of 128 by adjusting `--per_device_train_batch_size` and `--gradient_accumulation_steps` as per the available GPU memory).
238
305
As we can see below, QAT has improved the validation perplexity.
239
306
240
307
You could get slightly different numbers depending on your hyper-parameters - however you should be able to see consistent improvement
> **_NOTE:_** QAD doesn't support FSDP1 (<https://docs.pytorch.org/docs/stable/fsdp.html>) backend - only FSDP2.
@@ -300,14 +376,15 @@ See more details on deployment of quantized model [here](../llm_ptq/README.md).
300
376
301
377
## End-to-end QLoRA with Real Quantization
302
378
303
-
[QLoRA](https://arxiv.org/pdf/2305.14314) is a technique mainly intended for further reducing the training memory requirement of LoRA. In QLoRA, the LoRA backbone weights are quantized to reduce the model footprint. Unlike QAT which uses simulated quantization, QLoRA requires real quantization. To compress the model weights after quantization, we use the `mtq.compress()` function, which currently supports FP8, FP4, and INT4 formats. This feature can be enabled by passing `--compress True` to the `launch.sh` script. For detailed configuration options and patterns, please refer to the `modelopt.torch.quantization.compress` documentation.
379
+
[QLoRA](https://arxiv.org/pdf/2305.14314) is a technique mainly intended for further reducing the training memory requirement of LoRA. In QLoRA, the LoRA backbone weights are quantized to reduce the model footprint. Unlike QAT which uses simulated quantization, QLoRA requires real quantization. To compress the model weights after quantization, we use the `mtq.compress()` function, which currently supports FP8, FP4, and INT4 formats. This feature can be enabled by passing `--compress True`. For detailed configuration options and patterns, please refer to the `modelopt.torch.quantization.compress` documentation.
304
380
305
381
To evaluate QLoRA quantized model before training, run:
306
382
307
383
```sh
308
384
# Load the HF checkpoint, quantize the model and evaluate without additional training
0 commit comments