From 7f4d5d9a9386a1fcb50fcf635b05c130d0747a98 Mon Sep 17 00:00:00 2001
From: Peter Steinbach
Date: Mon, 15 Jun 2020 12:08:21 +0200
Subject: [PATCH] training can be performed on the CPU as well
---
mnist_minimal_example/data.py | 40 +++++++++++++++++++++-------------
mnist_minimal_example/train.py | 13 ++++++-----
2 files changed, 33 insertions(+), 20 deletions(-)
diff --git a/mnist_minimal_example/data.py b/mnist_minimal_example/data.py
index 3b03689..d3ef6c1 100644
--- a/mnist_minimal_example/data.py
+++ b/mnist_minimal_example/data.py
@@ -10,27 +10,37 @@
# amplitude for the noise augmentation
augm_sigma = 0.08
data_dir = 'mnist_data'
+_default_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
def unnormalize(x):
'''go from normaized data x back to the original range'''
return x * data_std + data_mean
-train_data = torchvision.datasets.MNIST(data_dir, train=True, download=True,
- transform=T.Compose([T.ToTensor(), lambda x: (x - data_mean) / data_std]))
-test_data = torchvision.datasets.MNIST(data_dir, train=False, download=True,
- transform=T.Compose([T.ToTensor(), lambda x: (x - data_mean) / data_std]))
+def setup(device=_default_device,
+ batchsize=batch_size,
+ mean=data_mean,
+ std=data_std,
+ folder=data_dir,
+ augm_std=augm_sigma):
-# Sample a fixed batch of 1024 validation examples
-val_x, val_l = zip(*list(train_data[i] for i in range(1024)))
-val_x = torch.stack(val_x, 0).cuda()
-val_l = torch.LongTensor(val_l).cuda()
+ train_data = torchvision.datasets.MNIST(folder, train=True, download=True,
+ transform=T.Compose([T.ToTensor(), lambda x: (x - mean) / std]))
+ test_data = torchvision.datasets.MNIST(folder, train=False, download=True,
+ transform=T.Compose([T.ToTensor(), lambda x: (x - mean) / std]))
-# Exclude the validation batch from the training data
-train_data.data = train_data.data[1024:]
-train_data.targets = train_data.targets[1024:]
-# Add the noise-augmentation to the (non-validation) training data:
-train_data.transform = T.Compose([train_data.transform, lambda x: x + augm_sigma * torch.randn_like(x)])
+ # Sample a fixed batch of 1024 validation examples
+ val_x, val_l = zip(*list(train_data[i] for i in range(1024)))
+ val_x = torch.stack(val_x, 0).to(device)
+ val_l = torch.LongTensor(val_l).to(device)
-train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
-test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
+ # Exclude the validation batch from the training data
+ train_data.data = train_data.data[1024:]
+ train_data.targets = train_data.targets[1024:]
+ # Add the noise-augmentation to the (non-validation) training data:
+ train_data.transform = T.Compose([train_data.transform, lambda x: x + augm_std * torch.randn_like(x)])
+
+ train_loader = DataLoader(train_data, batch_size=batchsize, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
+ test_loader = DataLoader(test_data, batch_size=batchsize, shuffle=False, num_workers=4, pin_memory=True, drop_last=True)
+
+ return train_loader, test_loader, val_x, val_l
diff --git a/mnist_minimal_example/train.py b/mnist_minimal_example/train.py
index ac15b97..1951121 100644
--- a/mnist_minimal_example/train.py
+++ b/mnist_minimal_example/train.py
@@ -8,9 +8,10 @@
import model
import data
+DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cinn = model.MNIST_cINN(5e-4)
-cinn.cuda()
+cinn.to(DEVICE)
scheduler = torch.optim.lr_scheduler.MultiStepLR(cinn.optimizer, milestones=[20, 40], gamma=0.1)
N_epochs = 60
@@ -18,9 +19,11 @@
nll_mean = []
print('Epoch\tBatch/Total \tTime \tNLL train\tNLL val\tLR')
+trainld, testld, val_x, val_l = data.setup(DEVICE)
+
for epoch in range(N_epochs):
- for i, (x, l) in enumerate(data.train_loader):
- x, l = x.cuda(), l.cuda()
+ for i, (x, l) in enumerate(trainld):
+ x, l = x.to(DEVICE), l.to(DEVICE)
z, log_j = cinn(x, l)
nll = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total
@@ -32,11 +35,11 @@
if not i % 50:
with torch.no_grad():
- z, log_j = cinn(data.val_x, data.val_l)
+ z, log_j = cinn(val_x, val_l)
nll_val = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total
print('%.3i \t%.5i/%.5i \t%.2f \t%.6f\t%.6f\t%.2e' % (epoch,
- i, len(data.train_loader),
+ i, len(trainld),
(time() - t_start)/60.,
np.mean(nll_mean),
nll_val.item(),