From 9e8646449f1115b44d431aa2c642284b8468f316 Mon Sep 17 00:00:00 2001 From: ajratner Date: Tue, 5 Dec 2017 18:24:56 -0800 Subject: [PATCH 1/6] Loading pretrained TAN from single path --- keras/keras_cifar10_example.py | 16 ++++------------ keras/tanda_keras.py | 5 +++-- keras/utils.py | 18 +++++++++++++----- pretrained/cifar10/run_log.json | 2 ++ 4 files changed, 22 insertions(+), 19 deletions(-) diff --git a/keras/keras_cifar10_example.py b/keras/keras_cifar10_example.py index a6d4bc1..3e722e8 100644 --- a/keras/keras_cifar10_example.py +++ b/keras/keras_cifar10_example.py @@ -12,18 +12,11 @@ import os import keras -from experiments.cifar10.train import tfs from keras.datasets import cifar10 from keras.models import Sequential from keras.layers import Dense, Dropout, Activation, Flatten from keras.layers import Conv2D, MaxPooling2D from tanda_keras import TANDAImageDataGenerator -from utils import load_pretrained_tan - - -TAN_PATH = os.path.join(os.environ['TANDAHOME'], 'pretrained', 'cifar10') -CONFIG_PATH = os.path.join(TAN_PATH, 'run_log.json') -CHECKPOINT_PATH = os.path.join(TAN_PATH, 'checkpoints', 'tan_checkpoint') batch_size = 32 @@ -77,12 +70,11 @@ x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 - - - print('Loading TAN for real-time data augmentation.') + # This will do preprocessing and realtime data augmentation using a TAN - tan = load_pretrained_tan(CONFIG_PATH, CHECKPOINT_PATH, tfs) - datagen = TANDAImageDataGenerator(tan) + print('Loading TAN for real-time data augmentation.') + tan_path = os.path.join(os.environ['TANDAHOME'], 'pretrained', 'cifar10') + datagen = TANDAImageDataGenerator(tan_path) print('Training model.') # Fit the model on the batches generated by datagen.flow(). diff --git a/keras/tanda_keras.py b/keras/tanda_keras.py index c89e37e..4a00e21 100644 --- a/keras/tanda_keras.py +++ b/keras/tanda_keras.py @@ -5,6 +5,7 @@ from keras import backend as K from keras.preprocessing.image import ImageDataGenerator +from utils import load_pretrained_tan class TANDAImageDataGenerator(ImageDataGenerator): @@ -36,7 +37,7 @@ class TANDAImageDataGenerator(ImageDataGenerator): """ def __init__(self, - tan, + tan_path, featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, @@ -57,7 +58,7 @@ def __init__(self, preprocessing_function=preprocessing_function, data_format=data_format ) - self.tan = tan + self.tan = load_pretrained_tan(tan_path) self.session = K.get_session() def random_transform(self, x, seed=None): diff --git a/keras/utils.py b/keras/utils.py index d396749..a333a06 100644 --- a/keras/utils.py +++ b/keras/utils.py @@ -4,6 +4,7 @@ from __future__ import unicode_literals import json +from import_lib import import_module from experiments.train_scripts import GENERATORS from experiments.utils import parse_config_str @@ -14,16 +15,23 @@ from tanda.tan import PretrainedTAN from tanda.transformer import PadCropTransformer - -def load_pretrained_tan(config_path, checkpoint_path, tfs, dims=[32, 32, 3]): - # Load config - with open(config_path, 'r') as f: +def load_pretrained_tan(path): + # Load config dictionary from run log + with open(os.path.join(path, 'run_log.json'), 'r') as f: config = json.load(f) + + # Load TFs + # Assume they are present in config['train_module'] as list called tfs + tfs = import_module(config['train_module']).tfs + # Build transformer - T = PadCropTransformer(tfs, dims=dims) + T = PadCropTransformer(tfs, dims=config['dims']) + # Build generator k = T.n_actions g_class = GENERATORS[config['generator']] G = g_class(k, config['seq_len'], **parse_config_str(config['gen_config'])) + # Build TAN + checkpoint_path = os.path.join(path, 'checkpoints', 'tan_checkpoint') return PretrainedTAN(G, T, dims, K.get_session(), checkpoint_path) diff --git a/pretrained/cifar10/run_log.json b/pretrained/cifar10/run_log.json index fd410f2..e93f64b 100644 --- a/pretrained/cifar10/run_log.json +++ b/pretrained/cifar10/run_log.json @@ -2,6 +2,7 @@ "batch_size": 32, "commit_hash": "b'9cdfe57\\n'", "date_stamp": null, + "dims": [32, 32, 3], "disc_lr": 1e-05, "discriminator": "dcnn", "end_batch_size": 50, @@ -53,6 +54,7 @@ "tan_checkpoint_path": null, "time_stamp": null, "train_disc": true, + "train_module": "experiments.cifar10.train", "transform_validation_set": false, "transformer": "image", "validation_set": true From 22820788e279744913286a25f0ef7387aee47f50 Mon Sep 17 00:00:00 2001 From: ajratner Date: Tue, 5 Dec 2017 18:39:11 -0800 Subject: [PATCH 2/6] Sneaking dims and module name into run_log --- experiments/cifar10/train.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/experiments/cifar10/train.py b/experiments/cifar10/train.py index e80ef21..6cad946 100644 --- a/experiments/cifar10/train.py +++ b/experiments/cifar10/train.py @@ -3,6 +3,9 @@ from __future__ import print_function from __future__ import unicode_literals +import sys +import re + from .dataset import load_cifar10_data from experiments.train_scripts import flags, select_fold, train from experiments.tfs.image import * @@ -35,7 +38,6 @@ ##################################################################### if __name__ == '__main__': - # Load CIFAR10 data dims = [32, 32, 3] DATA_DIR = 'experiments/cifar10/data/cifar-10-batches-py' @@ -45,6 +47,12 @@ if FLAGS.n_folds > 0: X_train, Y_train = select_fold(X_train, Y_train) + # Make sure dims and current module name is included in the run log + # Note: this is currently kind of hackey, should clean up... + FLAGS.__flags['train_module'] = re.sub(r'\/', '.', + re.sub(r'\.py$', '', re.sub(r'.*tanda/', '', __file__))) + FLAGS.__flags['dims'] = dims + # Run training scripts train(X_train, dims, tfs, Y_train=Y_train, X_valid=X_valid, Y_valid=Y_valid, n_classes=10) From 2f11687741963047176656c3ac310f0bbf4bacdf Mon Sep 17 00:00:00 2001 From: ajratner Date: Wed, 6 Dec 2017 12:39:53 -0800 Subject: [PATCH 3/6] Addressing PR comments --- keras/tanda_keras.py | 10 +++++++--- keras/utils.py | 2 +- pretrained/cifar10/{ => logs}/run_log.json | 0 3 files changed, 8 insertions(+), 4 deletions(-) rename pretrained/cifar10/{ => logs}/run_log.json (100%) diff --git a/keras/tanda_keras.py b/keras/tanda_keras.py index 4a00e21..39db083 100644 --- a/keras/tanda_keras.py +++ b/keras/tanda_keras.py @@ -12,7 +12,8 @@ class TANDAImageDataGenerator(ImageDataGenerator): """Generate minibatches of image data with real-time data augmentation using a trained TAN. # Arguments - tan: trained `TAN` object. + tan: trained `TAN` object, or path to a trained `TAN` object (the + directory which contains `log` and `checkpoint`) featurewise_center: set input mean to 0 over the dataset. samplewise_center: set each sample mean to 0. featurewise_std_normalization: divide inputs by std of the dataset. @@ -37,7 +38,7 @@ class TANDAImageDataGenerator(ImageDataGenerator): """ def __init__(self, - tan_path, + tan, featurewise_center=False, samplewise_center=False, featurewise_std_normalization=False, @@ -58,7 +59,10 @@ def __init__(self, preprocessing_function=preprocessing_function, data_format=data_format ) - self.tan = load_pretrained_tan(tan_path) + if isinstance(tan, str): + self.tan = load_pretrained_tan(tan) + else: + self.tan = tan self.session = K.get_session() def random_transform(self, x, seed=None): diff --git a/keras/utils.py b/keras/utils.py index a333a06..8d6854f 100644 --- a/keras/utils.py +++ b/keras/utils.py @@ -17,7 +17,7 @@ def load_pretrained_tan(path): # Load config dictionary from run log - with open(os.path.join(path, 'run_log.json'), 'r') as f: + with open(os.path.join(path, 'logs', 'run_log.json'), 'r') as f: config = json.load(f) # Load TFs diff --git a/pretrained/cifar10/run_log.json b/pretrained/cifar10/logs/run_log.json similarity index 100% rename from pretrained/cifar10/run_log.json rename to pretrained/cifar10/logs/run_log.json From c76be20f998f4a928e8d2b88f8566e94767e2bf2 Mon Sep 17 00:00:00 2001 From: ajratner Date: Thu, 7 Dec 2017 11:02:51 -0800 Subject: [PATCH 4/6] Now we pickle the TFs using the cloud lib (a good logging step anyway) --- experiments/cifar10/train.py | 6 ------ experiments/train_scripts.py | 14 ++++++++++++-- keras/utils.py | 6 +++--- 3 files changed, 15 insertions(+), 11 deletions(-) diff --git a/experiments/cifar10/train.py b/experiments/cifar10/train.py index 6cad946..abe208d 100644 --- a/experiments/cifar10/train.py +++ b/experiments/cifar10/train.py @@ -47,12 +47,6 @@ if FLAGS.n_folds > 0: X_train, Y_train = select_fold(X_train, Y_train) - # Make sure dims and current module name is included in the run log - # Note: this is currently kind of hackey, should clean up... - FLAGS.__flags['train_module'] = re.sub(r'\/', '.', - re.sub(r'\.py$', '', re.sub(r'.*tanda/', '', __file__))) - FLAGS.__flags['dims'] = dims - # Run training scripts train(X_train, dims, tfs, Y_train=Y_train, X_valid=X_valid, Y_valid=Y_valid, n_classes=10) diff --git a/experiments/train_scripts.py b/experiments/train_scripts.py index 6122744..c3fbf7a 100644 --- a/experiments/train_scripts.py +++ b/experiments/train_scripts.py @@ -8,6 +8,8 @@ import re import tensorflow as tf import tensorflow.contrib.slim as slim +import sys +import cloud from .utils import parse_config_str from collections import OrderedDict @@ -329,6 +331,12 @@ def train_tan(X, dims, tfs, log_path, d_class=None, t_class=None, if FLAGS.is_test: print("LOGDIR: %s" % LOGDIR) + # Also pickle and save the TFs + # Note: Can be reloaded with standard pickle.load + tfs_pickle_path = os.path.join(log_path, 'tan', FLAGS.run_index, 'tfs.pkl') + with open(tfs_pickle_path, 'w') as f: + cloud.serialization.cloudpickle.dump(tfs, f) + # Assemble TAN model based on FLAGS tan = assemble_tan( dims, tfs, d_class=d_class, t_class=t_class, t_kwargs=t_kwargs @@ -346,8 +354,9 @@ def train_tan(X, dims, tfs, log_path, d_class=None, t_class=None, nvo, _ = slim.model_analyzer.analyze_vars(tan_vars_o, print_info=False) print("# vars: {0} gen, {1} disc, {2} other".format(nvg, nvd, nvo)) - # Initialize and save log file + # Initialize and save log file; also save dims here log_dict = create_run_log(LOGDIR, FLAGS) + log_dict['dims'] = dims # As default create ImagePlotter for routing images into Tensorboard if plotter is None and FLAGS.plot_every > 0: @@ -565,8 +574,9 @@ def train_end_model(X_train_full, Y_train_full, X_valid, Y_valid, if FLAGS.is_test: print("LOGDIR: %s" % LOGDIR) - # Initialize and save log file + # Initialize and save log file; also save dims here log_dict = create_run_log(LOGDIR, FLAGS) + log_dict['dims'] = dims # Create ImagePlotter for routing images into Tensorboard plot_names = ['plot_%s' % run_type] diff --git a/keras/utils.py b/keras/utils.py index 8d6854f..3535086 100644 --- a/keras/utils.py +++ b/keras/utils.py @@ -4,7 +4,7 @@ from __future__ import unicode_literals import json -from import_lib import import_module +import pickle from experiments.train_scripts import GENERATORS from experiments.utils import parse_config_str @@ -21,8 +21,8 @@ def load_pretrained_tan(path): config = json.load(f) # Load TFs - # Assume they are present in config['train_module'] as list called tfs - tfs = import_module(config['train_module']).tfs + with open(os.path.join(path, 'tfs.pkl'), 'w') as f: + tfs = pickle.load(f) # Build transformer T = PadCropTransformer(tfs, dims=config['dims']) From 5feb73590be3a09125e3e6eeecc9f661e61e7da5 Mon Sep 17 00:00:00 2001 From: ajratner Date: Tue, 12 Dec 2017 22:49:38 -0800 Subject: [PATCH 5/6] Addressing review comments --- experiments/cifar10/train.py | 3 +- experiments/train_scripts.py | 1 - keras/utils.py | 2 +- pretrained/cifar10/logs/run_log.json | 61 ---------------------------- python-package-requirement.txt | 1 + 5 files changed, 3 insertions(+), 65 deletions(-) delete mode 100644 pretrained/cifar10/logs/run_log.json diff --git a/experiments/cifar10/train.py b/experiments/cifar10/train.py index abe208d..3adad85 100644 --- a/experiments/cifar10/train.py +++ b/experiments/cifar10/train.py @@ -3,10 +3,9 @@ from __future__ import print_function from __future__ import unicode_literals -import sys import re -from .dataset import load_cifar10_data +from dataset import load_cifar10_data from experiments.train_scripts import flags, select_fold, train from experiments.tfs.image import * from functools import partial diff --git a/experiments/train_scripts.py b/experiments/train_scripts.py index 0c88aa4..2510de5 100644 --- a/experiments/train_scripts.py +++ b/experiments/train_scripts.py @@ -8,7 +8,6 @@ import re import tensorflow as tf import tensorflow.contrib.slim as slim -import sys import cloud from .utils import parse_config_str diff --git a/keras/utils.py b/keras/utils.py index 3535086..c95bc99 100644 --- a/keras/utils.py +++ b/keras/utils.py @@ -4,7 +4,7 @@ from __future__ import unicode_literals import json -import pickle +from six import pickle from experiments.train_scripts import GENERATORS from experiments.utils import parse_config_str diff --git a/pretrained/cifar10/logs/run_log.json b/pretrained/cifar10/logs/run_log.json deleted file mode 100644 index e93f64b..0000000 --- a/pretrained/cifar10/logs/run_log.json +++ /dev/null @@ -1,61 +0,0 @@ -{ - "batch_size": 32, - "commit_hash": "b'9cdfe57\\n'", - "date_stamp": null, - "dims": [32, 32, 3], - "disc_lr": 1e-05, - "discriminator": "dcnn", - "end_batch_size": 50, - "end_batch_size_u": 10, - "end_discriminator": "resnet", - "end_epochs": 100, - "end_lr": 0.001, - "end_lr_mode": "constant", - "end_lr_schedule": 0, - "end_optimizer": "momentum", - "end_per_img_std": true, - "end_weight_decay": 0.0002, - "eval_every": 1, - "gamma": 0.5, - "gen_config": "init_type=train,feed_actions=True,n_stack=1,logit_range=6.0", - "gen_lr": 0.0001, - "generator": "gru", - "is_test": true, - "log_path": null, - "log_root": "experiments/log", - "ls_term": 0.1, - "ls_term_n_passes": 1, - "mse_layer": 1, - "mse_term": 0.001, - "n_disc_steps": 1, - "n_epochs": 1, - "n_folds": -1, - "n_gen_steps": 1, - "n_per_class": 500, - "n_sample": 5, - "n_tan_train": 0, - "p_transform_drop": null, - "p_transform_init": 0.1, - "p_transform_max": 0.5, - "p_transform_min": 0.5, - "p_transform_rate": 1.05, - "per_img_std": false, - "plot_every": 1, - "rand_loss_every": 1, - "run_fold": 0, - "run_index": 0, - "run_name": "cifar_tan", - "run_type": "tan-only", - "save_action_seqs": true, - "save_end_model_every": 50, - "save_model": false, - "seq_len": 10, - "subsample_seed": 1701, - "tan_checkpoint_path": null, - "time_stamp": null, - "train_disc": true, - "train_module": "experiments.cifar10.train", - "transform_validation_set": false, - "transformer": "image", - "validation_set": true -} \ No newline at end of file diff --git a/python-package-requirement.txt b/python-package-requirement.txt index 1281863..ed8c296 100644 --- a/python-package-requirement.txt +++ b/python-package-requirement.txt @@ -6,3 +6,4 @@ scikit-image>=0.13 scipy>=0.18 six tensorflow>=1.2 +cloud \ No newline at end of file From 4dbc41c85f8656867ef35639d6b32a95cbf86a56 Mon Sep 17 00:00:00 2001 From: ajratner Date: Wed, 13 Dec 2017 15:02:31 -0800 Subject: [PATCH 6/6] Nits on nits on nits --- experiments/cifar10/train.py | 2 -- keras/utils.py | 4 ++-- python-package-requirement.txt | 4 ++-- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/experiments/cifar10/train.py b/experiments/cifar10/train.py index 3adad85..604df09 100644 --- a/experiments/cifar10/train.py +++ b/experiments/cifar10/train.py @@ -3,8 +3,6 @@ from __future__ import print_function from __future__ import unicode_literals -import re - from dataset import load_cifar10_data from experiments.train_scripts import flags, select_fold, train from experiments.tfs.image import * diff --git a/keras/utils.py b/keras/utils.py index c95bc99..4e2d893 100644 --- a/keras/utils.py +++ b/keras/utils.py @@ -4,7 +4,7 @@ from __future__ import unicode_literals import json -from six import pickle +from six import cPickle from experiments.train_scripts import GENERATORS from experiments.utils import parse_config_str @@ -22,7 +22,7 @@ def load_pretrained_tan(path): # Load TFs with open(os.path.join(path, 'tfs.pkl'), 'w') as f: - tfs = pickle.load(f) + tfs = cPickle.load(f) # Build transformer T = PadCropTransformer(tfs, dims=config['dims']) diff --git a/python-package-requirement.txt b/python-package-requirement.txt index ed8c296..e306e3a 100644 --- a/python-package-requirement.txt +++ b/python-package-requirement.txt @@ -1,3 +1,4 @@ +cloud matplotlib numpy>=1.11 pandas @@ -5,5 +6,4 @@ pillow scikit-image>=0.13 scipy>=0.18 six -tensorflow>=1.2 -cloud \ No newline at end of file +tensorflow>=1.2 \ No newline at end of file