-
Notifications
You must be signed in to change notification settings - Fork 33
Expand file tree
/
Copy pathmmd_vae.py
More file actions
132 lines (115 loc) · 5.44 KB
/
mmd_vae.py
File metadata and controls
132 lines (115 loc) · 5.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import math, os
from tensorflow.examples.tutorials.mnist import input_data
# Define some handy network layers
def lrelu(x, rate=0.1):
return tf.maximum(tf.minimum(x * rate, 0), x)
def conv2d_lrelu(inputs, num_outputs, kernel_size, stride):
conv = tf.contrib.layers.convolution2d(inputs, num_outputs, kernel_size, stride,
weights_initializer=tf.contrib.layers.xavier_initializer(),
activation_fn=tf.identity)
conv = lrelu(conv)
return conv
def conv2d_t_relu(inputs, num_outputs, kernel_size, stride):
conv = tf.contrib.layers.convolution2d_transpose(inputs, num_outputs, kernel_size, stride,
weights_initializer=tf.contrib.layers.xavier_initializer(),
activation_fn=tf.identity)
conv = tf.nn.relu(conv)
return conv
def fc_lrelu(inputs, num_outputs):
fc = tf.contrib.layers.fully_connected(inputs, num_outputs,
weights_initializer=tf.contrib.layers.xavier_initializer(),
activation_fn=tf.identity)
fc = lrelu(fc)
return fc
def fc_relu(inputs, num_outputs):
fc = tf.contrib.layers.fully_connected(inputs, num_outputs,
weights_initializer=tf.contrib.layers.xavier_initializer(),
activation_fn=tf.identity)
fc = tf.nn.relu(fc)
return fc
# Encoder and decoder use the DC-GAN architecture
def encoder(x, z_dim):
with tf.variable_scope('encoder'):
conv1 = conv2d_lrelu(x, 64, 4, 2)
conv2 = conv2d_lrelu(conv1, 128, 4, 2)
conv2 = tf.reshape(conv2, [-1, np.prod(conv2.get_shape().as_list()[1:])])
fc1 = fc_lrelu(conv2, 1024)
return tf.contrib.layers.fully_connected(fc1, z_dim, activation_fn=tf.identity)
def decoder(z, reuse=False):
with tf.variable_scope('decoder') as vs:
if reuse:
vs.reuse_variables()
fc1 = fc_relu(z, 1024)
fc2 = fc_relu(fc1, 7*7*128)
fc2 = tf.reshape(fc2, tf.stack([tf.shape(fc2)[0], 7, 7, 128]))
conv1 = conv2d_t_relu(fc2, 64, 4, 2)
output = tf.contrib.layers.convolution2d_transpose(conv1, 1, 4, 2, activation_fn=tf.sigmoid)
return output
def compute_kernel(x, y):
x_size = tf.shape(x)[0]
y_size = tf.shape(y)[0]
dim = tf.shape(x)[1]
tiled_x = tf.tile(tf.reshape(x, tf.stack([x_size, 1, dim])), tf.stack([1, y_size, 1]))
tiled_y = tf.tile(tf.reshape(y, tf.stack([1, y_size, dim])), tf.stack([x_size, 1, 1]))
return tf.exp(-tf.reduce_mean(tf.square(tiled_x - tiled_y), axis=2) / tf.cast(dim, tf.float32))
def compute_mmd(x, y):
x_kernel = compute_kernel(x, x)
y_kernel = compute_kernel(y, y)
xy_kernel = compute_kernel(x, y)
return tf.reduce_mean(x_kernel) + tf.reduce_mean(y_kernel) - 2 * tf.reduce_mean(xy_kernel)
# Convert a numpy array of shape [batch_size, height, width, 1] into a displayable array
# of shape [height*sqrt(batch_size, width*sqrt(batch_size))] by tiling the images
def convert_to_display(samples):
cnt, height, width = int(math.floor(math.sqrt(samples.shape[0]))), samples.shape[1], samples.shape[2]
samples = np.transpose(samples, axes=[1, 0, 2, 3])
samples = np.reshape(samples, [height, cnt, cnt, width])
samples = np.transpose(samples, axes=[1, 0, 2, 3])
samples = np.reshape(samples, [height*cnt, width*cnt])
return samples
mnist = input_data.read_data_sets('mnist_data')
plt.ion()
# Build the computation graph for training
z_dim = 20
train_x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1])
train_z = encoder(train_x, z_dim)
train_xr = decoder(train_z)
# Build the computation graph for generating samples
gen_z = tf.placeholder(tf.float32, shape=[None, z_dim])
gen_x = decoder(gen_z, reuse=True)
# Compare the generated z with true samples from a standard Gaussian, and compute their MMD distance
true_samples = tf.random_normal(tf.stack([200, z_dim]))
loss_mmd = compute_mmd(true_samples, train_z)
loss_nll = tf.reduce_mean(tf.square(train_xr - train_x))
loss = loss_nll + loss_mmd
trainer = tf.train.AdamOptimizer(1e-3).minimize(loss)
batch_size = 200
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# Start training
for i in range(10000):
batch_x, batch_y = mnist.train.next_batch(batch_size)
batch_x = batch_x.reshape(-1, 28, 28, 1)
_, nll, mmd = sess.run([trainer, loss_nll, loss_mmd], feed_dict={train_x: batch_x})
if i % 100 == 0:
print("Negative log likelihood is %f, mmd loss is %f" % (nll, mmd))
if i % 500 == 0:
samples = sess.run(gen_x, feed_dict={gen_z: np.random.normal(size=(100, z_dim))})
plt.imshow(convert_to_display(samples), cmap='Greys_r')
plt.show()
plt.pause(0.001)
# If latent z is 2-dimensional we visualize it by plotting latent z of different digits in different colors
if z_dim == 2:
z_list, label_list = [], []
test_batch_size = 500
for i in range(20):
batch_x, batch_y = mnist.test.next_batch(test_batch_size)
batch_x = batch_x.reshape(-1, 28, 28, 1)
z_list.append(sess.run(train_z, feed_dict={train_x: batch_x}))
label_list.append(batch_y)
z = np.concatenate(z_list, axis=0)
label = np.concatenate(label_list)
plt.scatter(z[:, 0], z[:, 1], c=label)
plt.show()