Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions mnist_minimal_example/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,37 @@
# amplitude for the noise augmentation
augm_sigma = 0.08
data_dir = 'mnist_data'
_default_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def unnormalize(x):
'''go from normaized data x back to the original range'''
return x * data_std + data_mean


train_data = torchvision.datasets.MNIST(data_dir, train=True, download=True,
transform=T.Compose([T.ToTensor(), lambda x: (x - data_mean) / data_std]))
test_data = torchvision.datasets.MNIST(data_dir, train=False, download=True,
transform=T.Compose([T.ToTensor(), lambda x: (x - data_mean) / data_std]))
def setup(device=_default_device,
batchsize=batch_size,
mean=data_mean,
std=data_std,
folder=data_dir,
augm_std=augm_sigma):

# Sample a fixed batch of 1024 validation examples
val_x, val_l = zip(*list(train_data[i] for i in range(1024)))
val_x = torch.stack(val_x, 0).cuda()
val_l = torch.LongTensor(val_l).cuda()
train_data = torchvision.datasets.MNIST(folder, train=True, download=True,
transform=T.Compose([T.ToTensor(), lambda x: (x - mean) / std]))
test_data = torchvision.datasets.MNIST(folder, train=False, download=True,
transform=T.Compose([T.ToTensor(), lambda x: (x - mean) / std]))

# Exclude the validation batch from the training data
train_data.data = train_data.data[1024:]
train_data.targets = train_data.targets[1024:]
# Add the noise-augmentation to the (non-validation) training data:
train_data.transform = T.Compose([train_data.transform, lambda x: x + augm_sigma * torch.randn_like(x)])
# Sample a fixed batch of 1024 validation examples
val_x, val_l = zip(*list(train_data[i] for i in range(1024)))
val_x = torch.stack(val_x, 0).to(device)
val_l = torch.LongTensor(val_l).to(device)

train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
# Exclude the validation batch from the training data
train_data.data = train_data.data[1024:]
train_data.targets = train_data.targets[1024:]
# Add the noise-augmentation to the (non-validation) training data:
train_data.transform = T.Compose([train_data.transform, lambda x: x + augm_std * torch.randn_like(x)])

train_loader = DataLoader(train_data, batch_size=batchsize, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
test_loader = DataLoader(test_data, batch_size=batchsize, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)

return train_loader, test_loader, val_x, val_l
13 changes: 8 additions & 5 deletions mnist_minimal_example/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,22 @@

import model
import data
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

cinn = model.MNIST_cINN(5e-4)
cinn.cuda()
cinn.to(DEVICE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(cinn.optimizer, milestones=[20, 40], gamma=0.1)

N_epochs = 60
t_start = time()
nll_mean = []

print('Epoch\tBatch/Total \tTime \tNLL train\tNLL val\tLR')
trainld, testld, val_x, val_l = data.setup(DEVICE)

for epoch in range(N_epochs):
for i, (x, l) in enumerate(data.train_loader):
x, l = x.cuda(), l.cuda()
for i, (x, l) in enumerate(trainld):
x, l = x.to(DEVICE), l.to(DEVICE)
z, log_j = cinn(x, l)

nll = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total
Expand All @@ -32,11 +35,11 @@

if not i % 50:
with torch.no_grad():
z, log_j = cinn(data.val_x, data.val_l)
z, log_j = cinn(val_x, val_l)
nll_val = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total

print('%.3i \t%.5i/%.5i \t%.2f \t%.6f\t%.6f\t%.2e' % (epoch,
i, len(data.train_loader),
i, len(trainld),
(time() - t_start)/60.,
np.mean(nll_mean),
nll_val.item(),
Expand Down