-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
20 lines (15 loc) · 637 Bytes
/
train.py
File metadata and controls
20 lines (15 loc) · 637 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# Train the models
total_step = len(train_loader)
for epoch in range(num_epochs):
decoder.train()
for i, (features, captions, lengths) in enumerate(train_loader):
captions = captions.to(device)
lens=lengths.squeeze(1)
targets = pack_padded_sequence(captions, lens, batch_first=True,enforce_sorted=False)[0]
optimizer.zero_grad()
with torch.set_grad_enabled(True):
# Forward, backward and optimize
outputs = decoder(features, captions, lengths)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()