diff --git a/mnist_minimal_example/data.py b/mnist_minimal_example/data.py index 3b03689..d3ef6c1 100644 --- a/mnist_minimal_example/data.py +++ b/mnist_minimal_example/data.py @@ -10,27 +10,37 @@ # amplitude for the noise augmentation augm_sigma = 0.08 data_dir = 'mnist_data' +_default_device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') def unnormalize(x): '''go from normaized data x back to the original range''' return x * data_std + data_mean -train_data = torchvision.datasets.MNIST(data_dir, train=True, download=True, - transform=T.Compose([T.ToTensor(), lambda x: (x - data_mean) / data_std])) -test_data = torchvision.datasets.MNIST(data_dir, train=False, download=True, - transform=T.Compose([T.ToTensor(), lambda x: (x - data_mean) / data_std])) +def setup(device=_default_device, + batchsize=batch_size, + mean=data_mean, + std=data_std, + folder=data_dir, + augm_std=augm_sigma): -# Sample a fixed batch of 1024 validation examples -val_x, val_l = zip(*list(train_data[i] for i in range(1024))) -val_x = torch.stack(val_x, 0).cuda() -val_l = torch.LongTensor(val_l).cuda() + train_data = torchvision.datasets.MNIST(folder, train=True, download=True, + transform=T.Compose([T.ToTensor(), lambda x: (x - mean) / std])) + test_data = torchvision.datasets.MNIST(folder, train=False, download=True, + transform=T.Compose([T.ToTensor(), lambda x: (x - mean) / std])) -# Exclude the validation batch from the training data -train_data.data = train_data.data[1024:] -train_data.targets = train_data.targets[1024:] -# Add the noise-augmentation to the (non-validation) training data: -train_data.transform = T.Compose([train_data.transform, lambda x: x + augm_sigma * torch.randn_like(x)]) + # Sample a fixed batch of 1024 validation examples + val_x, val_l = zip(*list(train_data[i] for i in range(1024))) + val_x = torch.stack(val_x, 0).to(device) + val_l = torch.LongTensor(val_l).to(device) -train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) -test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) + # Exclude the validation batch from the training data + train_data.data = train_data.data[1024:] + train_data.targets = train_data.targets[1024:] + # Add the noise-augmentation to the (non-validation) training data: + train_data.transform = T.Compose([train_data.transform, lambda x: x + augm_std * torch.randn_like(x)]) + + train_loader = DataLoader(train_data, batch_size=batchsize, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) + test_loader = DataLoader(test_data, batch_size=batchsize, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) + + return train_loader, test_loader, val_x, val_l diff --git a/mnist_minimal_example/train.py b/mnist_minimal_example/train.py index ac15b97..1951121 100644 --- a/mnist_minimal_example/train.py +++ b/mnist_minimal_example/train.py @@ -8,9 +8,10 @@ import model import data +DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') cinn = model.MNIST_cINN(5e-4) -cinn.cuda() +cinn.to(DEVICE) scheduler = torch.optim.lr_scheduler.MultiStepLR(cinn.optimizer, milestones=[20, 40], gamma=0.1) N_epochs = 60 @@ -18,9 +19,11 @@ nll_mean = [] print('Epoch\tBatch/Total \tTime \tNLL train\tNLL val\tLR') +trainld, testld, val_x, val_l = data.setup(DEVICE) + for epoch in range(N_epochs): - for i, (x, l) in enumerate(data.train_loader): - x, l = x.cuda(), l.cuda() + for i, (x, l) in enumerate(trainld): + x, l = x.to(DEVICE), l.to(DEVICE) z, log_j = cinn(x, l) nll = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total @@ -32,11 +35,11 @@ if not i % 50: with torch.no_grad(): - z, log_j = cinn(data.val_x, data.val_l) + z, log_j = cinn(val_x, val_l) nll_val = torch.mean(z**2) / 2 - torch.mean(log_j) / model.ndim_total print('%.3i \t%.5i/%.5i \t%.2f \t%.6f\t%.6f\t%.2e' % (epoch, - i, len(data.train_loader), + i, len(trainld), (time() - t_start)/60., np.mean(nll_mean), nll_val.item(),