Skip to content

Commit 7f82fe2

Browse files
realAsmaclaude
andcommitted
Refactor llm_qat example with YAML configs, DistillArguments, and ModelOpt argument parser
- Add qlora_nvfp4.yaml config and expand qat/qad/finetune YAML configs with full parameter sets - Remove ptq_eval.yaml (superseded by unified quantize.py flow) - Add DistillArguments class in modelopt/torch/distill with distill, teacher_model, criterion fields and to_distill_kwargs() helper - Move --distill from TrainingArguments to DistillArguments - Remove TrainModelArguments from train.py, use DistillArguments instead - Reorder transformers.py: patching on top, arguments middle, training bottom - Include QuantizeArguments in train.py parser for ARGUMENTS.md generation - Rename output_dir -> quantize_output_dir in QuantizeArguments to avoid conflict with HF TrainingArguments.output_dir - Regenerate ARGUMENTS.md with all argument classes - Group distillation args in qad_nvfp4.yaml config - Update README.md, tests, and quantize.py for renamed field Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: realAsma <akuriparambi@nvidia.com>
1 parent df80a0f commit 7f82fe2

37 files changed

Lines changed: 2144 additions & 850 deletions

.pre-commit-config.yaml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ repos:
9393
examples/llm_eval/lm_eval_hf.py|
9494
examples/llm_eval/mmlu.py|
9595
examples/llm_eval/modeling.py|
96-
examples/llm_qat/main.py|
96+
examples/llm_qat/train.py|
9797
examples/llm_sparsity/weight_sparsity/finetune.py|
9898
examples/specdec_bench/specdec_bench/models/specbench_medusa.py|
9999
examples/speculative_decoding/main.py|
@@ -122,6 +122,21 @@ repos:
122122
args: ["-c", "pyproject.toml", "-q"]
123123
additional_dependencies: ["bandit[toml]"]
124124

125+
- repo: local
126+
hooks:
127+
- id: generate-arguments-md
128+
name: Regenerate examples/llm_qat/ARGUMENTS.md
129+
entry: bash -c 'python examples/llm_qat/train.py --generate_docs examples/llm_qat/ARGUMENTS.md'
130+
language: system
131+
files: >-
132+
(?x)^(
133+
examples/llm_qat/arguments\.py|
134+
examples/llm_qat/train\.py|
135+
modelopt/torch/opt/plugins/transformers\.py|
136+
modelopt/torch/quantization/plugins/transformers_trainer\.py
137+
)$
138+
pass_filenames: false
139+
125140
- repo: https://github.com/DavidAnson/markdownlint-cli2
126141
rev: v0.18.1
127142
hooks:

examples/llm_qad/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
Quantization-Aware Distillation (QAD) training scripts for language models using Megatron-LM. These scripts enable training quantized (e.g., NVFP4) student models with knowledge distillation from full-precision teacher models.
44

5+
> **Note:** For Hugging Face LLM QAD, see the [LLM QAT QAD section](../llm_qat/README.md#end-to-end-qad-example).
6+
57
## Overview
68

79
| Script | Purpose |

examples/llm_qat/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
.cache/

examples/llm_qat/ARGUMENTS.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Argument Reference
2+
3+
_Auto-generated — do not edit by hand._
4+
5+
## DistillArguments
6+
7+
| Argument | Type | Default | Description |
8+
|----------|------|---------|-------------|
9+
| `--distill` | `bool` | `False` | Enable training with knowledge distillation. |
10+
| `--teacher_model` | `str` | `None` | The name or path of the teacher model to use for distillation. |
11+
| `--criterion` | `str` | `"logits_loss"` | Distillation loss criterion. Currently only 'logits_loss' is supported. |
12+
13+
## DataArguments
14+
15+
| Argument | Type | Default | Description |
16+
|----------|------|---------|-------------|
17+
| `--dataset_config` | `str` | `"configs/dataset/blend.yaml"` | Path to a dataset blend YAML config file. See configs/dataset/README.md for schema documentation. |
18+
| `--train_samples` | `int` | `0` | Override train_samples from dataset config. 0 = use config value. |
19+
| `--eval_samples` | `int` | `0` | Override eval_samples from dataset config. 0 = use config value. |
20+
21+
## ModelArguments
22+
23+
| Argument | Type | Default | Description |
24+
|----------|------|---------|-------------|
25+
| `--model_name_or_path` | `str` | `"meta-llama/Llama-2-7b-hf"` | |
26+
| `--model_max_length` | `int` | `4096` | Maximum sequence length. Sequences will be right padded (and possibly truncated). |
27+
28+
## QuantizeArguments
29+
30+
| Argument | Type | Default | Description |
31+
|----------|------|---------|-------------|
32+
| `--recipe` | `str` | `None` | Path to a quantization recipe YAML file (built-in or custom). Built-in recipes can be specified by relative path, e.g. 'general/ptq/nvfp4_default-fp8_kv'. |
33+
| `--calib_size` | `int` | `512` | Specify the calibration size for quantization. The calibration dataset is used to setup the quantization scale parameters. |
34+
| `--calib_batch_size` | `int` | `1` | Batch size for calibration data during quantization. |
35+
| `--compress` | `bool` | `False` | Whether to compress the model weights after quantization for QLoRA. This is useful for reducing the model size. |
36+
| `--quantize_output_dir` | `str` | `"quantized_model"` | Directory to save the quantized model checkpoint. |
37+
38+
## TrainingArguments
39+
40+
Extends [HuggingFace TrainingArguments](https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments). Only additional/overridden arguments are shown below.
41+
42+
| Argument | Type | Default | Description |
43+
|----------|------|---------|-------------|
44+
| `--cache_dir` | `str` | `None` | |
45+
| `--lora` | `bool` | `False` | Whether to add LoRA (Low-Rank Adaptation) adapter before training. When using real quantization, the LoRA adapter must be set, as quantized weights will be frozen during training. |

examples/llm_qat/README.md

Lines changed: 179 additions & 264 deletions
Large diffs are not rendered by default.

examples/llm_qat/accelerate_config/fsdp1.yaml

Lines changed: 0 additions & 29 deletions
This file was deleted.

examples/llm_qat/arguments.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Shared argument dataclasses for llm_qat scripts (quantize.py, train.py)."""
17+
18+
from dataclasses import field
19+
20+
import transformers
21+
22+
from modelopt.torch.opt.plugins.transformers import ModelOptHFArguments
23+
24+
25+
class ModelArguments(ModelOptHFArguments):
26+
model_name_or_path: str = field(default="meta-llama/Llama-2-7b-hf")
27+
model_max_length: int = field(
28+
default=4096,
29+
metadata={
30+
"help": (
31+
"Maximum sequence length. Sequences will be right padded (and possibly truncated)."
32+
)
33+
},
34+
)
35+
36+
37+
class DataArguments(ModelOptHFArguments):
38+
dataset_config: str = field(
39+
default="configs/dataset/blend.yaml",
40+
metadata={
41+
"help": (
42+
"Path to a dataset blend YAML config file. "
43+
"See configs/dataset/README.md for schema documentation."
44+
)
45+
},
46+
)
47+
train_samples: int = field(
48+
default=0,
49+
metadata={"help": "Override train_samples from dataset config. 0 = use config value."},
50+
)
51+
eval_samples: int = field(
52+
default=0,
53+
metadata={"help": "Override eval_samples from dataset config. 0 = use config value."},
54+
)
55+
56+
57+
class TrainingArguments(ModelOptHFArguments, transformers.TrainingArguments):
58+
cache_dir: str | None = field(default=None)
59+
dataloader_drop_last: bool = field(default=True)
60+
bf16: bool = field(default=True)
61+
lora: bool = field(
62+
default=False,
63+
metadata={
64+
"help": (
65+
"Whether to add LoRA (Low-Rank Adaptation) adapter before training. When using real quantization, "
66+
"the LoRA adapter must be set, as quantized weights will be frozen during training."
67+
)
68+
},
69+
)
70+
# Sensible defaults (previously set by launch.sh)
71+
eval_strategy: str = field(default="steps")
72+
load_best_model_at_end: bool = field(default=True)
73+
save_total_limit: int = field(default=2)
74+
warmup_ratio: float = field(default=0.1)
75+
logging_steps: int = field(default=1)
76+
report_to: str = field(default="tensorboard")
77+
do_eval: bool = field(default=True)
78+
eval_accumulation_steps: int = field(default=1)
79+
learning_rate: float = field(default=1e-4)
80+
81+
82+
class QuantizeArguments(ModelOptHFArguments):
83+
recipe: str | None = field(
84+
default=None,
85+
metadata={
86+
"help": (
87+
"Path to a quantization recipe YAML file (built-in or custom). "
88+
"Built-in recipes can be specified by relative path, e.g. "
89+
"'general/ptq/nvfp4_default-fp8_kv'."
90+
),
91+
},
92+
)
93+
calib_size: int = field(
94+
default=512,
95+
metadata={
96+
"help": (
97+
"Specify the calibration size for quantization. The calibration dataset is used to"
98+
" setup the quantization scale parameters."
99+
)
100+
},
101+
)
102+
calib_batch_size: int = field(
103+
default=1,
104+
metadata={"help": "Batch size for calibration data during quantization."},
105+
)
106+
compress: bool = field(
107+
default=False,
108+
metadata={
109+
"help": (
110+
"Whether to compress the model weights after quantization for QLoRA. "
111+
"This is useful for reducing the model size."
112+
)
113+
},
114+
)
115+
quantize_output_dir: str = field(
116+
default="quantized_model",
117+
metadata={"help": "Directory to save the quantized model checkpoint."},
118+
)
File renamed without changes.
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)