-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathlayout.py
More file actions
119 lines (90 loc) · 4.12 KB
/
layout.py
File metadata and controls
119 lines (90 loc) · 4.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import argparse
import csv
import numpy as np
import tqdm
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from torch.utils.data.sampler import SubsetRandomSampler
from UI_embedding.plotter import plot_loss
from autoencoder import ScreenLayoutDataset, LayoutAutoEncoder, LayoutTrainer
from autoencoder import ScreenVisualLayout, ScreenVisualLayoutDataset, ImageAutoEncoder, ImageTrainer
# file that runs training of the layout autoencoder
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", required=True, type=str, help="dataset of screens to train on")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="traces in a batch")
parser.add_argument("-e", "--epochs", type=int, default=10, help="number of epochs")
parser.add_argument("-r", "--rate", type=float, default=0.001, help="learning rate")
parser.add_argument("-t", "--type", type=int, default=0, help="0 to create layout autoencoder, 1 to create visual autoencoder")
args = parser.parse_args()
if args.type == 0:
dataset = ScreenLayoutDataset(args.dataset)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.1 * dataset_size))
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=test_sampler)
model = LayoutAutoEncoder()
model.cuda()
trainer = LayoutTrainer(model, train_loader, test_loader, args.rate)
train_loss_data = []
test_loss_data = []
for epoch in tqdm.tqdm(range(args.epochs)):
print("--------")
print(str(epoch) + " loss:")
train_loss = trainer.train(epoch)
print(train_loss)
print("--------")
train_loss_data.append(train_loss)
test_loss = trainer.test(epoch)
test_loss_data.append(test_loss)
print(test_loss)
print("--------")
if (epoch%50)==0:
print("saved on epoch " + str(epoch))
trainer.save(epoch)
plot_loss(train_loss_data, test_loss_data, "output/autoencoder")
trainer.save(args.epochs, "output/autoencoder")
elif args.type == 1:
dataset = ScreenVisualLayoutDataset(args.dataset)
dataset_size = len(dataset)
indices = list(range(dataset_size))
split = int(np.floor(0.1 * dataset_size))
np.random.shuffle(indices)
train_indices, val_indices = indices[split:], indices[:split]
# Creating PT data samplers and loaders:
train_sampler = SubsetRandomSampler(train_indices)
test_sampler = SubsetRandomSampler(val_indices)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler)
test_loader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, sampler=test_sampler)
model = ImageAutoEncoder()
model.cuda()
trainer = ImageTrainer(model, train_loader, test_loader, args.rate)
train_loss_data = []
test_loss_data = []
for epoch in tqdm.tqdm(range(args.epochs)):
print("--------")
print(str(epoch) + " loss:")
train_loss = trainer.train(epoch)
print(train_loss)
print("--------")
train_loss_data.append(train_loss)
test_loss = trainer.test(epoch)
test_loss_data.append(test_loss)
print(test_loss)
print("--------")
if (epoch%50)==0:
print("saved on epoch " + str(epoch))
trainer.save(epoch, "output/visual_encoder_fast")
plot_loss(train_loss_data, test_loss_data, "output/visual_encoder_fast")
trainer.save(args.epochs, "output/visual_encoder_fast")
with open("output/visual_encoder_fast.csv", 'w', newline='') as myfile:
wr = csv.writer(myfile)
for row in range(len(train_loss_data)):
wr.writerow([train_loss_data[row], test_loss_data[row]])