Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 22 additions & 52 deletions backend/app/utils/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from datetime import datetime
from contextlib import contextmanager

from cka import CKA
from cka import compute_cka
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from app.config import UMAP_DATA_SIZE
Expand Down Expand Up @@ -480,49 +480,19 @@ def filter_loader(loader, is_train=False):
print(f"Retrain model not found at {retrain_model_path}")
retrain_model_loaded = False

with CKA(
model1=model_before,
model2=model_after,
model1_name="Before Unlearning",
model2_name="After Unlearning",
model1_layers=detailed_layers,
model2_layers=detailed_layers,
device=device,
) as cka:
# Original comparison: before vs after
forget_train_cka_matrix = cka.compare(forget_class_train_loader)
other_train_cka_matrix = cka.compare(other_classes_train_loader)
forget_test_cka_matrix = cka.compare(forget_class_test_loader)
other_test_cka_matrix = cka.compare(other_classes_test_loader)

# Retrain comparison: retrain vs unlearned
retrain_forget_train_cka_matrix = None
retrain_other_train_cka_matrix = None
retrain_forget_test_cka_matrix = None
retrain_other_test_cka_matrix = None

if retrain_model_loaded and retrain_model is not None:
with CKA(
model1=retrain_model,
model2=model_after,
model1_name="Retrain Model",
model2_name="Unlearned Model",
model1_layers=detailed_layers,
model2_layers=detailed_layers,
device=device,
) as cka_retrain:
retrain_forget_train_cka_matrix = cka_retrain.compare(
forget_class_train_loader
)
retrain_other_train_cka_matrix = cka_retrain.compare(
other_classes_train_loader
)
retrain_forget_test_cka_matrix = cka_retrain.compare(
forget_class_test_loader
)
retrain_other_test_cka_matrix = cka_retrain.compare(
other_classes_test_loader
)
dataloaders = [
forget_class_train_loader,
other_classes_train_loader,
forget_class_test_loader,
other_classes_test_loader,
]

original_results = compute_cka(
model_before, model_after, dataloaders, layers=detailed_layers, device=device
)
retrain_results = compute_cka(
retrain_model, model_after, dataloaders, layers=detailed_layers, device=device
)

def format_cka_results(results):
if results is None:
Expand All @@ -546,23 +516,23 @@ def format_cka_results(results):
"similarity": {
"layers": detailed_layers,
"train": {
"forget_class": format_cka_results(forget_train_cka_matrix),
"other_classes": format_cka_results(other_train_cka_matrix),
"forget_class": format_cka_results(original_results[0]),
"other_classes": format_cka_results(original_results[1]),
},
"test": {
"forget_class": format_cka_results(forget_test_cka_matrix),
"other_classes": format_cka_results(other_test_cka_matrix),
"forget_class": format_cka_results(original_results[2]),
"other_classes": format_cka_results(original_results[3]),
},
},
"similarity_retrain": {
"layers": detailed_layers,
"train": {
"forget_class": format_cka_results(retrain_forget_train_cka_matrix),
"other_classes": format_cka_results(retrain_other_train_cka_matrix),
"forget_class": format_cka_results(retrain_results[0]),
"other_classes": format_cka_results(retrain_results[1]),
},
"test": {
"forget_class": format_cka_results(retrain_forget_test_cka_matrix),
"other_classes": format_cka_results(retrain_other_test_cka_matrix),
"forget_class": format_cka_results(retrain_results[2]),
"other_classes": format_cka_results(retrain_results[3]),
},
}
if retrain_model_loaded
Expand Down
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [
"python-multipart",
"seaborn",
"huggingface_hub",
"pytorch-cka>=0.1.3",
"pytorch-cka>=1.0.1",
]


Expand Down