-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
127 lines (95 loc) · 3.5 KB
/
train.py
File metadata and controls
127 lines (95 loc) · 3.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import deepinv as dinv
import tqdm
import os
from dataset import load_dataset
from models import ModuloSEFLnet
def scale_eq_loss(physics, x, model, loss_fn=None, num_trans=3, alpha=0.1):
# scale equivariance loss
total_loss = 0.0
sat_range = 0.4
for i in range(num_trans):
# Apply a random scale transformation to x
scale_factor = torch.rand(1).item() * sat_range - (sat_range / 2) + 1.0
x_scale = x * scale_factor
y_virtual = physics(x_scale)
x_hat = model(y_virtual)
loss = loss_fn(x_scale, x_hat)
total_loss += loss
w = alpha / num_trans
return total_loss * w
DATA_ROOT = os.path.join(".", "data", "unmodnet")
MAX_VALUE = 4.0
THRESHOLD = 1.0
MODE = "floor"
n_channels = 3
epochs = 5000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset, test_dataset = load_dataset(DATA_ROOT, max_val=MAX_VALUE)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=16,
shuffle=True,
num_workers=0,
pin_memory=True,
drop_last=True,
)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
physics = dinv.physics.SpatialUnwrapping(threshold=THRESHOLD, mode=MODE).to(device)
model = ModuloSEFLnet(mx=THRESHOLD, in_channels=n_channels, out_channels=n_channels).to(
device
)
model_name = "ModuloSEFLnet"
ckpt_path = os.path.join("ckpts", model_name + ".pth")
if os.path.exists(ckpt_path):
model.load_state_dict(torch.load(ckpt_path))
print(f"Loaded checkpoint from {ckpt_path}")
else:
print(f"No checkpoint found at {ckpt_path}, training from scratch.")
fn_loss = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-5)
# print numeber of parameters in K
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters in the model: {num_params/1e3:.2f} K")
psnr_fn = dinv.loss.metric.PSNR(max_pixel=MAX_VALUE)
ssim_fn = dinv.loss.metric.SSIM(max_pixel=MAX_VALUE)
max_psnr = 0.0
for epoch in range(epochs):
model.train()
epoch_loss = 0.0
with tqdm.tqdm(
total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs}", unit="batch"
) as pbar:
for batch in train_loader:
x = batch[0].to(device)
y = physics(x)
optimizer.zero_grad()
x_rec = model(y)
loss = fn_loss(x_rec, x) + scale_eq_loss(physics, x, model, loss_fn=fn_loss)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
pbar.update(1)
pbar.set_postfix(
loss=f"{loss.item():.6f}", avg=f"{(epoch_loss / pbar.n):.6f}"
)
# save model checkpoint
save_path = os.path.join("ckpts", model_name + ".pth")
torch.save(model.state_dict(), save_path)
model.eval()
with torch.no_grad():
total_psnr = 0.0
total_ssim = 0.0
for batch in test_loader:
x = batch[0].to(device)
y = physics(x)
x_rec = model(y)
total_psnr += psnr_fn(x_rec, x).mean().item()
total_ssim += ssim_fn(x_rec, x).mean().item()
avg_psnr = total_psnr / len(test_loader)
avg_ssim = total_ssim / len(test_loader)
print(f"Test PSNR: {avg_psnr:.2f} dB, SSIM: {avg_ssim:.4f}")
if avg_psnr > max_psnr:
best_path = os.path.join("ckpts", model_name + "_best.pth")
torch.save(model.state_dict(), best_path)
max_psnr = avg_psnr