diff --git a/experiments/cifar10/train.py b/experiments/cifar10/train.py index bb5b8e4..604df09 100644 --- a/experiments/cifar10/train.py +++ b/experiments/cifar10/train.py @@ -35,7 +35,6 @@ ##################################################################### if __name__ == '__main__': - # Load CIFAR10 data dims = [32, 32, 3] DATA_DIR = 'experiments/cifar10/data/cifar-10-batches-py' diff --git a/experiments/train_scripts.py b/experiments/train_scripts.py index a2f5a9a..2510de5 100644 --- a/experiments/train_scripts.py +++ b/experiments/train_scripts.py @@ -8,6 +8,7 @@ import re import tensorflow as tf import tensorflow.contrib.slim as slim +import cloud from .utils import parse_config_str from collections import OrderedDict @@ -329,6 +330,12 @@ def train_tan(X, dims, tfs, log_path, d_class=None, t_class=None, if FLAGS.is_test: print("LOGDIR: %s" % LOGDIR) + # Also pickle and save the TFs + # Note: Can be reloaded with standard pickle.load + tfs_pickle_path = os.path.join(log_path, 'tan', FLAGS.run_index, 'tfs.pkl') + with open(tfs_pickle_path, 'w') as f: + cloud.serialization.cloudpickle.dump(tfs, f) + # Assemble TAN model based on FLAGS tan = assemble_tan( dims, tfs, d_class=d_class, t_class=t_class, t_kwargs=t_kwargs @@ -346,8 +353,9 @@ def train_tan(X, dims, tfs, log_path, d_class=None, t_class=None, nvo, _ = slim.model_analyzer.analyze_vars(tan_vars_o, print_info=False) print("# vars: {0} gen, {1} disc, {2} other".format(nvg, nvd, nvo)) - # Initialize and save log file + # Initialize and save log file; also save dims here log_dict = create_run_log(LOGDIR, FLAGS) + log_dict['dims'] = dims # As default create ImagePlotter for routing images into Tensorboard if plotter is None and FLAGS.plot_every > 0: @@ -565,8 +573,9 @@ def train_end_model(X_train_full, Y_train_full, X_valid, Y_valid, if FLAGS.is_test: print("LOGDIR: %s" % LOGDIR) - # Initialize and save log file + # Initialize and save log file; also save dims here log_dict = create_run_log(LOGDIR, FLAGS) + log_dict['dims'] = dims # Create ImagePlotter for routing images into Tensorboard plot_names = ['plot_%s' % run_type] diff --git a/keras/keras_cifar10_example.py b/keras/keras_cifar10_example.py index 3214580..5d4bbce 100644 --- a/keras/keras_cifar10_example.py +++ b/keras/keras_cifar10_example.py @@ -13,18 +13,12 @@ import os import keras -from experiments.cifar10.train import tfs from experiments.utils import balanced_subsample from keras.datasets import cifar10 from keras.models import Sequential from keras.layers import Dense, Dropout, Activation, Flatten from keras.layers import Conv2D, MaxPooling2D from tanda_keras import TANDAImageDataGenerator -from utils import load_pretrained_tan - -TAN_PATH = # TODO: Insert path here! -CONFIG_PATH = os.path.join(TAN_PATH, 'logs', 'run_log.json') -CHECKPOINT_PATH = os.path.join(TAN_PATH, 'checkpoints', 'tan_checkpoint') batch_size = 32 @@ -32,7 +26,6 @@ epochs = 100 train_frac = 0.1 - if __name__ == '__main__': # The data, shuffled and split between train and test sets: (x_train, y_train), (x_test, y_test) = cifar10.load_data() @@ -86,12 +79,11 @@ x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 - - - print('Loading TAN for real-time data augmentation.') + # This will do preprocessing and realtime data augmentation using a TAN - tan = load_pretrained_tan(CONFIG_PATH, CHECKPOINT_PATH, tfs) - datagen = TANDAImageDataGenerator(tan) + print('Loading TAN for real-time data augmentation.') + TAN_PATH = None # TODO: Your pre-trained TAN directory path here! + datagen = TANDAImageDataGenerator(TAN_PATH) print('Training model.') # Fit the model on the batches generated by datagen.flow(). diff --git a/keras/tanda_keras.py b/keras/tanda_keras.py index c89e37e..39db083 100644 --- a/keras/tanda_keras.py +++ b/keras/tanda_keras.py @@ -5,13 +5,15 @@ from keras import backend as K from keras.preprocessing.image import ImageDataGenerator +from utils import load_pretrained_tan class TANDAImageDataGenerator(ImageDataGenerator): """Generate minibatches of image data with real-time data augmentation using a trained TAN. # Arguments - tan: trained `TAN` object. + tan: trained `TAN` object, or path to a trained `TAN` object (the + directory which contains `log` and `checkpoint`) featurewise_center: set input mean to 0 over the dataset. samplewise_center: set each sample mean to 0. featurewise_std_normalization: divide inputs by std of the dataset. @@ -57,7 +59,10 @@ def __init__(self, preprocessing_function=preprocessing_function, data_format=data_format ) - self.tan = tan + if isinstance(tan, str): + self.tan = load_pretrained_tan(tan) + else: + self.tan = tan self.session = K.get_session() def random_transform(self, x, seed=None): diff --git a/keras/utils.py b/keras/utils.py index d396749..4e2d893 100644 --- a/keras/utils.py +++ b/keras/utils.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals import json +from six import cPickle from experiments.train_scripts import GENERATORS from experiments.utils import parse_config_str @@ -14,16 +15,23 @@ from tanda.tan import PretrainedTAN from tanda.transformer import PadCropTransformer - -def load_pretrained_tan(config_path, checkpoint_path, tfs, dims=[32, 32, 3]): - # Load config - with open(config_path, 'r') as f: +def load_pretrained_tan(path): + # Load config dictionary from run log + with open(os.path.join(path, 'logs', 'run_log.json'), 'r') as f: config = json.load(f) + + # Load TFs + with open(os.path.join(path, 'tfs.pkl'), 'w') as f: + tfs = cPickle.load(f) + # Build transformer - T = PadCropTransformer(tfs, dims=dims) + T = PadCropTransformer(tfs, dims=config['dims']) + # Build generator k = T.n_actions g_class = GENERATORS[config['generator']] G = g_class(k, config['seq_len'], **parse_config_str(config['gen_config'])) + # Build TAN + checkpoint_path = os.path.join(path, 'checkpoints', 'tan_checkpoint') return PretrainedTAN(G, T, dims, K.get_session(), checkpoint_path) diff --git a/python-package-requirement.txt b/python-package-requirement.txt index 1281863..e306e3a 100644 --- a/python-package-requirement.txt +++ b/python-package-requirement.txt @@ -1,3 +1,4 @@ +cloud matplotlib numpy>=1.11 pandas @@ -5,4 +6,4 @@ pillow scikit-image>=0.13 scipy>=0.18 six -tensorflow>=1.2 +tensorflow>=1.2 \ No newline at end of file