-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
133 lines (109 loc) · 4.5 KB
/
train.py
File metadata and controls
133 lines (109 loc) · 4.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import math
from typing import Tuple, Dict
import numpy as np
import torch
import torch.nn as nn
from src.models import Discriminator, Generator
from src.utils import convert_float_matrix_to_int_list, generate_even_data
import matplotlib.pyplot as plt
def train(
max_int: int = 128,
batch_size: int = 16,
training_steps: int = 500,
learning_rate: float = 0.001,
print_output_every_n_steps: int = 10,
):
"""Trains the even GAN
Args:
max_int: The maximum integer our dataset goes to. It is used to set the size of the binary
lists
batch_size: The number of examples in a training batch
training_steps: The number of steps to train on.
learning_rate: The learning rate for the generator and discriminator
print_output_every_n_steps: The number of training steps before we print generated output
Returns:
generator: The trained generator model
discriminator: The trained discriminator model
"""
input_length = int(math.log(max_int, 2))
# Models
generator = Generator(input_length)
discriminator = Discriminator(input_length)
# Optimizers
generator_optimizer = torch.optim.Adam(
generator.parameters(), lr=0.001
)
discriminator_optimizer = torch.optim.Adam(
discriminator.parameters(), lr=0.001
)
# loss
loss = nn.BCELoss()
gen_loss = []
dis_loss = []
for i in range(training_steps):
# zero the gradients on each iteration
generator_optimizer.zero_grad()
# Create noisy input for generator
# Need float type instead of int
noise = torch.randint(0, 2, size=(batch_size, input_length)).float()
generated_data = generator(noise)
# Generate examples of even real data
# true labels: [1,1,1,1,1,1,....] i.e all ones
# true data: [[0,0,0,0,1,0,0],....] i.e binary code for even numbers
true_labels, true_data = generate_even_data(
max_int, batch_size=batch_size
)
true_labels = torch.tensor(true_labels).float()
true_data = torch.tensor(true_data).float()
# Train the generator
# We invert the labels here and don't train the discriminator because we want the generator
# to make things the discriminator classifies as true.
# true labels: [1,1,1,1,....]
discriminator_out_gen_data = discriminator(generated_data)
generator_loss = loss(
discriminator_out_gen_data.squeeze(), true_labels
)
gen_loss.append(generator_loss.item())
generator_loss.backward()
generator_optimizer.step()
# Train the discriminator
# Teach Discriminator to distinguish true data with true label i.e [1,1,1,1,....]
discriminator_optimizer.zero_grad()
discriminator_out_true_data = discriminator(true_data)
discriminator_loss_true_data = loss(
discriminator_out_true_data.squeeze(), true_labels
)
# add .detach() here think about this
discriminator_out_fake_data = discriminator(generated_data.detach())
fake_labels = torch.zeros(batch_size) # [0,0,0,.....]
discriminator_loss_fake_data = loss(
discriminator_out_fake_data.squeeze(), fake_labels
)
# total discriminator loss
discriminator_loss = (
discriminator_loss_true_data + discriminator_loss_fake_data
) / 2
dis_loss.append(discriminator_loss.item())
discriminator_loss.backward()
discriminator_optimizer.step()
if i % print_output_every_n_steps == 0:
output = convert_float_matrix_to_int_list(generated_data)
even_count = len(list(filter(lambda x: (x % 2 == 0), output)))
print(f"steps: {i}, output: {output}, even count: {even_count}/16, Gen Loss: {np.round(generator_loss.item(),4)}, Dis Loss: {np.round(discriminator_loss.item(),4)}")
history = {}
history['dis_loss'] = dis_loss
history['gen_loss'] = gen_loss
return generator, discriminator, history
def plot_loss(loss_history: Dict):
plt.plot(loss_history["dis_loss"], color='blue', linewidth=2, label="dis")
plt.plot(loss_history["gen_loss"], color='orange', linewidth=2, label="gen")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title("GAN Loss curve")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig("output/loss_curve.png")
if __name__ == "__main__":
g, d, history = train()
plot_loss(history)