-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
40 lines (36 loc) · 1.38 KB
/
main.py
File metadata and controls
40 lines (36 loc) · 1.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import CNN
import ChessDataset as CD
import torch.optim as optim
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from test import test
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN.CNN().to(device)
optimizer = optim.Adam(model.parameters(),lr=1e-3)
loss = nn.MSELoss()
#205,353,531 is the limit
data = CD.ChessDataset("Dataset/cleaned_dataset", 205353531)
loader = DataLoader(data, batch_size=3000, shuffle=True, num_workers=1)
for epoch in range(2):
model.train()
total_loss = 0
for batch_id, (X, y) in enumerate(loader):
X, y = X.to(device), y.to(device)
pred = model(X)
difference = loss(pred, y)
optimizer.zero_grad()
difference.backward()
optimizer.step()
total_loss += difference.item()
if batch_id % 12000 == 0:
filename = f"results_batch{batch_id}_epoch{epoch}.pth"
torch.save(model.state_dict(), filename)
mse = test(filename)
with open("validation", "a") as file:
file.write(f"Batch:{batch_id}, epoch{epoch}, mse: {mse}\n")
print(f"Batch: {batch_id}, MSE: {mse}\n")
torch.save(model.state_dict(), f"results_200mil_epoch2.pth")
if __name__ == "__main__":
main()