diff --git a/backend/app/utils/evaluation.py b/backend/app/utils/evaluation.py index 42b480a7..6d62cc43 100644 --- a/backend/app/utils/evaluation.py +++ b/backend/app/utils/evaluation.py @@ -7,7 +7,7 @@ from datetime import datetime from contextlib import contextmanager -from cka import CKA +from cka import compute_cka from torch.utils.data import DataLoader, Subset from torchvision import datasets, transforms from app.config import UMAP_DATA_SIZE @@ -480,49 +480,19 @@ def filter_loader(loader, is_train=False): print(f"Retrain model not found at {retrain_model_path}") retrain_model_loaded = False - with CKA( - model1=model_before, - model2=model_after, - model1_name="Before Unlearning", - model2_name="After Unlearning", - model1_layers=detailed_layers, - model2_layers=detailed_layers, - device=device, - ) as cka: - # Original comparison: before vs after - forget_train_cka_matrix = cka.compare(forget_class_train_loader) - other_train_cka_matrix = cka.compare(other_classes_train_loader) - forget_test_cka_matrix = cka.compare(forget_class_test_loader) - other_test_cka_matrix = cka.compare(other_classes_test_loader) - - # Retrain comparison: retrain vs unlearned - retrain_forget_train_cka_matrix = None - retrain_other_train_cka_matrix = None - retrain_forget_test_cka_matrix = None - retrain_other_test_cka_matrix = None - - if retrain_model_loaded and retrain_model is not None: - with CKA( - model1=retrain_model, - model2=model_after, - model1_name="Retrain Model", - model2_name="Unlearned Model", - model1_layers=detailed_layers, - model2_layers=detailed_layers, - device=device, - ) as cka_retrain: - retrain_forget_train_cka_matrix = cka_retrain.compare( - forget_class_train_loader - ) - retrain_other_train_cka_matrix = cka_retrain.compare( - other_classes_train_loader - ) - retrain_forget_test_cka_matrix = cka_retrain.compare( - forget_class_test_loader - ) - retrain_other_test_cka_matrix = cka_retrain.compare( - other_classes_test_loader - ) + dataloaders = [ + forget_class_train_loader, + other_classes_train_loader, + forget_class_test_loader, + other_classes_test_loader, + ] + + original_results = compute_cka( + model_before, model_after, dataloaders, layers=detailed_layers, device=device + ) + retrain_results = compute_cka( + retrain_model, model_after, dataloaders, layers=detailed_layers, device=device + ) def format_cka_results(results): if results is None: @@ -546,23 +516,23 @@ def format_cka_results(results): "similarity": { "layers": detailed_layers, "train": { - "forget_class": format_cka_results(forget_train_cka_matrix), - "other_classes": format_cka_results(other_train_cka_matrix), + "forget_class": format_cka_results(original_results[0]), + "other_classes": format_cka_results(original_results[1]), }, "test": { - "forget_class": format_cka_results(forget_test_cka_matrix), - "other_classes": format_cka_results(other_test_cka_matrix), + "forget_class": format_cka_results(original_results[2]), + "other_classes": format_cka_results(original_results[3]), }, }, "similarity_retrain": { "layers": detailed_layers, "train": { - "forget_class": format_cka_results(retrain_forget_train_cka_matrix), - "other_classes": format_cka_results(retrain_other_train_cka_matrix), + "forget_class": format_cka_results(retrain_results[0]), + "other_classes": format_cka_results(retrain_results[1]), }, "test": { - "forget_class": format_cka_results(retrain_forget_test_cka_matrix), - "other_classes": format_cka_results(retrain_other_test_cka_matrix), + "forget_class": format_cka_results(retrain_results[2]), + "other_classes": format_cka_results(retrain_results[3]), }, } if retrain_model_loaded diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 2c26a23b..60f518f7 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ "python-multipart", "seaborn", "huggingface_hub", - "pytorch-cka>=0.1.3", + "pytorch-cka>=1.0.1", ]