-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTrace.py
More file actions
180 lines (147 loc) · 7.39 KB
/
Trace.py
File metadata and controls
180 lines (147 loc) · 7.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import time, os, argparse, pdb, pprint,torch
import torch.cuda.nvtx as nvtx
from tqdm import tqdm
from src.util import customDict, greedyDecode, log, getUniqueFolder, rebatch
from src.transformer import *
from src.noamOpt import NoamOpt
from src.dataIterator import BatchIterator
from src.batch import Batch
from src.labelSmoothing import LabelSmoothing
from src.lossCompute import MultiGPULossCompute
from src.dataLoader import generateDataloaders
def runEpoch(dataIter, model, lossCompute, logfile):
startInit=time.time()
start=time.time()
totalTokens=0
totalLoss=0
tokens=0
log("[epoch] \n", logfile)
for i, batch in enumerate(dataIter):
nvtx.range_push("Batch: {}".format(i))
nvtx.range_push("Forward pass")
out=model.forward(batch.src, batch.trg, batch.srcMask, batch.trgMask)
nvtx.range_push("Loss compute")
loss=lossCompute(out, batch.trgY, batch.nTokens)
totalLoss+=loss
totalTokens+=batch.nTokens
tokens+=batch.nTokens
nvtx.range_pop()
nvtx.range_pop()
if i%50==1:
elapsed=time.time()-start
elapsedInit=time.time()-startInit
elapsed=torch.LongTensor([elapsed])
if torch.LongTensor([elapsed])==0:
continue
message="Epoch: %d Loss: %f TPS %f Batch time elapsed: %d total time elapsed %d" %(i, loss/batch.nTokens, tokens/elapsed, elapsed, elapsedInit)
log(message+"\n", logfile)
print(message)
tokens=0
start=time.time()
return totalLoss/totalTokens
global maxSrcInBatch, maxTgtInBatch
def batchSizeFn(new, count, sofar):
global maxSrcInBatch, maxTgtInBatch
if count==1:
maxSrcInBatch=0
maxTgtInBatch=0
maxSrcInBatch=max(maxSrcInBatch, len(new.src))
maxTgtInBatch=max(maxTgtInBatch, len(new.trg)+2)
srcElements=count*maxSrcInBatch
tgtElements=count*maxTgtInBatch
return max(srcElements, tgtElements)
def train(**kwargs):
globals().update(kwargs)
print('loading data...')
SRC, TGT, train, val, test=generateDataloaders(**dataArgs)
padIdx=TGT.vocab.stoi["<blank>"]
model=createModel(len(SRC.vocab), len(TGT.vocab),modelType, N=6, dModel=512, mode=trainMode, bias=True, **manmpArgs)
model.cuda()
criterion=LabelSmoothing(size=len(TGT.vocab), paddingIdx=padIdx, smoothing=0.1)
criterion.cuda()
trainIterator=BatchIterator(train, batch_size=batchSize, device=torch.device(devices[0]), repeat=False, sort_key=lambda x: (len(x.src), len(x.trg)), batch_size_fn=batchSizeFn, train=True)
validIterator=BatchIterator(val, batch_size=batchSize, device=torch.device(devices[0]), repeat=False, sort_key=lambda x:(len(x.src), len(x.trg)), batch_size_fn=batchSizeFn, train=False)
modelParallel=nn.DataParallel(model, device_ids=devices)
modelOptimizer=NoamOpt(model.srcEmbed[0].embeddingDim, 1, 2000, torch.optim.Adam(model.parameters(), lr=0, betas=(0.9,0.98),eps=1e-9))
folder=getUniqueFolder('./models/', 'model')
if not (os.path.exists(folder)):
os.mkdir(folder)
logfile=os.path.join(folder, 'logfile')
print("Training model")
validLosses=[]
for epoch in tqdm(range(epochs)):
nvtx.range_push("EPOCH: {}".format(epoch))
modelParallel.train()
log("Train epoch: " +str(epoch) + "\n", logfile)
runEpoch((rebatch(padIdx, b) for b in trainIterator), modelParallel, MultiGPULossCompute(model.generator, criterion, devices=devices, opt=modelOptimizer), logfile)
nvtx.range_push("VAL EPOCH: {}".format(epoch))
modelParallel.eval()
log("Validation Epoch: "+str(epoch) + "\n", logfile)
loss=runEpoch((rebatch(padIdx, b) for b in validIterator), modelParallel, MultiGPULossCompute(model.generator, criterion, devices=devices, opt=None), logfile)
validLosses.append(loss)
nvtx.range_pop()
checkpoint={
'stateDict': model.state_dict(),
'setting':kwargs,
'validationLoss':loss
}
if saveMode =='all':
modelName=saveMode + '_loss_{loss:3.3f}.chkpt'.format(loss)
torch.save(checkpoint, os.path.join(folder, modelName))
elif saveMode=='best':
modelName=saveMode + '.chkpt'
if loss <= max(validLosses):
torch.save(checkpoint, os.path.join(folder, modelName))
print('*'*8, "CHECKPOINT UPDATED", '*'*8)
print(loss)
nvtx.range_pop()
for i, batch in enumerate(validIterator):
src=batch.src.transpose(0,1)[:1]
srcMask=(src!=SRC.vocab.stoi["<blank>"]).unsqueeze(-2)
out=greedyDecode(model, src, srcMask, maxlen=60, startSymbol=TGT.vocab.stoi["<s>"])
print("Translation: ", end="\t")
for i in range(0, out.size(0)):
for j in range(0, out.size(1)):
sym=TGT.vocab.itos[out[i,j]]
if sym=="</s>":
break
print(sym, end=" ")
print()
print("Target: ", end="\t")
for i in range(1, batch.trg.size(0)):
sym=TGT.vocab.itos[batch.trg.data[i,0]]
if sym=="</s>":
break
print(sym, end=" ")
print()
break
if __name__=="__main__":
parser=argparse.ArgumentParser()
parser.set_defaults(method=train)
deviceList=[x for x in range(torch.cuda.device_count())]
parser.add_argument('-devices', type=int, default=deviceList, nargs='+', help="A list of GPUs to use")
parser.add_argument('-datapath', type=str, default="data/", help="path where data is kept")
parser.add_argument('-batchSize', type=int, default=10000, help="batch size")
parser.add_argument('-epochs', type=int, default=5, help="number of training epochs")
parser.add_argument('-saveMode', type=str, default='best', help='save state dicts of model. [all for all epochs, best for latest epoch]')
parser.add_argument('-modelType', type=str, default='small', help="small for transformer with 8 heads and 512 dimensions, large for 16heads and 1024 dimensions")
parser.add_argument('-sourceLang', type=str, default='en', help="source language")
parser.add_argument('-targetLang', type=str, default='fr', help='target language')
parser.add_argument('-trainMode', type=str, default='none', help='mode')
parser.add_argument('-activationBits', type=int, default=8, help='activation bits if in case of quantization training')
parser.add_argument('-weightBits', type=int, default=16, help='weight bits if in case of quantization training/inference')
parser.add_argument('-requantizeOutputs', type=bool, default=False, help='requantize outputs?')
arguments = parser.parse_args()
assert arguments.trainMode in ['manmp', 'automp', 'none'], "The trainMode must be one of 'manmp', 'automp', or 'none'"
assert arguments.saveMode in ['all', 'best'], "the save mode should be either 'all' or 'best'"
manmpArgsList=['activationBits', 'weightBits', 'requantizeOutputs']
dataArgsList=['datapath', 'sourceLang', 'targetLang']
allParams=vars(arguments)
cleanedParams={}
cleanedParams['dataArgs']=customDict(dataArgsList, allParams)
cleanedParams['manmpArgs']=customDict(manmpArgsList, allParams)
for keys in allParams:
if keys not in manmpArgsList+dataArgsList:
cleanedParams[keys]=allParams[keys]
pprint.pprint(cleanedParams, width=3)
arguments.method(**cleanedParams)