-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
115 lines (95 loc) · 4.19 KB
/
train.py
File metadata and controls
115 lines (95 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import json
import math
from pathlib import Path
import time
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torch import optim
from PIL import Image
from torchvision import transforms
from whichway.model import get_model, get_device, DEFAULT_MODEL_PATH
class OrientationDataset(Dataset):
def __init__(self, answersheet_path: Path):
self.sample_path = answersheet_path.parent
self.samples = json.loads(answersheet_path.read_text())
self.to_tensor = transforms.ToTensor()
def __getitem__(self, index: int) -> tuple[torch.Tensor, torch.Tensor]:
assert index < len(self.samples), f"Index out of range for samples {index=} samples={len(self.samples)}"
sample = self.samples[index]
filename = self.sample_path / sample["filename"]
image = Image.open(filename).convert("L")
radians = sample["degrees"] * math.pi / 180
answer = torch.tensor([math.sin(radians), math.cos(radians)], dtype=torch.float32)
return (self.to_tensor(image), answer)
def __len__(self) -> int:
return len(self.samples)
def train(train_path: Path, val_path: Path, output_path: Path, num_epochs: int):
train_dataset = OrientationDataset(train_path)
train_loader = DataLoader(
train_dataset,
batch_size=32,
shuffle=True,
num_workers=0,
)
val_dataset = OrientationDataset(val_path)
val_loader = DataLoader(
val_dataset,
batch_size=32,
num_workers=0,
)
model = get_model()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(num_epochs):
start_time = time.monotonic()
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(get_device()), labels.to(get_device())
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
avg_loss = running_loss / len(train_loader)
model.eval()
within_5 = 0
within_10 = 0
within_20 = 0
total = 0
with torch.no_grad():
for images, labels in val_loader:
images, labels = images.to(get_device()), labels.to(get_device())
total += labels.size(0)
outputs = model(images)
sin_vals = outputs[:, 0]
cos_vals = outputs[:, 1]
predicted_angles = torch.atan2(sin_vals, cos_vals) * 180 / math.pi
true_angles = torch.atan2(labels[:, 0], labels[:, 1]) * 180 / math.pi
angle_diff = torch.abs(predicted_angles - true_angles)
angle_diff = torch.min(angle_diff, 360 - angle_diff)
within_5 += (angle_diff < 5).sum().item()
within_10 += (angle_diff < 10).sum().item()
within_20 += (angle_diff < 20).sum().item()
seconds_taken = time.monotonic() - start_time
print(f"[Epoch {epoch + 1}/{num_epochs}] after {seconds_taken:.1f} seconds:")
print(f" loss: {avg_loss:.4f}")
print(f" within 5 degrees: {within_5 / total * 100:.2f}%")
print(f" within 10 degrees: {within_10 / total * 100:.2f}%")
print(f" within 20 degrees: {within_20 / total * 100:.2f}%")
print()
torch.save(model.state_dict(), output_path)
print(f"Saved model to {output_path}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-m", "--model", type=Path, default=DEFAULT_MODEL_PATH, help="path to save model state dict")
parser.add_argument("-t", "--training", type=Path, required=True, help="path to training answersheet.json")
parser.add_argument("-v", "--validation", type=Path, required=True, help="path to validation answersheet.json")
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of training epochs to run")
args = parser.parse_args()
train(train_path=args.training, val_path=args.validation, output_path=args.model, num_epochs=args.epochs)
if __name__ == "__main__":
main()