-
Notifications
You must be signed in to change notification settings - Fork 53
Expand file tree
/
Copy pathtrain.py
More file actions
49 lines (37 loc) · 1.32 KB
/
train.py
File metadata and controls
49 lines (37 loc) · 1.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
The training entry script for the FastGen project. Works for both DDP and FSDP training.
"""
import argparse
import warnings
from fastgen.configs.config import BaseConfig
from fastgen.utils import instantiate
from fastgen.trainer import Trainer
import fastgen.utils.logging_utils as logger
from fastgen.utils.distributed import synchronize, clean_up
from fastgen.utils.scripts import parse_args, setup
warnings.filterwarnings(
"ignore", "Grad strides do not match bucket view strides"
) # False warning printed by PyTorch 2.6.
def main(config: BaseConfig):
# initiate the model
config.model_class.config = config.model
model = instantiate(config.model_class)
config.model_class.config = None
synchronize()
# initiate the trainer
logger.info("Initializing trainer...")
fastgen_trainer = Trainer(config)
logger.success("Trainer initialized successfully")
synchronize()
# Start training
fastgen_trainer.run(model)
synchronize()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Training")
args = parse_args(parser)
config = setup(args)
main(config)
clean_up()
logger.info("Training finished.")