-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
69 lines (57 loc) · 2.06 KB
/
train.py
File metadata and controls
69 lines (57 loc) · 2.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import os
import json
import argparse
import torch
import dataloaders
import models
import inspect
import math
from utils import losses
from utils import Logger
from utils.torchsummary import summary
from trainer import Trainer
import time
def get_instance(module, name, config, *args):
return getattr(module, config[name]['type'])(*args, **config[name]['args'])
def main(config, resume):
train_logger = Logger()
# DATA LOADERS
train_loader = get_instance(dataloaders, 'train_loader', config)
val_loader = get_instance(dataloaders, 'val_loader', config)
# MODEL
model = get_instance(models, 'arch', config, train_loader.dataset.num_classes)
print(f'\n{model}\n')
# LOSS
loss = getattr(losses, config['loss'])(ignore_index = config['ignore_index'])
# TRAINING
trainer = Trainer(
model=model,
loss=loss,
resume=resume,
config=config,
train_loader=train_loader,
val_loader=val_loader,
train_logger=train_logger)
trainer.train()
if __name__=='__main__':
# PARSE THE ARGS
parser = argparse.ArgumentParser(description='PyTorch Training')
parser.add_argument('-c', '--config', default='config.json',type=str,
help='Path to the config file (default: config.json)')
parser.add_argument('-r', '--resume', default=None, type=str,
help='Path to the .pth model checkpoint to resume training')
parser.add_argument('-d', '--device', default=None, type=str,
help='indices of GPUs to enable (default: all)')
args = parser.parse_args()
config = json.load(open(args.config))
if args.resume:
config = torch.load(args.resume)['config']
if args.device:
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
# === 时间统计开始 ===
start_time = time.time()
main(config, args.resume)
end_time = time.time()
elapsed = end_time - start_time
print(f"[Done] Total elapsed time: {elapsed:.2f} seconds "
f"({elapsed / 60:.2f} minutes, {elapsed / 3600:.2f} hours)")