-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLinearProbModel.py
More file actions
74 lines (61 loc) · 2.69 KB
/
LinearProbModel.py
File metadata and controls
74 lines (61 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
from torch import nn
from tqdm import tqdm
from BaseModel import BaseModel
class LinearProbModel(BaseModel):
# training parameters
BATCH_SIZE = 32
LEARNING_RATE = 3e-4
TRAIN_EPOCHS_NUM = 10
BETAS = (0.9, 0.999)
WEIGHT_DECAY = 1e-6
LOG_INTERVAL = 10
MODEL_NAME = 'LinearProbModel.pth'
TRAINED_MODELS_DIR = 'trained_models'
def __init__(self, input_dim, output_dim, batch_size=BATCH_SIZE,
learning_rate=LEARNING_RATE, train_epochs_num=TRAIN_EPOCHS_NUM,
betas=BETAS, weight_decay=WEIGHT_DECAY):
super(LinearProbModel, self).__init__()
self.batch_size = batch_size
self.learning_rate = learning_rate
self.train_epochs_num = train_epochs_num
self.classes = None
self.fc = nn.Linear(input_dim, output_dim)
self.to(self.device)
self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, betas=betas, weight_decay=weight_decay)
self.criterion = nn.CrossEntropyLoss()
def forward(self, batch):
return self.fc(batch[0].to(self.device))
def predict(self, x):
assert self.classes is not None, 'Model is not trained yet'
return self.classes[torch.argmax(self.forward(x), dim=1)]
def fit(self, train_loader, save_model=True):
self.train()
print('setting classes.')
self.classes = train_loader.dataset.dataset.classes
print(f'Training {self.MODEL_NAME.split(".")[0]}...')
for epoch in range(self.train_epochs_num):
for batch_idx, batch in enumerate(train_loader):
self.optimizer.zero_grad()
output = self.forward(batch)
loss = self.criterion(output, batch[train_loader.dataset.CLASSES_IDX].to(self.device))
loss.backward()
self.optimizer.step()
if batch_idx % self.LOG_INTERVAL == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(batch[0])}/{len(train_loader.dataset)} '
f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
print('Finished Training')
if save_model:
self.save_model()
def evaluate(self, test_loader):
correct = 0
print('Evaluating...')
self.eval()
with torch.no_grad():
for batch in tqdm(test_loader):
output = self.forward(batch)
pred = torch.argmax(output, dim=1)
correct += pred.eq(batch[test_loader.dataset.CLASSES_IDX].view_as(pred)).sum().item()
acc = correct / len(test_loader.dataset)
print(f' - acc of linear probing on test dataset: {acc:.4f}')
return acc