forked from MarlonBecker/DeepLearningBaseTraining
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
129 lines (101 loc) · 4.46 KB
/
main.py
File metadata and controls
129 lines (101 loc) · 4.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import os
import torch
import json
import sys
from models import getModel
from optimizer import getOptimizer
from utility.loss import smooth_crossentropy
from utility.inputData import DataLoader
from utility.dataLogger import DataLogger
from utility.utils import initialize
from utility.LRScheduler import getLRScheduler, _LRScheduler
from utility.modelSaver import ModelSaver
from utility.args import Args
"""
run:
python -m torch.distributed.run main.py
"""
Args.add_argument("--logDir", type=str, help="main directory to store logs")
Args.add_argument("--logSubDir", type=str, help="subdir in logDir to store logs for this run")
Args.add_argument("--epochs", type=int, help="Total number of epochs")
Args.add_argument("--contin", type=bool, help="Whether to continue from checkpoint. In continue mode parameters are read from params.json file, input file is ignored.")
if __name__ == "__main__":
Args.parse_args()
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
logDir = os.path.join(Args.logDir, Args.logSubDir)
if Args.contin:
with open(os.path.join(logDir, "params.json"), "r") as file:
parameters = json.load(file)
Args.parse_args_contin(parameters)
torch.distributed.init_process_group(backend="nccl", init_method="env://", rank = int(os.getenv("SLURM_PROCID", -1))) #set rank to 'SLURM_PROCID' if started with slurm, else to -1
local_rank = torch.distributed.get_rank() % torch.cuda.device_count()
torch.cuda.set_device(local_rank)
initialize() # set up seed and cudnn
if torch.distributed.get_rank() == 0:
os.makedirs(logDir, exist_ok=True)
with open(os.path.join(logDir, "params.json"), "w") as file:
json.dump(vars(Args.data), file, indent = 4)
dataLogger = DataLogger()
dataset = DataLoader()
model = getModel()(num_classes=dataset.numClasses)
model = model.cuda(local_rank)
if hasattr(torch, "compile") and int(sys.version.split(".")[1]) < 11: #compile not available for python 11 yet
model = torch.compile(model)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = torch.nn.parallel.DistributedDataParallel(model)
optimizer = getOptimizer()(model.parameters())
lrScheduler: _LRScheduler = getLRScheduler(optimizer)
modelSaver = ModelSaver(model = model, optimizer = optimizer)
startEpoch = 1
if Args.contin:
startEpoch = modelSaver.loadModel("checkpoint.model")
startEpoch += 1
model = model.cuda(local_rank)
if startEpoch >= Args.epochs:
raise RuntimeError(f"Can't continue model from epoch {startEpoch} to max epoch {Args.epochs}.")
else:
modelSaver(0)
torch.distributed.barrier() #wait until all workers are done with initialization
dataLogger.printHeader()
state = {
"model": model,
"lrScheduler": lrScheduler,
"optimizer": optimizer,
}
for epoch in range(startEpoch, Args.epochs+1):
dataset.train.sampler.set_epoch(epoch)
model.train()
numBatches = len(dataset.train)
dataLogger.startTrain(trainDataLen = numBatches)
for i, batch in enumerate(dataset.train):
lrScheduler.step(epoch-1, (i+1)/numBatches)
inputs, targets = (b.cuda(local_rank) for b in batch)
predictions = model(inputs)
loss = smooth_crossentropy(predictions, targets)
loss.mean().backward()
if Args.grad_clip_norm != 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), Args.grad_clip_norm)
optimizer.step()
optimizer.zero_grad()
with torch.no_grad():
state["loss"] = loss
state["predictions"] = predictions
state["targets"] = targets
dataLogger(state)
dataLogger.flush()
dataLogger.startTest()
model.eval()
with torch.no_grad():
for batch in dataset.test:
inputs, targets = (b.cuda(local_rank) for b in batch)
predictions = model(inputs)
loss = smooth_crossentropy(predictions, targets)
state["loss"] = loss
state["predictions"] = predictions
state["targets"] = targets
dataLogger(state)
dataLogger.flush()
modelSaver(epoch)
dataLogger.printFooter()
torch.distributed.destroy_process_group()