diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..6a176ee32a --- /dev/null +++ b/.gitignore @@ -0,0 +1,111 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class +*.sw[op] +*.pkl +*pkl.gz +results/ +*.pth +*.DS_Store +*.Rhistory + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/README.txt b/README.txt index 3355f5d134..0871a0d397 100644 --- a/README.txt +++ b/README.txt @@ -15,11 +15,9 @@ Recommended hardware: 4 NVIDIA Tesla P-100 GPUs or 8 NVIDIA Tesla K-80 GPUs Instructions for preparing the data: 1. Download the dataset CUB_200_2011.tgz from http://www.vision.caltech.edu/visipedia/CUB-200-2011.html 2. Unpack CUB_200_2011.tgz -3. Crop the images using information from bounding_boxes.txt (included in the dataset) -4. Split the cropped images into training and test sets, using train_test_split.txt (included in the dataset) -5. Put the cropped training images in the directory "./datasets/cub200_cropped/train_cropped/" -6. Put the cropped test images in the directory "./datasets/cub200_cropped/test_cropped/" -7. Augment the training set using img_aug.py (included in this code package) +3. Run initial_images_processing.py to process dataset into proper form + -- Ex. python initial_images_processing.py /path/to/CUB_200_2011/images/ +4. Augment the training set using img_aug.py (included in this code package) -- this will create an augmented training set in the following directory: "./datasets/cub200_cropped/train_cropped_augmented/" diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000..b6437822aa --- /dev/null +++ b/environment.yml @@ -0,0 +1,17 @@ +name: protopnet +dependencies: + - python=3.7 + - ipython + - numpy + - pandas + - pytorch + - torchvision + - cudatoolkit + - scipy + - opencv + - matplotlib + - pip: + - Augmentor +channels: + - pytorch + - conda-forge diff --git a/img_aug.py b/img_aug.py index ca05b1954a..085eb0e36e 100644 --- a/img_aug.py +++ b/img_aug.py @@ -1,5 +1,7 @@ import Augmentor import os + + def makedir(path): ''' if path does not exist in the file system, create it @@ -17,7 +19,7 @@ def makedir(path): for i in range(len(folders)): fd = folders[i] - tfd = target_folders[i] + tfd = os.path.abspath(target_folders[i]) # rotation p = Augmentor.Pipeline(source_directory=fd, output_directory=tfd) p.rotate(probability=1, max_left_rotation=15, max_right_rotation=15) diff --git a/initial_images_processing.py b/initial_images_processing.py new file mode 100644 index 0000000000..279b7e7c13 --- /dev/null +++ b/initial_images_processing.py @@ -0,0 +1,62 @@ +import argparse +from pathlib import Path + +import cv2 + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('images_dir', help='relative path to the images folder as described in CUB 200-2011') + args = parser.parse_args() + + imgs_txt = Path(args.images_dir).joinpath('../images.txt') + bb_txt = Path(args.images_dir).joinpath('../bounding_boxes.txt') + train_test = Path(args.images_dir).joinpath('../train_test_split.txt') + train_dir = Path(__file__).parent.joinpath('datasets/cub200_cropped/train_cropped') + test_dir = Path(__file__).parent.joinpath('datasets/cub200_cropped/test_cropped') + if not train_dir.exists(): + train_dir.mkdir(parents=True) + if not test_dir.exists(): + test_dir.mkdir(parents=True) + + imgs_to_data = [] + with open(str(imgs_txt)) as imgs: + img_index = imgs.readlines() + with open(str(bb_txt)) as bb: + bb_index = bb.readlines() + with open(str(train_test)) as tt: + tt_index = tt.readlines() + + for i, line in enumerate(img_index): + n1, filename = line.strip().split(' ') + n2, x, y, width, height = bb_index[i].strip().split(' ') + n3, is_train = tt_index[i].strip().split(' ') + if n1 != n2 or n2 != n3: + raise Exception('something went wrong and indexing on images.txt/bounding_boxes.txt/train_test_split.txt is off') + imgs_to_data.append([ + Path(args.images_dir).joinpath(filename), + int(float(x)), + int(float(y)), + int(float(width)), + int(float(height)), + bool(int(is_train)), + ]) + + for path, x, y, w, h, is_train in imgs_to_data: + im = cv2.imread(str(path)) + # crop and save + im = im[y:y+h, x:x+w, :] + + if is_train: + final_dir = train_dir.joinpath(path.parent.name) + else: + final_dir = test_dir.joinpath(path.parent.name) + final_path = final_dir.joinpath(path.name) + if not final_dir.exists(): + final_dir.mkdir() + + cv2.imwrite(str(final_path), im) + + +if __name__ == "__main__": + main() diff --git a/main.py b/main.py index 46491e3881..bcf7414c60 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,6 @@ import os import shutil +import copy import torch import torch.utils.data @@ -12,12 +13,14 @@ from helpers import makedir import model -import push +from push import Pusher import prune import train_and_test as tnt import save from log import create_logger from preprocess import mean, std, preprocess_input_function +from settings import * + parser = argparse.ArgumentParser() parser.add_argument('-gpuid', nargs=1, type=str, default='0') # python3 main.py -gpuid=0,1,2,3 @@ -26,8 +29,6 @@ print(os.environ['CUDA_VISIBLE_DEVICES']) # book keeping namings and code -from settings import base_architecture, img_size, prototype_shape, num_classes, \ - prototype_activation_function, add_on_layers_type, experiment_run base_architecture_type = re.match('^[a-z]*', base_architecture).group(0) @@ -48,8 +49,14 @@ proto_bound_boxes_filename_prefix = 'bb' # load the data -from settings import train_dir, test_dir, train_push_dir, \ - train_batch_size, test_batch_size, train_push_batch_size + + +def perform_push(pusher, epoch_number): + if use_protobank: + pusher.push_protobank(epoch_number) + else: + pusher.push_orig(epoch_number) + normalize = transforms.Normalize(mean=mean, std=std) @@ -100,7 +107,8 @@ prototype_shape=prototype_shape, num_classes=num_classes, prototype_activation_function=prototype_activation_function, - add_on_layers_type=add_on_layers_type) + add_on_layers_type=add_on_layers_type, + bank_size=bank_size) #if prototype_activation_function == 'linear': # ppnet.set_last_layer_incorrect_connection(incorrect_strength=0) ppnet = ppnet.cuda() @@ -108,35 +116,54 @@ class_specific = True # define optimizer -from settings import joint_optimizer_lrs, joint_lr_step_size +if use_protobank: + prototype_params = { + 'params': ppnet.protobank_tensor, + 'lr': joint_optimizer_lrs['prototype_vectors'] + } +else: + prototype_params = { + 'params': ppnet.prototype_vectors, + 'lr': joint_optimizer_lrs['prototype_vectors'] + } + joint_optimizer_specs = \ [{'params': ppnet.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized {'params': ppnet.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3}, - {'params': ppnet.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']}, + prototype_params, ] joint_optimizer = torch.optim.Adam(joint_optimizer_specs) joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.1) -from settings import warm_optimizer_lrs warm_optimizer_specs = \ [{'params': ppnet.add_on_layers.parameters(), 'lr': warm_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3}, - {'params': ppnet.prototype_vectors, 'lr': warm_optimizer_lrs['prototype_vectors']}, + prototype_params, ] warm_optimizer = torch.optim.Adam(warm_optimizer_specs) -from settings import last_layer_optimizer_lr last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}] last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs) -# weighting of different training losses -from settings import coefs - -# number of training epochs, number of warm epochs, push start epoch, push epochs -from settings import num_train_epochs, num_warm_epochs, push_start, push_epochs - # train the model log('start training') -import copy +pusher = Pusher( + train_push_loader, + prototype_network_parallel=ppnet_multi, + bank_size=bank_size, + class_specific=class_specific, + preprocess_input_function=preprocess_input_function, # normalize if needed + prototype_layer_stride=1, + dir_for_saving_prototypes=img_dir, # if not None, prototypes will be saved here + prototype_img_filename_prefix=prototype_img_filename_prefix, + prototype_self_act_filename_prefix=prototype_self_act_filename_prefix, + proto_bound_boxes_filename_prefix=proto_bound_boxes_filename_prefix, + save_prototype_class_identity=True, + log=log +) + +if do_initial_push: + perform_push(pusher, 0) + for epoch in range(num_train_epochs): log('epoch: \t{0}'.format(epoch)) @@ -156,19 +183,7 @@ target_accu=0.70, log=log) if epoch >= push_start and epoch in push_epochs: - push.push_prototypes( - train_push_loader, # pytorch dataloader (must be unnormalized in [0,1]) - prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors - class_specific=class_specific, - preprocess_input_function=preprocess_input_function, # normalize if needed - prototype_layer_stride=1, - root_dir_for_saving_prototypes=img_dir, # if not None, prototypes will be saved here - epoch_number=epoch, # if not provided, prototypes saved previously will be overwritten - prototype_img_filename_prefix=prototype_img_filename_prefix, - prototype_self_act_filename_prefix=prototype_self_act_filename_prefix, - proto_bound_boxes_filename_prefix=proto_bound_boxes_filename_prefix, - save_prototype_class_identity=True, - log=log) + perform_push(pusher, epoch) accu = tnt.test(model=ppnet_multi, dataloader=test_loader, class_specific=class_specific, log=log) save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'push', accu=accu, @@ -184,6 +199,5 @@ class_specific=class_specific, log=log) save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + '_' + str(i) + 'push', accu=accu, target_accu=0.70, log=log) - -logclose() +logclose() diff --git a/model.py b/model.py index b8870ab161..77e5fa7aba 100644 --- a/model.py +++ b/model.py @@ -33,7 +33,7 @@ class PPNet(nn.Module): def __init__(self, features, img_size, prototype_shape, proto_layer_rf_info, num_classes, init_weights=True, prototype_activation_function='log', - add_on_layers_type='bottleneck'): + add_on_layers_type='bottleneck', bank_size=20): super(PPNet, self).__init__() self.img_size = img_size @@ -41,7 +41,7 @@ def __init__(self, features, img_size, prototype_shape, self.num_prototypes = prototype_shape[0] self.num_classes = num_classes self.epsilon = 1e-4 - + # prototype_activation_function could be 'log', 'linear', # or a generic function that converts distance to similarity score self.prototype_activation_function = prototype_activation_function @@ -101,10 +101,14 @@ def __init__(self, features, img_size, prototype_shape, nn.Conv2d(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1), nn.Sigmoid() ) - + self.prototype_vectors = nn.Parameter(torch.rand(self.prototype_shape), requires_grad=True) + self.bank_size = bank_size + protobank_shape = list(self.prototype_shape) + protobank_shape[0] = protobank_shape[0] * self.bank_size + self.protobank_tensor = nn.Parameter(torch.rand(protobank_shape), requires_grad=True) # do not make this just a tensor, # since it will not be moved automatically to gpu self.ones = nn.Parameter(torch.ones(self.prototype_shape), @@ -170,13 +174,39 @@ def _l2_convolution(self, x): return distances - def prototype_distances(self, x): + def _l2_protobank(self, x): ''' x is the raw input ''' + x2 = x ** 2 + x2_patch_sum = F.conv2d(input=x2, weight=self.ones) + + p2 = self.protobank_tensor ** 2 + p2 = torch.sum(p2, dim=(1, 2, 3)) + # p2 is a vector of shape (num_prototypes,) + # then we reshape it to (num_prototypes, 1, 1) + p2_reshape = p2.view(-1, 1, 1) + + # vectorize + xpb = F.conv2d(input=x, weight=self.protobank_tensor) + last_dim = xpb.size()[-1] + tmp = -2 * xpb + p2_reshape + # torch split is the opposite of cat. We did not want view after the conv op to + # break up the tensor. Otherwise we need to use view pretty extensively to + # ensure proper mathematic operations occur + tmp = torch.cat([v.unsqueeze(0) for v in tmp.split(self.num_prototypes, dim=1)], dim=0) + tmp = tmp + x2_patch_sum + distances = F.relu(tmp).view(-1, self.num_prototypes, last_dim, last_dim) + return distances + + def prototype_distances(self, x): conv_features = self.conv_features(x) distances = self._l2_convolution(conv_features) - return distances + return conv_features, distances + + def protobank_distances(self, x): + conv_features = self.conv_features(x) + return conv_features, self._l2_protobank(conv_features) def distance_2_similarity(self, distances): if self.prototype_activation_function == 'log': @@ -186,8 +216,8 @@ def distance_2_similarity(self, distances): else: return self.prototype_activation_function(distances) - def forward(self, x): - distances = self.prototype_distances(x) + def forward_orig(self, x): + _, distances = self.prototype_distances(x) ''' we cannot refactor the lines below for similarity scores because we need to return min_distances @@ -201,11 +231,23 @@ def forward(self, x): logits = self.last_layer(prototype_activations) return logits, min_distances - def push_forward(self, x): - '''this method is needed for the pushing operation''' - conv_output = self.conv_features(x) - distances = self._l2_convolution(conv_output) - return conv_output, distances + def forward_protobank(self, x): + _, distances = self.protobank_distances(x) + last_dim = distances.size(-1) + min_distances = -F.max_pool2d(-distances, kernel_size=(last_dim, last_dim)) + # this is the version where we just take a minima. There can also be other + # variants like where you just take an average/sum of the memory bank. We can + # code this up later tho. + min_distances = min_distances.view(self.bank_size, -1, self.num_prototypes).min(dim=0)[0] + prototype_activations = self.distance_2_similarity(min_distances) + logits = self.last_layer(prototype_activations) + return logits, min_distances + + def forward(self, x): + if self.bank_size > 1: + return self.forward_protobank(x) + else: + return self.forward_orig(x) def prune_prototypes(self, prototypes_to_prune): ''' @@ -288,7 +330,8 @@ def _initialize_weights(self): def construct_PPNet(base_architecture, pretrained=True, img_size=224, prototype_shape=(2000, 512, 1, 1), num_classes=200, prototype_activation_function='log', - add_on_layers_type='bottleneck'): + add_on_layers_type='bottleneck', + bank_size=20): features = base_architecture_to_features[base_architecture](pretrained=pretrained) layer_filter_sizes, layer_strides, layer_paddings = features.conv_info() proto_layer_rf_info = compute_proto_layer_rf_info_v2(img_size=img_size, @@ -303,5 +346,5 @@ def construct_PPNet(base_architecture, pretrained=True, img_size=224, num_classes=num_classes, init_weights=True, prototype_activation_function=prototype_activation_function, - add_on_layers_type=add_on_layers_type) - + add_on_layers_type=add_on_layers_type, + bank_size=bank_size) diff --git a/push.py b/push.py index f3eb146b41..2b2f5fd512 100644 --- a/push.py +++ b/push.py @@ -1,305 +1,452 @@ -import torch -import numpy as np -import matplotlib.pyplot as plt -import cv2 import os import copy import time +import cv2 +import torch.nn.functional as F +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch import nn + from receptive_field import compute_rf_prototype from helpers import makedir, find_high_activation_crop -# push each prototype to the nearest patch in the training set -def push_prototypes(dataloader, # pytorch dataloader (must be unnormalized in [0,1]) - prototype_network_parallel, # pytorch network with prototype_vectors - class_specific=True, - preprocess_input_function=None, # normalize if needed - prototype_layer_stride=1, - root_dir_for_saving_prototypes=None, # if not None, prototypes will be saved here - epoch_number=None, # if not provided, prototypes saved previously will be overwritten - prototype_img_filename_prefix=None, - prototype_self_act_filename_prefix=None, - proto_bound_boxes_filename_prefix=None, - save_prototype_class_identity=True, # which class the prototype image comes from - log=print, - prototype_activation_function_in_numpy=None): - - prototype_network_parallel.eval() - log('\tpush') - - start = time.time() - prototype_shape = prototype_network_parallel.module.prototype_shape - n_prototypes = prototype_network_parallel.module.num_prototypes - # saves the closest distance seen so far - global_min_proto_dist = np.full(n_prototypes, np.inf) - # saves the patch representation that gives the current smallest distance - global_min_fmap_patches = np.zeros( - [n_prototypes, - prototype_shape[1], - prototype_shape[2], - prototype_shape[3]]) - - ''' - proto_rf_boxes and proto_bound_boxes column: - 0: image index in the entire dataset - 1: height start index - 2: height end index - 3: width start index - 4: width end index - 5: (optional) class identity - ''' - if save_prototype_class_identity: - proto_rf_boxes = np.full(shape=[n_prototypes, 6], - fill_value=-1) - proto_bound_boxes = np.full(shape=[n_prototypes, 6], - fill_value=-1) - else: - proto_rf_boxes = np.full(shape=[n_prototypes, 5], - fill_value=-1) - proto_bound_boxes = np.full(shape=[n_prototypes, 5], - fill_value=-1) - - if root_dir_for_saving_prototypes != None: - if epoch_number != None: - proto_epoch_dir = os.path.join(root_dir_for_saving_prototypes, - 'epoch-'+str(epoch_number)) - makedir(proto_epoch_dir) - else: - proto_epoch_dir = root_dir_for_saving_prototypes - else: - proto_epoch_dir = None - search_batch_size = dataloader.batch_size +# XXX my two ideas here are weight averaging for push and a memory bank. I think memory bank is the most +# interesting item to implement first, but they both have similar implementation in code +class Pusher(object): + def __init__(self, + dataloader, + prototype_network_parallel, # pytorch network with prototype_vectors + bank_size, + class_specific=True, + preprocess_input_function=None, # normalize if needed + prototype_layer_stride=1, + dir_for_saving_prototypes=None, # if not None, prototypes will be saved here + prototype_img_filename_prefix=None, + prototype_self_act_filename_prefix=None, + proto_bound_boxes_filename_prefix=None, + save_prototype_class_identity=True, # which class the prototype image comes from + log=print, + prototype_activation_function_in_numpy=None): + self.dataloader = dataloader + self.prototype_network_parallel = prototype_network_parallel + self.class_specific = class_specific + self.preprocess_input_function = preprocess_input_function + self.prototype_layer_stride = prototype_layer_stride + self.dir_for_saving_prototypes = dir_for_saving_prototypes + self.prototype_img_filename_prefix = prototype_img_filename_prefix + self.prototype_self_act_filename_prefix = prototype_self_act_filename_prefix + self.save_prototype_class_identity = save_prototype_class_identity + self.log = log + self.prototype_activation_function_in_numpy = prototype_activation_function_in_numpy + self.bank_size = bank_size - num_classes = prototype_network_parallel.module.num_classes + def push_orig(self, epoch_number): + self.prototype_network_parallel.eval() + self.log('\tpush') + + start = time.time() + prototype_shape = self.prototype_network_parallel.module.prototype_shape + n_prototypes = self.prototype_network_parallel.module.num_prototypes + # saves the closest distance seen so far + global_min_proto_dist = np.full(n_prototypes, np.inf) + # saves the patch representation that gives the current smallest distance + global_min_fmap_patches = np.zeros( + [n_prototypes, + prototype_shape[1], + prototype_shape[2], + prototype_shape[3]]) - for push_iter, (search_batch_input, search_y) in enumerate(dataloader): ''' - start_index_of_search keeps track of the index of the image - assigned to serve as prototype + proto_rf_boxes and proto_bound_boxes column: + 0: image index in the entire dataset + 1: height start index + 2: height end index + 3: width start index + 4: width end index + 5: (optional) class identity ''' - start_index_of_search_batch = push_iter * search_batch_size + if self.save_prototype_class_identity: + proto_rf_boxes = np.full(shape=[n_prototypes, 6], + fill_value=-1) + proto_bound_boxes = np.full(shape=[n_prototypes, 6], + fill_value=-1) + else: + proto_rf_boxes = np.full(shape=[n_prototypes, 5], + fill_value=-1) + proto_bound_boxes = np.full(shape=[n_prototypes, 5], + fill_value=-1) - update_prototypes_on_batch(search_batch_input, - start_index_of_search_batch, - prototype_network_parallel, - global_min_proto_dist, - global_min_fmap_patches, - proto_rf_boxes, - proto_bound_boxes, - class_specific=class_specific, - search_y=search_y, - num_classes=num_classes, - preprocess_input_function=preprocess_input_function, - prototype_layer_stride=prototype_layer_stride, - dir_for_saving_prototypes=proto_epoch_dir, - prototype_img_filename_prefix=prototype_img_filename_prefix, - prototype_self_act_filename_prefix=prototype_self_act_filename_prefix, - prototype_activation_function_in_numpy=prototype_activation_function_in_numpy) - - if proto_epoch_dir != None and proto_bound_boxes_filename_prefix != None: - np.save(os.path.join(proto_epoch_dir, proto_bound_boxes_filename_prefix + '-receptive_field' + str(epoch_number) + '.npy'), - proto_rf_boxes) - np.save(os.path.join(proto_epoch_dir, proto_bound_boxes_filename_prefix + str(epoch_number) + '.npy'), - proto_bound_boxes) - - log('\tExecuting push ...') - prototype_update = np.reshape(global_min_fmap_patches, - tuple(prototype_shape)) - prototype_network_parallel.module.prototype_vectors.data.copy_(torch.tensor(prototype_update, dtype=torch.float32).cuda()) - # prototype_network_parallel.cuda() - end = time.time() - log('\tpush time: \t{0}'.format(end - start)) - -# update each prototype for current search batch -def update_prototypes_on_batch(search_batch_input, - start_index_of_search_batch, - prototype_network_parallel, - global_min_proto_dist, # this will be updated - global_min_fmap_patches, # this will be updated - proto_rf_boxes, # this will be updated - proto_bound_boxes, # this will be updated - class_specific=True, - search_y=None, # required if class_specific == True - num_classes=None, # required if class_specific == True - preprocess_input_function=None, - prototype_layer_stride=1, - dir_for_saving_prototypes=None, - prototype_img_filename_prefix=None, - prototype_self_act_filename_prefix=None, - prototype_activation_function_in_numpy=None): - - prototype_network_parallel.eval() - - if preprocess_input_function is not None: - # print('preprocessing input for pushing ...') - # search_batch = copy.deepcopy(search_batch_input) - search_batch = preprocess_input_function(search_batch_input) - - else: - search_batch = search_batch_input - - with torch.no_grad(): - search_batch = search_batch.cuda() - # this computation currently is not parallelized - protoL_input_torch, proto_dist_torch = prototype_network_parallel.module.push_forward(search_batch) - - protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy()) - proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy()) - - del protoL_input_torch, proto_dist_torch - - if class_specific: + if self.dir_for_saving_prototypes != None: + if epoch_number != None: + proto_epoch_dir = os.path.join(self.dir_for_saving_prototypes, + 'epoch-'+str(epoch_number)) + makedir(proto_epoch_dir) + else: + # XXX I think dir_for_saving_proto and root_dir are actually + # different variables and it wasnt a misnaming. Oh well + # I'll come back to this later + proto_epoch_dir = self.dir_for_saving_prototypes + else: + proto_epoch_dir = None + + search_batch_size = self.dataloader.batch_size + + num_classes = self.prototype_network_parallel.module.num_classes + + for push_iter, (search_batch_input, search_y) in enumerate(self.dataloader): + ''' + start_index_of_search keeps track of the index of the image + assigned to serve as prototype + ''' + start_index_of_search_batch = push_iter * search_batch_size + + self.update_prototypes_on_batch(search_batch_input, + start_index_of_search_batch, + global_min_proto_dist, + global_min_fmap_patches, + proto_rf_boxes, + proto_bound_boxes, + search_y=search_y, + num_classes=num_classes) + + if proto_epoch_dir != None and self.proto_bound_boxes_filename_prefix != None: + np.save(os.path.join(proto_epoch_dir, self.proto_bound_boxes_filename_prefix + '-receptive_field' + str(epoch_number) + '.npy'), + proto_rf_boxes) + np.save(os.path.join(proto_epoch_dir, self.proto_bound_boxes_filename_prefix + str(epoch_number) + '.npy'), + proto_bound_boxes) + + # XXX push here is different because we're choosing top K vectors. + self.log('\tExecuting push ...') + prototype_update = np.reshape(global_min_fmap_patches, + tuple(prototype_shape)) + self.prototype_network_parallel.module.prototype_vectors.data.copy_(torch.tensor(prototype_update, dtype=torch.float32).cuda()) + # prototype_network_parallel.cuda() + end = time.time() + self.log('\tpush time: \t{0}'.format(end - start)) + + def push_protobank(self, epoch_number): + self.prototype_network_parallel.eval() + self.log('\tpush') + + start = time.time() + prototype_shape = self.prototype_network_parallel.module.prototype_shape + protobank_shape = list(prototype_shape) + protobank_shape[0] = protobank_shape[0] * self.bank_size + n_prototypes = self.prototype_network_parallel.module.num_prototypes + # XXX skip proto_rf impl + # + # XXX skip dir save impl + search_batch_size = self.dataloader.batch_size + num_classes = self.prototype_network_parallel.module.num_classes + all_proto_dist = np.full((self.bank_size, n_prototypes), np.inf) + # saves the patch representation that gives the current smallest distance + # this is really the only var we care about in the entire update process + all_fmap_patches = np.zeros([ + self.bank_size, n_prototypes, prototype_shape[1], prototype_shape[2], prototype_shape[3] + ]) + + for push_iter, (search_batch_input, search_y) in enumerate(self.dataloader): + ''' + start_index_of_search keeps track of the index of the image + assigned to serve as prototype + ''' + start_index_of_search_batch = push_iter * search_batch_size + + self.generic_update_search(search_batch_input, + start_index_of_search_batch, + all_proto_dist, + all_fmap_patches, + search_y=search_y, + num_classes=num_classes) + + # XXX didnt impl bound boxes filename + self.log('\tExecuting push ...') + prototype_update = np.reshape(all_fmap_patches, tuple(protobank_shape)) + self.prototype_network_parallel.module.protobank_tensor.data.copy_(torch.tensor(prototype_update, dtype=torch.float32).cuda()) + # prototype_network_parallel.cuda() + end = time.time() + self.log('\tTotal push time: \t{0}'.format(end - start)) + + def generic_update_search(self, + search_batch_input, + start_index_of_search_batch, + all_proto_dist, + all_fmap_patches, + search_y, + num_classes): + """ + Make generic update on prototype vectors based on some kind of + memory bank available to us. This follows a top-K method + """ + self.prototype_network_parallel.eval() + # XXX no preprocess func implemented + with torch.no_grad(): + search_batch_input = search_batch_input.cuda() + # you could run this across multi-gpu but you'd have to make mod + # to the DataParallel class because it doesn't allow you to + # call other methods besides forward func + conv_features, distances = self.prototype_network_parallel.module.protobank_distances(search_batch_input) + + pool = nn.MaxPool2d(distances.size()[-1], distances.size()[-1], return_indices=True) + protoL_input_ = np.copy(conv_features.detach().cpu().numpy()) + proto_dist_ = np.copy(distances.detach().cpu().numpy()) + pooled, pooled_idx = pool(-distances) + pooled = -pooled.detach().cpu().numpy() + + # XXX non-class specific not implemented class_to_img_index_dict = {key: [] for key in range(num_classes)} # img_y is the image's integer label for img_index, img_y in enumerate(search_y): img_label = img_y.item() class_to_img_index_dict[img_label].append(img_index) - prototype_shape = prototype_network_parallel.module.prototype_shape - n_prototypes = prototype_shape[0] - proto_h = prototype_shape[2] - proto_w = prototype_shape[3] - max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3] + prototype_shape = self.prototype_network_parallel.module.prototype_shape + n_prototypes = prototype_shape[0] + proto_h = prototype_shape[2] + proto_w = prototype_shape[3] + max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3] - for j in range(n_prototypes): - #if n_prototypes_per_class != None: - if class_specific: + for j in range(n_prototypes): + #if n_prototypes_per_class != None: # target_class is the class of the class_specific prototype - target_class = torch.argmax(prototype_network_parallel.module.prototype_class_identity[j]).item() + # XXX non-class specific not implemented + target_class = torch.argmax(self.prototype_network_parallel.module.prototype_class_identity[j]).item() # if there is not images of the target_class from this batch # we go on to the next prototype if len(class_to_img_index_dict[target_class]) == 0: continue - proto_dist_j = proto_dist_[class_to_img_index_dict[target_class]][:,j,:,:] + proto_dist_j = proto_dist_[class_to_img_index_dict[target_class]][:, j] + # XXX well apparently this function is more complex than i realized. Its not + # just finding a specific prototype input, but it finds a specific pixel + # that they're trying to set. This complicates the actual setting of + # the conv value because in my impl i was doing max pooling and setting + # by row. + # + # XXX So this in turn leads me to question the whole thing about push. I mean + # at the end, its just focused on minimizing the l2 distance to 0 + # after min pooling. But why pick the min for the push? Why not a mean? + # I should read thru the math again to see if theres rationale for the min + # I mean, setting it to anything could have the desired outcome of + # minimizing l2 + for i in range(self.bank_size): + batch_min_proto_dist_j = np.amin(proto_dist_j) + # this works originally because the pooling done is basically an + # argmin across the 7x7. If you want to simplfiy this you can + # directly perform the argmin here. + which_less_than = batch_min_proto_dist_j < all_proto_dist[:, j] + if which_less_than.any(): + batch_argmin_proto_dist_j = \ + list(np.unravel_index(np.argmin(proto_dist_j, axis=None), + proto_dist_j.shape)) + # XXX no non-class specific + img_index_in_batch = class_to_img_index_dict[target_class][batch_argmin_proto_dist_j[0]] + + # retrieve the corresponding feature map patch + fmap_height_start_index = batch_argmin_proto_dist_j[1] * self.prototype_layer_stride + fmap_height_end_index = fmap_height_start_index + proto_h + fmap_width_start_index = batch_argmin_proto_dist_j[2] * self.prototype_layer_stride + fmap_width_end_index = fmap_width_start_index + proto_w + + batch_min_fmap_patch_j = protoL_input_[img_index_in_batch, + :, + fmap_height_start_index:fmap_height_end_index, + fmap_width_start_index:fmap_width_end_index] + + # find the bank idx to update first and then do it + at_bank_idx = np.searchsorted(which_less_than, True) + # do insert instead of update, and then truncate + dist_vals = all_proto_dist[:, j] + all_proto_dist[:, j] = np.insert(dist_vals, at_bank_idx, batch_min_proto_dist_j)[:-1] + fmap_vals = all_fmap_patches[:, j] + all_fmap_patches[:, j] = np.insert(fmap_vals, at_bank_idx, batch_min_fmap_patch_j,axis=0)[:-1] + # remove the rest of this. I'd like to separate it out into another + # function. Currently we dont need it and we can always refer to the + # original function if we want the code + + def update_prototypes_on_batch(self, + search_batch_input, + start_index_of_search_batch, + global_min_proto_dist, # this will be updated + global_min_fmap_patches, # this will be updated + proto_rf_boxes, # this will be updated + proto_bound_boxes, # this will be updated + search_y=None, # required if class_specific == True + num_classes=None): # required if class_specific == True + + self.prototype_network_parallel.eval() + + if self.preprocess_input_function is not None: + # print('preprocessing input for pushing ...') + # search_batch = copy.deepcopy(search_batch_input) + search_batch = self.preprocess_input_function(search_batch_input) + else: - # if it is not class specific, then we will search through - # every example - proto_dist_j = proto_dist_[:,j,:,:] - - batch_min_proto_dist_j = np.amin(proto_dist_j) - if batch_min_proto_dist_j < global_min_proto_dist[j]: - batch_argmin_proto_dist_j = \ - list(np.unravel_index(np.argmin(proto_dist_j, axis=None), - proto_dist_j.shape)) - if class_specific: - ''' - change the argmin index from the index among - images of the target class to the index in the entire search - batch - ''' - batch_argmin_proto_dist_j[0] = class_to_img_index_dict[target_class][batch_argmin_proto_dist_j[0]] - - # retrieve the corresponding feature map patch - img_index_in_batch = batch_argmin_proto_dist_j[0] - fmap_height_start_index = batch_argmin_proto_dist_j[1] * prototype_layer_stride - fmap_height_end_index = fmap_height_start_index + proto_h - fmap_width_start_index = batch_argmin_proto_dist_j[2] * prototype_layer_stride - fmap_width_end_index = fmap_width_start_index + proto_w - - batch_min_fmap_patch_j = protoL_input_[img_index_in_batch, - :, - fmap_height_start_index:fmap_height_end_index, - fmap_width_start_index:fmap_width_end_index] - - global_min_proto_dist[j] = batch_min_proto_dist_j - global_min_fmap_patches[j] = batch_min_fmap_patch_j - - # get the receptive field boundary of the image patch - # that generates the representation - protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info - rf_prototype_j = compute_rf_prototype(search_batch.size(2), batch_argmin_proto_dist_j, protoL_rf_info) - - # get the whole image - original_img_j = search_batch_input[rf_prototype_j[0]] - original_img_j = original_img_j.numpy() - original_img_j = np.transpose(original_img_j, (1, 2, 0)) - original_img_size = original_img_j.shape[0] - - # crop out the receptive field - rf_img_j = original_img_j[rf_prototype_j[1]:rf_prototype_j[2], - rf_prototype_j[3]:rf_prototype_j[4], :] - - # save the prototype receptive field information - proto_rf_boxes[j, 0] = rf_prototype_j[0] + start_index_of_search_batch - proto_rf_boxes[j, 1] = rf_prototype_j[1] - proto_rf_boxes[j, 2] = rf_prototype_j[2] - proto_rf_boxes[j, 3] = rf_prototype_j[3] - proto_rf_boxes[j, 4] = rf_prototype_j[4] - if proto_rf_boxes.shape[1] == 6 and search_y is not None: - proto_rf_boxes[j, 5] = search_y[rf_prototype_j[0]].item() - - # find the highly activated region of the original image - proto_dist_img_j = proto_dist_[img_index_in_batch, j, :, :] - if prototype_network_parallel.module.prototype_activation_function == 'log': - proto_act_img_j = np.log((proto_dist_img_j + 1) / (proto_dist_img_j + prototype_network_parallel.module.epsilon)) - elif prototype_network_parallel.module.prototype_activation_function == 'linear': - proto_act_img_j = max_dist - proto_dist_img_j + search_batch = search_batch_input + + with torch.no_grad(): + search_batch = search_batch.cuda() + # this computation currently is not parallelized + conv_features, distances = self.prototype_network_parallel.module.prototype_distances(search_batch) + + protoL_input_ = np.copy(conv_features.detach().cpu().numpy()) + proto_dist_ = np.copy(distances.detach().cpu().numpy()) + + del conv_features, distances + + if self.class_specific: + class_to_img_index_dict = {key: [] for key in range(num_classes)} + # img_y is the image's integer label + for img_index, img_y in enumerate(search_y): + img_label = img_y.item() + class_to_img_index_dict[img_label].append(img_index) + + prototype_shape = self.prototype_network_parallel.module.prototype_shape + n_prototypes = prototype_shape[0] + proto_h = prototype_shape[2] + proto_w = prototype_shape[3] + max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3] + + for j in range(n_prototypes): + #if n_prototypes_per_class != None: + if self.class_specific: + # target_class is the class of the class_specific prototype + target_class = torch.argmax(self.prototype_network_parallel.module.prototype_class_identity[j]).item() + # if there is not images of the target_class from this batch + # we go on to the next prototype + if len(class_to_img_index_dict[target_class]) == 0: + continue + proto_dist_j = proto_dist_[class_to_img_index_dict[target_class]][:,j,:,:] else: - proto_act_img_j = prototype_activation_function_in_numpy(proto_dist_img_j) - upsampled_act_img_j = cv2.resize(proto_act_img_j, dsize=(original_img_size, original_img_size), - interpolation=cv2.INTER_CUBIC) - proto_bound_j = find_high_activation_crop(upsampled_act_img_j) - # crop out the image patch with high activation as prototype image - proto_img_j = original_img_j[proto_bound_j[0]:proto_bound_j[1], - proto_bound_j[2]:proto_bound_j[3], :] - - # save the prototype boundary (rectangular boundary of highly activated region) - proto_bound_boxes[j, 0] = proto_rf_boxes[j, 0] - proto_bound_boxes[j, 1] = proto_bound_j[0] - proto_bound_boxes[j, 2] = proto_bound_j[1] - proto_bound_boxes[j, 3] = proto_bound_j[2] - proto_bound_boxes[j, 4] = proto_bound_j[3] - if proto_bound_boxes.shape[1] == 6 and search_y is not None: - proto_bound_boxes[j, 5] = search_y[rf_prototype_j[0]].item() - - if dir_for_saving_prototypes is not None: - if prototype_self_act_filename_prefix is not None: - # save the numpy array of the prototype self activation - np.save(os.path.join(dir_for_saving_prototypes, - prototype_self_act_filename_prefix + str(j) + '.npy'), - proto_act_img_j) - if prototype_img_filename_prefix is not None: - # save the whole image containing the prototype as png - plt.imsave(os.path.join(dir_for_saving_prototypes, - prototype_img_filename_prefix + '-original' + str(j) + '.png'), - original_img_j, - vmin=0.0, - vmax=1.0) - # overlay (upsampled) self activation on original image and save the result - rescaled_act_img_j = upsampled_act_img_j - np.amin(upsampled_act_img_j) - rescaled_act_img_j = rescaled_act_img_j / np.amax(rescaled_act_img_j) - heatmap = cv2.applyColorMap(np.uint8(255*rescaled_act_img_j), cv2.COLORMAP_JET) - heatmap = np.float32(heatmap) / 255 - heatmap = heatmap[...,::-1] - overlayed_original_img_j = 0.5 * original_img_j + 0.3 * heatmap - plt.imsave(os.path.join(dir_for_saving_prototypes, - prototype_img_filename_prefix + '-original_with_self_act' + str(j) + '.png'), - overlayed_original_img_j, - vmin=0.0, - vmax=1.0) - - # if different from the original (whole) image, save the prototype receptive field as png - if rf_img_j.shape[0] != original_img_size or rf_img_j.shape[1] != original_img_size: - plt.imsave(os.path.join(dir_for_saving_prototypes, - prototype_img_filename_prefix + '-receptive_field' + str(j) + '.png'), - rf_img_j, + # if it is not class specific, then we will search through + # every example + proto_dist_j = proto_dist_[:,j,:,:] + + batch_min_proto_dist_j = np.amin(proto_dist_j) + if batch_min_proto_dist_j < global_min_proto_dist[j]: + batch_argmin_proto_dist_j = \ + list(np.unravel_index(np.argmin(proto_dist_j, axis=None), + proto_dist_j.shape)) + if self.class_specific: + ''' + change the argmin index from the index among + images of the target class to the index in the entire search + batch + ''' + batch_argmin_proto_dist_j[0] = class_to_img_index_dict[target_class][batch_argmin_proto_dist_j[0]] + + # retrieve the corresponding feature map patch + img_index_in_batch = batch_argmin_proto_dist_j[0] + fmap_height_start_index = batch_argmin_proto_dist_j[1] * prototype_layer_stride + fmap_height_end_index = fmap_height_start_index + proto_h + fmap_width_start_index = batch_argmin_proto_dist_j[2] * prototype_layer_stride + fmap_width_end_index = fmap_width_start_index + proto_w + + batch_min_fmap_patch_j = protoL_input_[img_index_in_batch, + :, + fmap_height_start_index:fmap_height_end_index, + fmap_width_start_index:fmap_width_end_index] + + global_min_proto_dist[j] = batch_min_proto_dist_j + global_min_fmap_patches[j] = batch_min_fmap_patch_j + + # get the receptive field boundary of the image patch + # that generates the representation + protoL_rf_info = self.prototype_network_parallel.module.proto_layer_rf_info + rf_prototype_j = compute_rf_prototype(search_batch.size(2), batch_argmin_proto_dist_j, protoL_rf_info) + + # get the whole image + original_img_j = search_batch_input[rf_prototype_j[0]] + original_img_j = original_img_j.numpy() + original_img_j = np.transpose(original_img_j, (1, 2, 0)) + original_img_size = original_img_j.shape[0] + + # crop out the receptive field + rf_img_j = original_img_j[rf_prototype_j[1]:rf_prototype_j[2], + rf_prototype_j[3]:rf_prototype_j[4], :] + + # save the prototype receptive field information + proto_rf_boxes[j, 0] = rf_prototype_j[0] + start_index_of_search_batch + proto_rf_boxes[j, 1] = rf_prototype_j[1] + proto_rf_boxes[j, 2] = rf_prototype_j[2] + proto_rf_boxes[j, 3] = rf_prototype_j[3] + proto_rf_boxes[j, 4] = rf_prototype_j[4] + if proto_rf_boxes.shape[1] == 6 and search_y is not None: + proto_rf_boxes[j, 5] = search_y[rf_prototype_j[0]].item() + + # find the highly activated region of the original image + proto_dist_img_j = proto_dist_[img_index_in_batch, j, :, :] + if self.prototype_network_parallel.module.prototype_activation_function == 'log': + proto_act_img_j = np.log((proto_dist_img_j + 1) / (proto_dist_img_j + self.prototype_network_parallel.module.epsilon)) + elif self.prototype_network_parallel.module.prototype_activation_function == 'linear': + proto_act_img_j = max_dist - proto_dist_img_j + else: + proto_act_img_j = self.prototype_activation_function_in_numpy(proto_dist_img_j) + upsampled_act_img_j = cv2.resize(proto_act_img_j, dsize=(original_img_size, original_img_size), + interpolation=cv2.INTER_CUBIC) + proto_bound_j = find_high_activation_crop(upsampled_act_img_j) + # crop out the image patch with high activation as prototype image + proto_img_j = original_img_j[proto_bound_j[0]:proto_bound_j[1], + proto_bound_j[2]:proto_bound_j[3], :] + + # save the prototype boundary (rectangular boundary of highly activated region) + proto_bound_boxes[j, 0] = proto_rf_boxes[j, 0] + proto_bound_boxes[j, 1] = proto_bound_j[0] + proto_bound_boxes[j, 2] = proto_bound_j[1] + proto_bound_boxes[j, 3] = proto_bound_j[2] + proto_bound_boxes[j, 4] = proto_bound_j[3] + if proto_bound_boxes.shape[1] == 6 and search_y is not None: + proto_bound_boxes[j, 5] = search_y[rf_prototype_j[0]].item() + + if self.dir_for_saving_prototypes is not None: + if self.prototype_self_act_filename_prefix is not None: + # save the numpy array of the prototype self activation + np.save(os.path.join(self.dir_for_saving_prototypes, + self.prototype_self_act_filename_prefix + str(j) + '.npy'), + proto_act_img_j) + if self.prototype_img_filename_prefix is not None: + # save the whole image containing the prototype as png + plt.imsave(os.path.join(self.dir_for_saving_prototypes, + self.prototype_img_filename_prefix + '-original' + str(j) + '.png'), + original_img_j, vmin=0.0, vmax=1.0) - overlayed_rf_img_j = overlayed_original_img_j[rf_prototype_j[1]:rf_prototype_j[2], - rf_prototype_j[3]:rf_prototype_j[4]] - plt.imsave(os.path.join(dir_for_saving_prototypes, - prototype_img_filename_prefix + '-receptive_field_with_self_act' + str(j) + '.png'), - overlayed_rf_img_j, + # overlay (upsampled) self activation on original image and save the result + rescaled_act_img_j = upsampled_act_img_j - np.amin(upsampled_act_img_j) + rescaled_act_img_j = rescaled_act_img_j / np.amax(rescaled_act_img_j) + heatmap = cv2.applyColorMap(np.uint8(255*rescaled_act_img_j), cv2.COLORMAP_JET) + heatmap = np.float32(heatmap) / 255 + heatmap = heatmap[...,::-1] + overlayed_original_img_j = 0.5 * original_img_j + 0.3 * heatmap + plt.imsave(os.path.join(self.dir_for_saving_prototypes, + self.prototype_img_filename_prefix + '-original_with_self_act' + str(j) + '.png'), + overlayed_original_img_j, vmin=0.0, vmax=1.0) - - # save the prototype image (highly activated region of the whole image) - plt.imsave(os.path.join(dir_for_saving_prototypes, - prototype_img_filename_prefix + str(j) + '.png'), - proto_img_j, - vmin=0.0, - vmax=1.0) - - if class_specific: - del class_to_img_index_dict + + # if different from the original (whole) image, save the prototype receptive field as png + if rf_img_j.shape[0] != original_img_size or rf_img_j.shape[1] != original_img_size: + plt.imsave(os.path.join(self.dir_for_saving_prototypes, + self.prototype_img_filename_prefix + '-receptive_field' + str(j) + '.png'), + rf_img_j, + vmin=0.0, + vmax=1.0) + overlayed_rf_img_j = overlayed_original_img_j[rf_prototype_j[1]:rf_prototype_j[2], + rf_prototype_j[3]:rf_prototype_j[4]] + plt.imsave(os.path.join(self.dir_for_saving_prototypes, + self.prototype_img_filename_prefix + '-receptive_field_with_self_act' + str(j) + '.png'), + overlayed_rf_img_j, + vmin=0.0, + vmax=1.0) + + # save the prototype image (highly activated region of the whole image) + plt.imsave(os.path.join(self.dir_for_saving_prototypes, + self.prototype_img_filename_prefix + str(j) + '.png'), + proto_img_j, + vmin=0.0, + vmax=1.0) + + if class_specific: + del class_to_img_index_dict diff --git a/settings.py b/settings.py index 1f7f7f99c1..e900e5c1ad 100644 --- a/settings.py +++ b/settings.py @@ -1,4 +1,4 @@ -base_architecture = 'vgg19' +base_architecture = 'resnet34' img_size = 224 prototype_shape = (2000, 128, 1, 1) num_classes = 200 @@ -10,7 +10,7 @@ data_path = './datasets/cub200_cropped/' train_dir = data_path + 'train_cropped_augmented/' test_dir = data_path + 'test_cropped/' -train_push_dir = data_path + 'train_cropped/' +train_push_dir = data_path + 'train_cropped_augmented/' train_batch_size = 80 test_batch_size = 100 train_push_batch_size = 75 @@ -37,3 +37,8 @@ push_start = 10 push_epochs = [i for i in range(num_train_epochs) if i % 10 == 0] + +# updates on original spec +do_initial_push = False +use_protobank = True +bank_size = 6 diff --git a/train_and_test.py b/train_and_test.py index cb8c1c85df..9ff8f14827 100644 --- a/train_and_test.py +++ b/train_and_test.py @@ -56,12 +56,12 @@ def _train_or_test(model, dataloader, optimizer=None, class_specific=True, use_l avg_separation_cost = \ torch.sum(min_distances * prototypes_of_wrong_class, dim=1) / torch.sum(prototypes_of_wrong_class, dim=1) avg_separation_cost = torch.mean(avg_separation_cost) - + if use_l1_mask: l1_mask = 1 - torch.t(model.module.prototype_class_identity).cuda() l1 = (model.module.last_layer.weight * l1_mask).norm(p=1) else: - l1 = model.module.last_layer.weight.norm(p=1) + l1 = model.module.last_layer.weight.norm(p=1) else: min_distance, _ = torch.min(min_distances, dim=1) @@ -126,7 +126,7 @@ def _train_or_test(model, dataloader, optimizer=None, class_specific=True, use_l def train(model, dataloader, optimizer, class_specific=False, coefs=None, log=print): assert(optimizer is not None) - + log('\ttrain') model.train() return _train_or_test(model=model, dataloader=dataloader, optimizer=optimizer, @@ -146,9 +146,10 @@ def last_only(model, log=print): for p in model.module.add_on_layers.parameters(): p.requires_grad = False model.module.prototype_vectors.requires_grad = False + model.module.protobank_tensor.requires_grad = False for p in model.module.last_layer.parameters(): p.requires_grad = True - + log('\tlast layer') @@ -158,9 +159,10 @@ def warm_only(model, log=print): for p in model.module.add_on_layers.parameters(): p.requires_grad = True model.module.prototype_vectors.requires_grad = True + model.module.protobank_tensor.requires_grad = True for p in model.module.last_layer.parameters(): p.requires_grad = True - + log('\twarm') @@ -170,7 +172,8 @@ def joint(model, log=print): for p in model.module.add_on_layers.parameters(): p.requires_grad = True model.module.prototype_vectors.requires_grad = True + model.module.protobank_tensor.requires_grad = True for p in model.module.last_layer.parameters(): p.requires_grad = True - + log('\tjoint')