-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
96 lines (82 loc) · 3.25 KB
/
train.py
File metadata and controls
96 lines (82 loc) · 3.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch import optim
from utils_data import Dataset_Seg
from utils_eval import evaluate
from utils_model import UNet
from pathlib import Path
from utils import set_seed
def train_model(
path_to_data,
model,
device,
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 1e-5,
save_checkpoint: bool = True,
amp: bool = False,
weight_decay: float = 1e-8,
gradient_clipping: float = 1.0,
):
# Dataset
train_dataset = Dataset_Seg(path_to_data, 'train', augment=True)
val_dataset = Dataset_Seg(path_to_data, 'val')
loader_args = dict(batch_size=batch_size, pin_memory=True)
train_loader = DataLoader(train_dataset, shuffle=True, **loader_args)
val_loader = DataLoader(val_dataset, shuffle=False, **loader_args)
# Optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5)
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.BCEWithLogitsLoss()
global_step = 0
validation_score = 0
for epoch in range(1, epochs + 1):
model.train()
with tqdm(total=len(train_dataset), desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
for batch in train_loader:
images, true_masks = batch
images = images.to(device=device, dtype=torch.float32)
true_masks = true_masks.to(device=device, dtype=torch.long)
masks_pred = model(images)
loss = criterion(masks_pred.squeeze(1), true_masks.float().squeeze(1))
optimizer.zero_grad()
grad_scaler.scale(loss).backward()
grad_scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
grad_scaler.step(optimizer)
grad_scaler.update()
pbar.update(images.shape[0])
global_step += 1
pbar.set_postfix(**{'loss (batch)': loss.item()})
val_score = evaluate(model, val_loader, device, amp)
scheduler.step(val_score)
print('validation score:', val_score)
if save_checkpoint:
if val_score>validation_score:
validation_score = val_score
state_dict = model.state_dict()
torch.save(state_dict, Path(path_to_data) / 'checkpoint_best.pth')
if save_checkpoint and epoch==epochs:
state_dict = model.state_dict()
torch.save(state_dict, Path(path_to_data) / 'checkpoint_last.pth')
if __name__ == '__main__':
num_epochs = 50
batch_size = 32
learning_rate=1e-5
path_to_data = '/content/drive/MyDrive/archive'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
set_seed()
model = UNet(n_channels=3, n_classes=1, bilinear=False)
model.to(device=device)
train_model(
path_to_data = path_to_data,
model=model,
epochs=num_epochs,
batch_size=batch_size,
learning_rate=learning_rate,
device=device
)