-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
38 lines (29 loc) · 1.08 KB
/
train.py
File metadata and controls
38 lines (29 loc) · 1.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import random
import torch
import torch.nn.functional as F
from torch.nn import NLLLoss
from src import train_loop, configure, get_model, get_dataloaders
random.seed(42)
torch.manual_seed(42)
if __name__ == "__main__":
args = configure()
training_dataloader, validation_dataloader = get_dataloaders(args)
model, optimizer = get_model(args, return_config=False, return_optimizer=True)
@train_loop(
model = model,
optimizer = optimizer,
args = args
)
def train(model, batch):
input_ids=batch.masked_input_ids
attention_mask=batch.attention_mask
output_ids=batch.input_ids
masked_mask=batch.masked_mask
out = model(input_ids=input_ids, attention_mask=attention_mask)
logits = out.logits.view(-1, out.logits.shape[-1])
preds = F.log_softmax(logits, dim=-1)
loss_fn = NLLLoss(reduction='none')
loss = loss_fn(preds, output_ids.flatten())
loss = ((masked_mask.flatten() * loss)).sum() / masked_mask.sum()
return loss
train(training_dataloader, validation_dataloader)