From 8e377ab0d8641d27b321ac05816a30a67acbf675 Mon Sep 17 00:00:00 2001 From: wpalka-icg Date: Wed, 9 Apr 2025 13:01:43 +0200 Subject: [PATCH 1/5] layers --- pipnet/cast_layer.py | 13 +++++++++++++ pipnet/clamp_layer.py | 11 +++++++++++ pipnet/gumbel_layer.py | 14 ++++++++++++++ pipnet/one_hot_encoding_layer.py | 19 +++++++++++++++++++ pipnet/sum_layer.py | 10 ++++++++++ 5 files changed, 67 insertions(+) create mode 100644 pipnet/cast_layer.py create mode 100644 pipnet/clamp_layer.py create mode 100644 pipnet/gumbel_layer.py create mode 100644 pipnet/one_hot_encoding_layer.py create mode 100644 pipnet/sum_layer.py diff --git a/pipnet/cast_layer.py b/pipnet/cast_layer.py new file mode 100644 index 0000000..4b67e54 --- /dev/null +++ b/pipnet/cast_layer.py @@ -0,0 +1,13 @@ +import torch +import torch.nn as nn + +class CastLayer(nn.Module): + def __init__(self, dtype, round_before_cast=False): + super(CastLayer, self).__init__() + self.dtype = dtype + self.round_before_cast = round_before_cast + + def forward(self, x): + if self.round_before_cast: + x = torch.round(x) # Apply rounding if flag is enabled + return x.to(self.dtype) \ No newline at end of file diff --git a/pipnet/clamp_layer.py b/pipnet/clamp_layer.py new file mode 100644 index 0000000..0ae8ce5 --- /dev/null +++ b/pipnet/clamp_layer.py @@ -0,0 +1,11 @@ +import torch +import torch.nn as nn + +class ClampLayer(nn.Module): + def __init__(self, min_value, max_value): + super(ClampLayer, self).__init__() + self.min_value = min_value + self.max_value = max_value + + def forward(self, x): + return torch.clamp(x, min=self.min_value, max=self.max_value) \ No newline at end of file diff --git a/pipnet/gumbel_layer.py b/pipnet/gumbel_layer.py new file mode 100644 index 0000000..27d2f0a --- /dev/null +++ b/pipnet/gumbel_layer.py @@ -0,0 +1,14 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class GumbelSoftmaxLayer(nn.Module): + def __init__(self, tau=1.0, hard=False, eps=None, dim=None): + super().__init__() + self.tau = tau + self.hard = hard + self.eps = eps + self.dim = dim + + def forward(self, logits): + return F.gumbel_softmax(logits, tau=self.tau, hard=self.hard, eps=self.eps, dim=self.dim) \ No newline at end of file diff --git a/pipnet/one_hot_encoding_layer.py b/pipnet/one_hot_encoding_layer.py new file mode 100644 index 0000000..aa13535 --- /dev/null +++ b/pipnet/one_hot_encoding_layer.py @@ -0,0 +1,19 @@ +import torch +import torch.nn as nn + +class SoftOneHotEncodingLayer(nn.Module): + def __init__(self, num_classes, temperature=0.5): + super(SoftOneHotEncodingLayer, self).__init__() + self.num_classes = num_classes + self.temperature = temperature + + def forward(self, x): + # Compute softmax over classes for probabilistic one-hot + x = x.unsqueeze(-1) # Add dimension for broadcasting + indices = torch.arange(self.num_classes, device=x.device).float() + logits = -torch.abs(x - indices - 1) / self.temperature + probabilities = torch.softmax(logits, dim=-1) + mask = torch.clamp(x, 0, 1) + result = probabilities * mask + + return result \ No newline at end of file diff --git a/pipnet/sum_layer.py b/pipnet/sum_layer.py new file mode 100644 index 0000000..bb6a5bc --- /dev/null +++ b/pipnet/sum_layer.py @@ -0,0 +1,10 @@ +import torch +import torch.nn as nn + +class SumLayer(nn.Module): + def __init__(self, dim=None): + super().__init__() + self.dim = dim + + def forward(self, input): + return torch.sum(input, dim=self.dim) \ No newline at end of file From 0b47f54e6666c1220edd9d85b88dce1b16089e71 Mon Sep 17 00:00:00 2001 From: wpalka-icg Date: Wed, 9 Apr 2025 13:05:56 +0200 Subject: [PATCH 2/5] working tensorboard --- main.py | 34 +++++++++----- pipnet/pipnet.py | 23 ++++++++-- pipnet/test.py | 112 ++++++++++++++++++++++++++++++++++++++++++--- pipnet/train.py | 20 ++++++-- util/args.py | 4 ++ util/vis_pipnet.py | 30 ++++++++++-- 6 files changed, 192 insertions(+), 31 deletions(-) diff --git a/main.py b/main.py index 0ecf0a7..5372154 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,7 @@ +from datetime import datetime + +from torch.utils.tensorboard import SummaryWriter + from pipnet.pipnet import PIPNet, get_network from util.log import Log import torch.nn as nn @@ -56,6 +60,7 @@ def run_pipnet(args=None): device = torch.device('cuda:'+str(device_ids[0])) else: device = torch.device('cpu') + device_ids.append(0) # Log which device was actually used print("Device used: ", device, "with id", device_ids, flush=True) @@ -81,9 +86,12 @@ def run_pipnet(args=None): classification_layer = classification_layer ) net = net.to(device=device) - net = nn.DataParallel(net, device_ids = device_ids) + net = nn.DataParallel(net, device_ids = device_ids) - optimizer_net, optimizer_classifier, params_to_freeze, params_to_train, params_backbone = get_optimizer_nn(net, args) + optimizer_net, optimizer_classifier, params_to_freeze, params_to_train, params_backbone = get_optimizer_nn(net, args) + global_epoch = 0 + dirname = datetime.now().strftime("%Y_%m_%d__%H_%M_%S") + tb_writer = SummaryWriter(f"tensorboard/{dirname}") # Initialize or load model with torch.no_grad(): @@ -157,13 +165,16 @@ def run_pipnet(args=None): print("\nPretrain Epoch", epoch, "with batch size", trainloader_pretraining.batch_size, flush=True) # Pretrain prototypes - train_info = train_pipnet(net, trainloader_pretraining, optimizer_net, optimizer_classifier, scheduler_net, None, criterion, epoch, args.epochs_pretrain, device, pretrain=True, finetune=False) + train_info = train_pipnet(net, trainloader_pretraining, optimizer_net, optimizer_classifier, scheduler_net, None, criterion, epoch, args.epochs_pretrain, global_epoch, device, pretrain=True, finetune=False, tb_writer=tb_writer) lrs_pretrain_net+=train_info['lrs_net'] + eval_info = eval_pipnet(net, testloader, global_epoch, device, log, tensorboard=tb_writer) + plt.clf() plt.plot(lrs_pretrain_net) plt.savefig(os.path.join(args.log_dir,'lr_pretrain_net.png')) log.log_values('log_epoch_overview', epoch, "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", "n.a.", train_info['loss']) - + global_epoch += 1 + if args.state_dict_dir_net == '': net.eval() torch.save({'model_state_dict': net.state_dict(), 'optimizer_net_state_dict': optimizer_net.state_dict()}, os.path.join(os.path.join(args.log_dir, 'checkpoints'), 'net_pretrained')) @@ -239,11 +250,11 @@ def run_pipnet(args=None): print("Classifier bias: ", net.module._classification.bias, flush=True) torch.set_printoptions(profile="default") - train_info = train_pipnet(net, trainloader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, args.epochs, device, pretrain=False, finetune=finetune) + train_info = train_pipnet(net, trainloader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, args.epochs, global_epoch, device, pretrain=False, finetune=finetune) lrs_net+=train_info['lrs_net'] lrs_classifier+=train_info['lrs_class'] # Evaluate model - eval_info = eval_pipnet(net, testloader, epoch, device, log) + eval_info = eval_pipnet(net, testloader, global_epoch, device, log) log.log_values('log_epoch_overview', epoch, eval_info['top1_accuracy'], eval_info['top5_accuracy'], eval_info['almost_sim_nonzeros'], eval_info['local_size_all_classes'], eval_info['almost_nonzeros'], eval_info['num non-zero prototypes'], train_info['train_accuracy'], train_info['loss']) with torch.no_grad(): @@ -261,6 +272,7 @@ def run_pipnet(args=None): plt.clf() plt.plot(lrs_classifier) plt.savefig(os.path.join(args.log_dir,'lr_class.png')) + global_epoch += 1 net.eval() torch.save({'model_state_dict': net.state_dict(), 'optimizer_net_state_dict': optimizer_net.state_dict(), 'optimizer_classifier_state_dict': optimizer_classifier.state_dict()}, os.path.join(os.path.join(args.log_dir, 'checkpoints'), 'net_trained_last')) @@ -361,11 +373,11 @@ def run_pipnet(args=None): tqdm_dir = os.path.join(args.log_dir,'tqdm.txt') if not os.path.isdir(args.log_dir): os.mkdir(args.log_dir) - - sys.stdout.close() - sys.stderr.close() - sys.stdout = open(print_dir, 'w') - sys.stderr = open(tqdm_dir, 'w') + print(torch.cuda.is_available()) + # sys.stdout.close() + # sys.stderr.close() + # sys.stdout = open(print_dir, 'w') + # sys.stderr = open(tqdm_dir, 'w') run_pipnet(args) sys.stdout.close() diff --git a/pipnet/pipnet.py b/pipnet/pipnet.py index 3722bd8..8a48b05 100644 --- a/pipnet/pipnet.py +++ b/pipnet/pipnet.py @@ -7,6 +7,13 @@ import torch from torch import Tensor +from pipnet.cast_layer import CastLayer +from pipnet.clamp_layer import ClampLayer +from pipnet.gumbel_layer import GumbelSoftmaxLayer +from pipnet.one_hot_encoding_layer import SoftOneHotEncodingLayer +from pipnet.sum_layer import SumLayer + + class PIPNet(nn.Module): def __init__(self, num_classes: int, @@ -87,7 +94,7 @@ def get_network(num_classes: int, args: argparse.Namespace): num_prototypes = first_add_on_layer_in_channels print("Number of prototypes: ", num_prototypes, flush=True) add_on_layers = nn.Sequential( - nn.Softmax(dim=1), #softmax over every prototype for each patch, such that for every location in image, sum over prototypes is 1 + GumbelSoftmaxLayer(dim=1, tau=args.temperature), #softmax over every prototype for each patch, such that for every location in image, sum over prototypes is 1 ) else: num_prototypes = args.num_features @@ -99,12 +106,20 @@ def get_network(num_classes: int, args: argparse.Namespace): pool_layer = nn.Sequential( nn.AdaptiveMaxPool2d(output_size=(1,1)), #outputs (bs, ps,1,1) nn.Flatten() #outputs (bs, ps) - ) + ) + max_class_occurrences = 3 + pool_layer = nn.Sequential( + SumLayer(dim=[2,3]), #sum over all patches, outputs (bs, ps) + ClampLayer(0, max_class_occurrences), #clamp to 3 + # CastLayer(torch.int, round_before_cast=True), #cast to int + SoftOneHotEncodingLayer(max_class_occurrences, temperature=args.temperature), #one-hot encoding (bs, ps, 4) + nn.Flatten() + ) if args.bias: - classification_layer = NonNegLinear(num_prototypes, num_classes, bias=True) + classification_layer = NonNegLinear(num_prototypes * max_class_occurrences, num_classes, bias=True) else: - classification_layer = NonNegLinear(num_prototypes, num_classes, bias=False) + classification_layer = NonNegLinear(num_prototypes * max_class_occurrences, num_classes, bias=False) return features, add_on_layers, pool_layer, classification_layer, num_prototypes diff --git a/pipnet/test.py b/pipnet/test.py index 18d8238..151b6d2 100644 --- a/pipnet/test.py +++ b/pipnet/test.py @@ -1,3 +1,11 @@ +import io +import itertools +import math +from typing import List + +from matplotlib import pyplot as plt +from scipy.interpolate import make_interp_spline +from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import numpy as np import torch @@ -13,7 +21,8 @@ def eval_pipnet(net, test_loader: DataLoader, epoch, device, - log: Log = None, + log: Log = None, + tensorboard: SummaryWriter|None = None, progress_prefix: str = 'Eval Epoch' ) -> dict: @@ -41,6 +50,7 @@ def eval_pipnet(net, mininterval=5., ncols=0) (xs, ys) = next(iter(test_loader)) + inputs, results = [], [] # Iterate through the test set for i, (xs, ys) in test_iter: xs, ys = xs.to(device), ys.to(device) @@ -81,14 +91,20 @@ def eval_pipnet(net, y_preds += ys_pred_scores.detach().tolist() y_trues += ys.detach().tolist() y_preds_classes += ys_pred.detach().tolist() + inputs.append(xs) + results.append((pooled, out, ys_pred)) + - del out - del pooled - del ys_pred + # del out + # del pooled + # del ys_pred - print("PIP-Net abstained from a decision for", abstained.item(), "images", flush=True) + print("PIP-Net abstained from a decision for", abstained.item(), "images", flush=True) + info['abstained'] = abstained.item() info['num non-zero prototypes'] = torch.gt(net.module._classification.weight,1e-3).any(dim=0).sum().item() - print("sparsity ratio: ", (torch.numel(net.module._classification.weight)-torch.count_nonzero(torch.nn.functional.relu(net.module._classification.weight-1e-3)).item()) / torch.numel(net.module._classification.weight), flush=True) + sparsity_ratio = (torch.numel(net.module._classification.weight)-torch.count_nonzero(torch.nn.functional.relu(net.module._classification.weight-1e-3)).item()) / torch.numel(net.module._classification.weight) + print("sparsity ratio: ", sparsity_ratio, flush=True) + info['sparsity ratio'] = sparsity_ratio info['confusion_matrix'] = cm info['test_accuracy'] = acc_from_cm(cm) info['top1_accuracy'] = global_top1acc/len(test_loader.dataset) @@ -97,6 +113,9 @@ def eval_pipnet(net, info['local_size_all_classes'] = local_size_total / len(test_loader.dataset) info['almost_nonzeros'] = global_anz/len(test_loader.dataset) + if tensorboard is not None: + log_to_tensorboard(info, net.module, inputs, results, classes=test_loader.dataset.class_to_idx.values(), global_epoch=epoch, tb_writer=tensorboard) + if net.module._num_classes == 2: tp = cm[0][0] fn = cm[0][1] @@ -128,6 +147,87 @@ def eval_pipnet(net, return info +def log_to_tensorboard(info: dict, model_module, inp, preds, classes: List[str], global_epoch: int, tb_writer: SummaryWriter): + for key, value in info.items(): + match key: + case 'confusion_matrix': + # figure = plot_confusion_matrix(value, class_names=classes) + # image = plot_to_image(figure) + # tb_writer.add_image(key, image, global_epoch) + pass + case _: + tb_writer.add_scalar(key, value, global_epoch) + + if hasattr(model_module, '_classification'): + weights = model_module._classification.weight.flatten() + tb_writer.add_figure('classification_weights', plot_tensor(weights, samples=66), global_epoch) + + inp = inp[0] + tb_writer.add_graph(model_module, inp) + pooled, out, ys_pred = [torch.cat([pred[i] for pred in preds], dim=0) for i in range(3)] + tb_writer.add_figure('pooled', plot_batch(pooled), global_epoch) + tb_writer.add_figure('out', plot_batch(out), global_epoch) + tb_writer.add_histogram('ys_pred', ys_pred, global_epoch) + +def plot_batch(tensor: torch.Tensor, samples = 8, title: str = None): + batch_size = tensor.shape[0] + if batch_size > samples: + tensor = tensor[::batch_size//samples] + size = len(tensor) + plt.figure(figsize=(10, 10)) + for i in range(len(tensor)): + plt.subplot(math.ceil(size / 2), 2, i+1) + plot_tensor(tensor[i]) + return plt.gcf() + + +def plot_tensor(tensor: torch.Tensor, title: str = None, samples: int = 50): + array = tensor.cpu().numpy() + array = smooth_array(array, samples) + plt.plot(array) + if title is not None: + plt.title(title) + return plt.gcf() + +def smooth_array(array: np.ndarray, samples: int = 50): + window_size = len(array) // samples + kernel = np.ones(window_size) / window_size + smoothed_tensor = np.convolve(array, kernel, mode='valid') + return smoothed_tensor + +def plot_confusion_matrix(cm, class_names): + figure = plt.figure(figsize=(8, 8)) + plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Accent) + plt.title("Confusion matrix") + plt.colorbar() + tick_marks = np.arange(len(class_names)) + plt.xticks(tick_marks, class_names, rotation=45) + plt.yticks(tick_marks, class_names) + + cm = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2) + threshold = cm.max() / 2. + + for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): + color = "white" if cm[i, j] > threshold else "black" + plt.text(j, i, cm[i, j], horizontalalignment="center", color=color) + + plt.tight_layout() + plt.ylabel('True label') + plt.xlabel('Predicted label') + + return figure + +def plot_to_image(figure): + figure.canvas.draw() + data = np.frombuffer(figure.canvas.tostring_argb(), dtype=np.uint8) + data = data.reshape(figure.canvas.get_width_height()[::-1] + (4,)) # (H, W, 4) + + data = data[:, :, [1, 2, 3]] # Drop the alpha channel (0th index) + tensor = torch.tensor(data) + tensor = tensor.float() / 255.0 + plt.close(figure) + return tensor.permute(2, 0, 1) + def acc_from_cm(cm: np.ndarray) -> float: """ Compute the accuracy from the confusion matrix diff --git a/pipnet/train.py b/pipnet/train.py index bb013ef..5cc385f 100644 --- a/pipnet/train.py +++ b/pipnet/train.py @@ -3,9 +3,10 @@ import torch.nn.functional as F import torch.optim import torch.utils.data -import math +from torch.utils.tensorboard import SummaryWriter -def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, nr_epochs, device, pretrain=False, finetune=False, progress_prefix: str = 'Train Epoch'): + +def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, nr_epochs, global_epoch, device, pretrain=False, finetune=False, progress_prefix: str = 'Train Epoch', tb_writer: SummaryWriter = None): # Make sure the model is in train mode net.train() @@ -65,7 +66,7 @@ def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, schedul # Perform a forward pass through the network proto_features, pooled, out = net(torch.cat([xs1, xs2])) - loss, acc = calculate_loss(proto_features, pooled, out, ys, align_pf_weight, t_weight, unif_weight, cl_weight, net.module._classification.normalization_multiplier, pretrain, finetune, criterion, train_iter, print=True, EPS=1e-8) + loss, acc = calculate_loss(proto_features, pooled, out, ys, align_pf_weight, t_weight, unif_weight, cl_weight, net.module._classification.normalization_multiplier, pretrain, finetune, criterion, train_iter, global_epoch, print=True, EPS=1e-8, tb_writer=tb_writer) # Compute the gradient loss.backward() @@ -99,7 +100,7 @@ def train_pipnet(net, train_loader, optimizer_net, optimizer_classifier, schedul return train_info -def calculate_loss(proto_features, pooled, out, ys1, align_pf_weight, t_weight, unif_weight, cl_weight, net_normalization_multiplier, pretrain, finetune, criterion, train_iter, print=True, EPS=1e-10): +def calculate_loss(proto_features, pooled, out, ys1, align_pf_weight, t_weight, unif_weight, cl_weight, net_normalization_multiplier, pretrain, finetune, criterion, train_iter, global_epoch, print=True, EPS=1e-10, tb_writer=None): ys = torch.cat([ys1,ys1]) pooled1, pooled2 = pooled.chunk(2) pf1, pf2 = proto_features.chunk(2) @@ -132,7 +133,16 @@ def calculate_loss(proto_features, pooled, out, ys1, align_pf_weight, t_weight, ys_pred_max = torch.argmax(out, dim=1) correct = torch.sum(torch.eq(ys_pred_max, ys)) acc = correct.item() / float(len(ys)) - if print: + if print: + if tb_writer is not None: + scalars = {'loss': loss.item(), 'align_loss': a_loss_pf.item(), 'tanh_loss': tanh_loss.item(), + 'acc': acc, 't_weight': t_weight, 'align_pf_weight': align_pf_weight, 'cl_weight': cl_weight, + 'num_scores>0.1': torch.count_nonzero(torch.relu(pooled-0.1),dim=1).float().mean().item()} + if not pretrain: + scalars['class_loss'] = class_loss.item() + for key, value in scalars.items(): + tb_writer.add_scalar(key, value, global_epoch) + with torch.no_grad(): if pretrain: train_iter.set_postfix_str( diff --git a/util/args.py b/util/args.py index 24c3ede..d942cec 100644 --- a/util/args.py +++ b/util/args.py @@ -117,6 +117,10 @@ def get_args() -> argparse.Namespace: type=str, default='./experiments', help='Folder with images that PIP-Net will predict and explain, that are not in the training or test set. E.g. images with 2 objects or OOD image. Images should be in subfolder. E.g. images in ./experiments/images/, and argument --./experiments') + parser.add_argument('--temperature', + type=float, + default=0.5, + help='') args = parser.parse_args() if len(args.log_dir.split('/'))>2: diff --git a/util/vis_pipnet.py b/util/vis_pipnet.py index 055db1a..b769e71 100644 --- a/util/vis_pipnet.py +++ b/util/vis_pipnet.py @@ -1,3 +1,5 @@ +import cv2 +import numpy as np from tqdm import tqdm import argparse import torch @@ -115,12 +117,14 @@ def visualize_topk(net, projectloader, num_classes, device, foldername, args: ar max_per_prototype, max_idx_per_prototype = torch.max(softmaxes, dim=0) max_per_prototype_h, max_idx_per_prototype_h = torch.max(max_per_prototype, dim=1) max_per_prototype_w, max_idx_per_prototype_w = torch.max(max_per_prototype_h, dim=1) #shape (num_prototypes) - + + occurrences_per_prototype = 3 + _p = int(p / occurrences_per_prototype) c_weight = torch.max(classification_weights[:,p]) #ignore prototypes that are not relevant to any class if (c_weight > 1e-10) or ('pretrain' in foldername): - h_idx = max_idx_per_prototype_h[p, max_idx_per_prototype_w[p]] - w_idx = max_idx_per_prototype_w[p] + h_idx = max_idx_per_prototype_h[_p, max_idx_per_prototype_w[_p]] + w_idx = max_idx_per_prototype_w[_p] img_to_open = imgs[i] if isinstance(img_to_open, tuple) or isinstance(img_to_open, list): #dataset contains tuples of (img,label) @@ -131,8 +135,8 @@ def visualize_topk(net, projectloader, num_classes, device, foldername, args: ar h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, softmaxes.shape, patchsize, skip, h_idx, w_idx) img_tensor_patch = img_tensor[0, :, h_coor_min:h_coor_max, w_coor_min:w_coor_max] - saved[p]+=1 - tensors_per_prototype[p].append(img_tensor_patch) + saved[_p]+=1 + tensors_per_prototype[_p].append(img_tensor_patch) print("Abstained: ", abstained, flush=True) all_tensors = [] @@ -159,6 +163,22 @@ def visualize_topk(net, projectloader, num_classes, device, foldername, args: ar else: print("Pretrained prototypes not visualized. Try to pretrain longer.", flush=True) return topks + + +def add_text_to_tensor(image_tensor, text, position, font_scale=1, color=(255, 255, 255), thickness=2): + # Convert the tensor to a NumPy array + image_np = image_tensor.permute(1, 2, 0).numpy() # Change shape to [H, W, C] + + # Normalize the image to 0-255 and convert to uint8 + image_np = (image_np * 255).astype(np.uint8) + + # Add text to the image using OpenCV + image_with_text = cv2.putText(image_np, text, position, cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness) + + # Convert the NumPy array back to a PyTorch tensor + modified_tensor = torch.from_numpy(image_with_text.astype(np.float32) / 255).permute(2, 0, 1) + + return modified_tensor def visualize(net, projectloader, num_classes, device, foldername, args: argparse.Namespace): From 23dfd328aeb5578b360d524b2dc71faa6a123b23 Mon Sep 17 00:00:00 2001 From: wpalka-icg Date: Tue, 15 Apr 2025 19:06:07 +0200 Subject: [PATCH 3/5] back to classic --- pipnet/pipnet.py | 21 +++------------------ util/vis_pipnet.py | 26 ++++---------------------- 2 files changed, 7 insertions(+), 40 deletions(-) diff --git a/pipnet/pipnet.py b/pipnet/pipnet.py index 8a48b05..5f30b99 100644 --- a/pipnet/pipnet.py +++ b/pipnet/pipnet.py @@ -7,13 +7,6 @@ import torch from torch import Tensor -from pipnet.cast_layer import CastLayer -from pipnet.clamp_layer import ClampLayer -from pipnet.gumbel_layer import GumbelSoftmaxLayer -from pipnet.one_hot_encoding_layer import SoftOneHotEncodingLayer -from pipnet.sum_layer import SumLayer - - class PIPNet(nn.Module): def __init__(self, num_classes: int, @@ -94,7 +87,7 @@ def get_network(num_classes: int, args: argparse.Namespace): num_prototypes = first_add_on_layer_in_channels print("Number of prototypes: ", num_prototypes, flush=True) add_on_layers = nn.Sequential( - GumbelSoftmaxLayer(dim=1, tau=args.temperature), #softmax over every prototype for each patch, such that for every location in image, sum over prototypes is 1 + nn.Softmax(dim=1), #softmax over every prototype for each patch, such that for every location in image, sum over prototypes is 1 ) else: num_prototypes = args.num_features @@ -107,19 +100,11 @@ def get_network(num_classes: int, args: argparse.Namespace): nn.AdaptiveMaxPool2d(output_size=(1,1)), #outputs (bs, ps,1,1) nn.Flatten() #outputs (bs, ps) ) - max_class_occurrences = 3 - pool_layer = nn.Sequential( - SumLayer(dim=[2,3]), #sum over all patches, outputs (bs, ps) - ClampLayer(0, max_class_occurrences), #clamp to 3 - # CastLayer(torch.int, round_before_cast=True), #cast to int - SoftOneHotEncodingLayer(max_class_occurrences, temperature=args.temperature), #one-hot encoding (bs, ps, 4) - nn.Flatten() - ) if args.bias: - classification_layer = NonNegLinear(num_prototypes * max_class_occurrences, num_classes, bias=True) + classification_layer = NonNegLinear(num_prototypes, num_classes, bias=True) else: - classification_layer = NonNegLinear(num_prototypes * max_class_occurrences, num_classes, bias=False) + classification_layer = NonNegLinear(num_prototypes, num_classes, bias=False) return features, add_on_layers, pool_layer, classification_layer, num_prototypes diff --git a/util/vis_pipnet.py b/util/vis_pipnet.py index b769e71..6fee1e9 100644 --- a/util/vis_pipnet.py +++ b/util/vis_pipnet.py @@ -118,13 +118,11 @@ def visualize_topk(net, projectloader, num_classes, device, foldername, args: ar max_per_prototype_h, max_idx_per_prototype_h = torch.max(max_per_prototype, dim=1) max_per_prototype_w, max_idx_per_prototype_w = torch.max(max_per_prototype_h, dim=1) #shape (num_prototypes) - occurrences_per_prototype = 3 - _p = int(p / occurrences_per_prototype) c_weight = torch.max(classification_weights[:,p]) #ignore prototypes that are not relevant to any class if (c_weight > 1e-10) or ('pretrain' in foldername): - h_idx = max_idx_per_prototype_h[_p, max_idx_per_prototype_w[_p]] - w_idx = max_idx_per_prototype_w[_p] + h_idx = max_idx_per_prototype_h[p, max_idx_per_prototype_w[p]] + w_idx = max_idx_per_prototype_w[p] img_to_open = imgs[i] if isinstance(img_to_open, tuple) or isinstance(img_to_open, list): #dataset contains tuples of (img,label) @@ -135,8 +133,8 @@ def visualize_topk(net, projectloader, num_classes, device, foldername, args: ar h_coor_min, h_coor_max, w_coor_min, w_coor_max = get_img_coordinates(args.image_size, softmaxes.shape, patchsize, skip, h_idx, w_idx) img_tensor_patch = img_tensor[0, :, h_coor_min:h_coor_max, w_coor_min:w_coor_max] - saved[_p]+=1 - tensors_per_prototype[_p].append(img_tensor_patch) + saved[p]+=1 + tensors_per_prototype[p].append(img_tensor_patch) print("Abstained: ", abstained, flush=True) all_tensors = [] @@ -163,22 +161,6 @@ def visualize_topk(net, projectloader, num_classes, device, foldername, args: ar else: print("Pretrained prototypes not visualized. Try to pretrain longer.", flush=True) return topks - - -def add_text_to_tensor(image_tensor, text, position, font_scale=1, color=(255, 255, 255), thickness=2): - # Convert the tensor to a NumPy array - image_np = image_tensor.permute(1, 2, 0).numpy() # Change shape to [H, W, C] - - # Normalize the image to 0-255 and convert to uint8 - image_np = (image_np * 255).astype(np.uint8) - - # Add text to the image using OpenCV - image_with_text = cv2.putText(image_np, text, position, cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, thickness) - - # Convert the NumPy array back to a PyTorch tensor - modified_tensor = torch.from_numpy(image_with_text.astype(np.float32) / 255).permute(2, 0, 1) - - return modified_tensor def visualize(net, projectloader, num_classes, device, foldername, args: argparse.Namespace): From 31621b6fb526a61bbb4a79beff7684862d235696 Mon Sep 17 00:00:00 2001 From: wpalka-icg Date: Wed, 16 Apr 2025 10:11:49 +0200 Subject: [PATCH 4/5] fixed tb --- main.py | 6 +++--- pipnet/test.py | 47 +++++++++++++++++++++++++++++++++++++---------- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 5372154..1c2cd2c 100644 --- a/main.py +++ b/main.py @@ -89,7 +89,7 @@ def run_pipnet(args=None): net = nn.DataParallel(net, device_ids = device_ids) optimizer_net, optimizer_classifier, params_to_freeze, params_to_train, params_backbone = get_optimizer_nn(net, args) - global_epoch = 0 + global_epoch = 1 dirname = datetime.now().strftime("%Y_%m_%d__%H_%M_%S") tb_writer = SummaryWriter(f"tensorboard/{dirname}") @@ -250,11 +250,11 @@ def run_pipnet(args=None): print("Classifier bias: ", net.module._classification.bias, flush=True) torch.set_printoptions(profile="default") - train_info = train_pipnet(net, trainloader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, args.epochs, global_epoch, device, pretrain=False, finetune=finetune) + train_info = train_pipnet(net, trainloader, optimizer_net, optimizer_classifier, scheduler_net, scheduler_classifier, criterion, epoch, args.epochs, global_epoch, device, pretrain=False, finetune=finetune, tb_writer=tb_writer) lrs_net+=train_info['lrs_net'] lrs_classifier+=train_info['lrs_class'] # Evaluate model - eval_info = eval_pipnet(net, testloader, global_epoch, device, log) + eval_info = eval_pipnet(net, testloader, global_epoch, device, log, tensorboard=tb_writer) log.log_values('log_epoch_overview', epoch, eval_info['top1_accuracy'], eval_info['top5_accuracy'], eval_info['almost_sim_nonzeros'], eval_info['local_size_all_classes'], eval_info['almost_nonzeros'], eval_info['num non-zero prototypes'], train_info['train_accuracy'], train_info['loss']) with torch.no_grad(): diff --git a/pipnet/test.py b/pipnet/test.py index 151b6d2..356e5ea 100644 --- a/pipnet/test.py +++ b/pipnet/test.py @@ -115,6 +115,7 @@ def eval_pipnet(net, if tensorboard is not None: log_to_tensorboard(info, net.module, inputs, results, classes=test_loader.dataset.class_to_idx.values(), global_epoch=epoch, tb_writer=tensorboard) + tensorboard.flush() if net.module._num_classes == 2: tp = cm[0][0] @@ -159,8 +160,12 @@ def log_to_tensorboard(info: dict, model_module, inp, preds, classes: List[str], tb_writer.add_scalar(key, value, global_epoch) if hasattr(model_module, '_classification'): - weights = model_module._classification.weight.flatten() - tb_writer.add_figure('classification_weights', plot_tensor(weights, samples=66), global_epoch) + weights = model_module._classification.weight.flatten().cpu() + fig = plot_tensor(weights, samples=66) + tb_writer.add_figure('classification_weights', fig, global_epoch) + + fig = plot_tensor(torch.relu(weights), samples=66) + tb_writer.add_figure('classification_non_neg', fig, global_epoch) inp = inp[0] tb_writer.add_graph(model_module, inp) @@ -172,18 +177,40 @@ def log_to_tensorboard(info: dict, model_module, inp, preds, classes: List[str], def plot_batch(tensor: torch.Tensor, samples = 8, title: str = None): batch_size = tensor.shape[0] if batch_size > samples: - tensor = tensor[::batch_size//samples] - size = len(tensor) - plt.figure(figsize=(10, 10)) - for i in range(len(tensor)): - plt.subplot(math.ceil(size / 2), 2, i+1) - plot_tensor(tensor[i]) - return plt.gcf() + indices = torch.linspace(0, batch_size - 1, steps=samples).long() + sample_tensor = tensor[indices] + else: + sample_tensor = tensor + + stats = [ + ("Max over batch", torch.max(tensor, dim=0).values), + ("Min over batch", torch.min(tensor, dim=0).values), + ("Mean over batch", torch.mean(tensor, dim=0)) + ] + + fig = plt.figure(figsize=(12, 10)) + if title: + fig.suptitle(title, fontsize=16) + + ax = plt.subplot(2, 2, 1) + for s in sample_tensor: + plot_tensor(s, subplot=True) + ax.set_title("Samples") + + for i, (stat_title, stat_tensor) in enumerate(stats, start=2): + ax = plt.subplot(2, 2, i) + plot_tensor(stat_tensor, subplot=True) + ax.set_title(stat_title) + + plt.tight_layout() + return fig -def plot_tensor(tensor: torch.Tensor, title: str = None, samples: int = 50): +def plot_tensor(tensor: torch.Tensor, title: str = None, samples: int = 50, subplot: bool = False): array = tensor.cpu().numpy() array = smooth_array(array, samples) + if not subplot: + plt.figure() plt.plot(array) if title is not None: plt.title(title) From 1e04449ff9c21e709e52004f7c689a2ac8f13511 Mon Sep 17 00:00:00 2001 From: wpalka-icg Date: Wed, 16 Apr 2025 17:06:03 +0200 Subject: [PATCH 5/5] reduced changes --- main.py | 4 ++-- pipnet/cast_layer.py | 13 ------------- pipnet/clamp_layer.py | 11 ----------- pipnet/gumbel_layer.py | 14 -------------- pipnet/one_hot_encoding_layer.py | 19 ------------------- pipnet/sum_layer.py | 10 ---------- pipnet/test.py | 3 --- util/args.py | 4 ---- util/vis_pipnet.py | 2 -- 9 files changed, 2 insertions(+), 78 deletions(-) delete mode 100644 pipnet/cast_layer.py delete mode 100644 pipnet/clamp_layer.py delete mode 100644 pipnet/gumbel_layer.py delete mode 100644 pipnet/one_hot_encoding_layer.py delete mode 100644 pipnet/sum_layer.py diff --git a/main.py b/main.py index 1c2cd2c..7825013 100644 --- a/main.py +++ b/main.py @@ -167,7 +167,7 @@ def run_pipnet(args=None): # Pretrain prototypes train_info = train_pipnet(net, trainloader_pretraining, optimizer_net, optimizer_classifier, scheduler_net, None, criterion, epoch, args.epochs_pretrain, global_epoch, device, pretrain=True, finetune=False, tb_writer=tb_writer) lrs_pretrain_net+=train_info['lrs_net'] - eval_info = eval_pipnet(net, testloader, global_epoch, device, log, tensorboard=tb_writer) + # eval_info = eval_pipnet(net, testloader, global_epoch, device, log, tensorboard=tb_writer) plt.clf() plt.plot(lrs_pretrain_net) @@ -373,7 +373,7 @@ def run_pipnet(args=None): tqdm_dir = os.path.join(args.log_dir,'tqdm.txt') if not os.path.isdir(args.log_dir): os.mkdir(args.log_dir) - print(torch.cuda.is_available()) + # sys.stdout.close() # sys.stderr.close() # sys.stdout = open(print_dir, 'w') diff --git a/pipnet/cast_layer.py b/pipnet/cast_layer.py deleted file mode 100644 index 4b67e54..0000000 --- a/pipnet/cast_layer.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch -import torch.nn as nn - -class CastLayer(nn.Module): - def __init__(self, dtype, round_before_cast=False): - super(CastLayer, self).__init__() - self.dtype = dtype - self.round_before_cast = round_before_cast - - def forward(self, x): - if self.round_before_cast: - x = torch.round(x) # Apply rounding if flag is enabled - return x.to(self.dtype) \ No newline at end of file diff --git a/pipnet/clamp_layer.py b/pipnet/clamp_layer.py deleted file mode 100644 index 0ae8ce5..0000000 --- a/pipnet/clamp_layer.py +++ /dev/null @@ -1,11 +0,0 @@ -import torch -import torch.nn as nn - -class ClampLayer(nn.Module): - def __init__(self, min_value, max_value): - super(ClampLayer, self).__init__() - self.min_value = min_value - self.max_value = max_value - - def forward(self, x): - return torch.clamp(x, min=self.min_value, max=self.max_value) \ No newline at end of file diff --git a/pipnet/gumbel_layer.py b/pipnet/gumbel_layer.py deleted file mode 100644 index 27d2f0a..0000000 --- a/pipnet/gumbel_layer.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -class GumbelSoftmaxLayer(nn.Module): - def __init__(self, tau=1.0, hard=False, eps=None, dim=None): - super().__init__() - self.tau = tau - self.hard = hard - self.eps = eps - self.dim = dim - - def forward(self, logits): - return F.gumbel_softmax(logits, tau=self.tau, hard=self.hard, eps=self.eps, dim=self.dim) \ No newline at end of file diff --git a/pipnet/one_hot_encoding_layer.py b/pipnet/one_hot_encoding_layer.py deleted file mode 100644 index aa13535..0000000 --- a/pipnet/one_hot_encoding_layer.py +++ /dev/null @@ -1,19 +0,0 @@ -import torch -import torch.nn as nn - -class SoftOneHotEncodingLayer(nn.Module): - def __init__(self, num_classes, temperature=0.5): - super(SoftOneHotEncodingLayer, self).__init__() - self.num_classes = num_classes - self.temperature = temperature - - def forward(self, x): - # Compute softmax over classes for probabilistic one-hot - x = x.unsqueeze(-1) # Add dimension for broadcasting - indices = torch.arange(self.num_classes, device=x.device).float() - logits = -torch.abs(x - indices - 1) / self.temperature - probabilities = torch.softmax(logits, dim=-1) - mask = torch.clamp(x, 0, 1) - result = probabilities * mask - - return result \ No newline at end of file diff --git a/pipnet/sum_layer.py b/pipnet/sum_layer.py deleted file mode 100644 index bb6a5bc..0000000 --- a/pipnet/sum_layer.py +++ /dev/null @@ -1,10 +0,0 @@ -import torch -import torch.nn as nn - -class SumLayer(nn.Module): - def __init__(self, dim=None): - super().__init__() - self.dim = dim - - def forward(self, input): - return torch.sum(input, dim=self.dim) \ No newline at end of file diff --git a/pipnet/test.py b/pipnet/test.py index 356e5ea..71bf485 100644 --- a/pipnet/test.py +++ b/pipnet/test.py @@ -1,10 +1,7 @@ -import io import itertools -import math from typing import List from matplotlib import pyplot as plt -from scipy.interpolate import make_interp_spline from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm import numpy as np diff --git a/util/args.py b/util/args.py index d942cec..24c3ede 100644 --- a/util/args.py +++ b/util/args.py @@ -117,10 +117,6 @@ def get_args() -> argparse.Namespace: type=str, default='./experiments', help='Folder with images that PIP-Net will predict and explain, that are not in the training or test set. E.g. images with 2 objects or OOD image. Images should be in subfolder. E.g. images in ./experiments/images/, and argument --./experiments') - parser.add_argument('--temperature', - type=float, - default=0.5, - help='') args = parser.parse_args() if len(args.log_dir.split('/'))>2: diff --git a/util/vis_pipnet.py b/util/vis_pipnet.py index 6fee1e9..1248989 100644 --- a/util/vis_pipnet.py +++ b/util/vis_pipnet.py @@ -1,5 +1,3 @@ -import cv2 -import numpy as np from tqdm import tqdm import argparse import torch