Skip to content

Commit 6ac7288

Browse files
update entropy TA plot
1 parent cbfb840 commit 6ac7288

File tree

6 files changed

+321
-68
lines changed

6 files changed

+321
-68
lines changed

exp5_SCRUB_and_TA/experiment_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from src.evaluation.membership_inference_attack import MIA
1818
from src.evaluation.unlearning_evaluator import UnlearningEvaluator
1919
from prepare_image_data import get_image_unlearn_data
20-
from prepare_image_data_v2 import get_image_unlearn_data as get_image_unlearn_data_tsne_box
20+
from prepare_image_data_tsne import get_image_unlearn_data as get_image_unlearn_data_tsne_box
2121
import time
2222
import json
2323

exp5_SCRUB_and_TA/mnist_config.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,15 @@ data:
2727
n_channels: 3
2828
# cifar parameters
2929
split_type: "tsne_box" # "tsne_box" or "random"
30-
tsne_box_coordinates:
30+
tsne_box_coordinates:
31+
# x_min: -4.2
32+
# x_max: -2.6
33+
# y_min: -6.25
34+
# y_max: -4.8
3135
# x_min: 10.2 #-6.2
3236
# x_max: 11.0 # -3.4
3337
# y_min: -4.3 # -1.6
34-
# y_max: -0.4 # 1.011
38+
# y_max: -0.4 # 1.0
3539
x_min: -6.2 #10.2
3640
x_max: -3.4 #11.0
3741
y_min: -1.6 # -4.3
@@ -64,7 +68,7 @@ unlearn:
6468
method: "scrub" # "scrub" or "teacher_ascend"
6569
teacher_ascend:
6670
versions: [["ce", false], ["entropy", false], ["ce-retain", false], ["entropy-retain", false], ["ce", true], ["entropy", true], ["ce-retain", true], ["entropy-retain", true], ["ce-retain-no-reg", false], ["entropy-retain-no-reg", false]]
67-
n_epochs: 150
71+
n_epochs: 100
6872
_lambda: 2 # 64
6973
scrub:
7074
alpha: 2

exp5_SCRUB_and_TA/plot_mnist_samples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Add the parent directory to the Python path to allow for package imports
77
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
88

9-
from exp5_SCRUB_and_TA.prepare_image_data_v2 import download_dataset
9+
from exp5_SCRUB_and_TA.prepare_image_data_tsne import download_dataset
1010

1111
def plot_mnist_samples():
1212
"""

exp5_SCRUB_and_TA/prepare_image_data_v2.py renamed to exp5_SCRUB_and_TA/prepare_image_data_tsne.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def split_data_by_tsne_box(boundary: dict = None,
198198
return train_dataset, retain_dataset, forget_dataset, validation_dataset, forget_indices
199199

200200

201-
def get_image_unlearn_data(root_dir: str, batch_size: int, seed: int, boundary: dict = None, subsample_size: int = None):
201+
def get_image_unlearn_data(root_dir: str, batch_size: int, seed: int, boundary: dict = None, subsample_size: int = None, return_tsne_results: bool = False):
202202

203203
if boundary is None:
204204
boundary = {
@@ -211,18 +211,31 @@ def get_image_unlearn_data(root_dir: str, batch_size: int, seed: int, boundary:
211211
# Make a folder if it does not exist
212212
coords_string = f"{boundary['x_min']}_{boundary['x_max']}_{boundary['y_min']}_{boundary['y_max']}"
213213
folder_name = os.path.join(root_dir, f"tsne_box_{coords_string}_{subsample_size}")
214-
214+
# pdb.set_trace()
215215
# load data if it exists
216216
if os.path.exists(os.path.join(folder_name, "data.pt")):
217217
print(f"Loading data from {folder_name}")
218218
data = torch.load(os.path.join(folder_name, "data.pt"), weights_only=False)
219+
220+
if return_tsne_results:
221+
return data['train_loader'], data['retain_loader'], data['forget_loader'], data['validation_loader'], data['forget_indices'], data['tsne_results']
222+
219223
return data['train_loader'], data['retain_loader'], data['forget_loader'], data['validation_loader'], data['forget_indices']
220224

221-
train_dataset, retain_dataset, forget_dataset, test_dataset, forget_indices = split_data_by_tsne_box(
222-
boundary=boundary,
223-
return_tsne_results=False,
224-
subsample_size=subsample_size
225-
)
225+
if return_tsne_results:
226+
train_dataset, retain_dataset, forget_dataset, test_dataset, forget_indices, tsne_results = split_data_by_tsne_box(
227+
boundary=boundary,
228+
return_tsne_results=return_tsne_results,
229+
subsample_size=subsample_size
230+
)
231+
232+
else:
233+
train_dataset, retain_dataset, forget_dataset, test_dataset, forget_indices = split_data_by_tsne_box(
234+
boundary=boundary,
235+
return_tsne_results=return_tsne_results,
236+
subsample_size=subsample_size
237+
)
238+
226239

227240
from prepare_image_data import create_image_dataloaders
228241

@@ -245,9 +258,12 @@ def get_image_unlearn_data(root_dir: str, batch_size: int, seed: int, boundary:
245258
'retain_loader': retain_loader,
246259
'forget_loader': forget_loader,
247260
'validation_loader': validation_loader,
248-
'forget_indices': forget_indices
261+
'forget_indices': forget_indices,
262+
'tsne_results': tsne_results
249263
}, os.path.join(folder_name, "data.pt"))
250264

265+
if return_tsne_results:
266+
return train_loader, retain_loader, forget_loader, validation_loader, forget_indices, tsne_results
251267

252268
return train_loader, retain_loader, forget_loader, validation_loader, forget_indices
253269

exp5_SCRUB_and_TA/ta_scrub_experiment.py

Lines changed: 190 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
import graphviz
1919
import torch.nn as nn
2020
import matplotlib.pyplot as plt
21+
from matplotlib.colors import ListedColormap
22+
from scipy.interpolate import griddata
23+
2124
from torch.utils.data import RandomSampler, SequentialSampler
2225

2326
from src.models.neural_network import NeuralNetRS
@@ -26,9 +29,11 @@
2629
from src.evaluation.membership_inference_attack import MIA
2730
from src.evaluation.unlearning_evaluator import UnlearningEvaluator
2831
from prepare_image_data import get_image_unlearn_data
29-
from prepare_image_data_v2 import get_image_unlearn_data as get_image_unlearn_data_tsne_box
32+
from prepare_image_data_tsne import get_image_unlearn_data as get_image_unlearn_data_tsne_box
3033
import time
3134
import json
35+
import numpy as np
36+
3237

3338
def has_shuffle(dataloader) -> bool:
3439
"""
@@ -140,6 +145,156 @@ def calculate_model_metrics(model, dataloaders, device, model_name=None, save_pa
140145

141146
return metrics
142147

148+
def get_nearest_neighbor_idxs(X, y, tsne_results, num_neighbors: int = 5):
149+
"""
150+
A function that finds the nearest neighbors in t-SNE space to each point in the dataset X.
151+
The neighbors are constrained such that they must belong to the same class as the sample.
152+
153+
Returns:
154+
A numpy array of size [N x num_neighbors] containing the indices in X which have
155+
"""
156+
# calculate distance from every point to every other point in t-SNE space
157+
distances = np.linalg.norm((tsne_results[:, np.newaxis, :] - tsne_results[np.newaxis, :, :]), ord=None, axis=2)
158+
np.fill_diagonal(distances, np.inf) # make sure to exclude diagonal entries
159+
sorted_idxs = np.argsort(distances, axis=1) # sort distances in ascending order and get indices.
160+
161+
labels = y.argmax(dim=-1) # convert from OHE to labels
162+
neighbor_labels = labels[sorted_idxs] # N x N (labels of all neighbors)
163+
164+
same_class_mask = (neighbor_labels == labels[:, None]) # N x N mask where true means the neighbor has the same class
165+
masked_sorted_idxs = np.where(same_class_mask, sorted_idxs, -1) # fill different classes with index -1
166+
167+
nn_idxs = np.full((X.size(0), num_neighbors), -1)
168+
for i in range(X.size(0)):
169+
valid_idxs = sorted_idxs[i][masked_sorted_idxs[i] != -1]
170+
nn_idxs[i, :min(num_neighbors, len(valid_idxs))] = valid_idxs[:num_neighbors]
171+
172+
return nn_idxs
173+
174+
def interpolate_points_generator(X, nn_idxs, n_interp_points: int = 100):
175+
"""
176+
X: N x M matrix of points
177+
nn_idxs: A N x n_neighbors matrix of the same order as X.
178+
"""
179+
N = X.shape[0]
180+
n_neighbors = nn_idxs.shape[1]
181+
interp_range = np.arange(1/n_interp_points, 1, step=1/n_interp_points)
182+
# X_interp = np.zeros((N * n_interp_points * len(interp_range), X.shape[1]))
183+
184+
idx = 0
185+
for i in range(X.shape[0]):
186+
for j in range(n_neighbors):
187+
x = X[i]
188+
x_nbr = X[nn_idxs[i][j]]
189+
for alpha in interp_range:
190+
x_interp = x * alpha + (1-alpha) * x_nbr
191+
yield x_interp.unsqueeze(0)
192+
193+
194+
def interpolate_points_batch(X, nn_idxs, n_interp_points: int = 100, batch_size: int = 512):
195+
from itertools import islice
196+
generator = interpolate_points_generator(X, nn_idxs, n_interp_points)
197+
while True:
198+
batch = list(islice(generator, batch_size))
199+
if not batch:
200+
break
201+
yield torch.cat(batch)
202+
203+
def create_entropy_tsne_plot(model, y,
204+
image_generator,
205+
tsne_results,
206+
tsne_point_generator,
207+
tsne_bounds_x: tuple[float, float],
208+
tsne_bounds_y: tuple[float, float],
209+
n_imgs_per_batch: int = 10,
210+
grid_size: int = 500,
211+
device='cuda'):
212+
213+
# Setup
214+
class_colors = ['#003f5c', '#2f4b7c', '#665191', '#a05195', '#d45087',
215+
'#f95d6a', '#ff7c43', '#ffa600', '#aa5382', '#eea152']
216+
plt.style.use('seaborn-v0_8-paper')
217+
plt.figure(figsize=(8, 6))
218+
219+
y = y.argmax(dim=-1) # from one-hot to label
220+
221+
# Accumulate data
222+
all_tsne = []
223+
all_entropy = []
224+
all_labels = []
225+
226+
itt = 0
227+
while True:
228+
imgs = next(iter(image_generator), None)
229+
tsne_points = next(iter(tsne_point_generator), None)
230+
if imgs is None: #or tsne_points is None:
231+
break
232+
233+
imgs = imgs.to(device)
234+
ys = y[itt * n_imgs_per_batch : (itt + 1) * n_imgs_per_batch]
235+
ys = ys.repeat_interleave(imgs.size(0) // n_imgs_per_batch)
236+
237+
probs = model.inference(imgs)['probabilities']
238+
entropies = -(torch.log(probs + 1e-8) * probs).sum(dim=-1)
239+
240+
all_tsne.append(tsne_points.cpu().numpy())
241+
all_entropy.append(entropies.cpu().numpy())
242+
all_labels.append(ys.cpu().numpy())
243+
244+
itt += 1
245+
246+
all_tsne = np.concatenate(all_tsne, axis=0)
247+
all_entropy = np.concatenate(all_entropy, axis=0)
248+
all_labels = np.concatenate(all_labels, axis=0)
249+
250+
# Interpolate entropy
251+
x_min, x_max = all_tsne[:, 0].min(), all_tsne[:, 0].max()
252+
y_min, y_max = all_tsne[:, 1].min(), all_tsne[:, 1].max()
253+
254+
grid_x, grid_y = np.meshgrid(
255+
np.linspace(x_min, x_max, grid_size),
256+
np.linspace(y_min, y_max, grid_size)
257+
)
258+
259+
grid_entropy = griddata(
260+
all_tsne, all_entropy,
261+
(grid_x, grid_y),
262+
#method='linear' # This leaves NaNs in missing regions
263+
)
264+
265+
# Set colormap with transparency for missing areas
266+
cmap = plt.cm.viridis.copy()
267+
cmap.set_bad(color=(1, 1, 1, 0)) # Fully transparent for NaNs
268+
269+
# Show entropy heatmap with transparency for missing data
270+
plt.imshow(
271+
grid_entropy,
272+
extent=(x_min, x_max, y_min, y_max),
273+
origin='lower',
274+
cmap=cmap,
275+
alpha=1.0,
276+
aspect='auto'
277+
)
278+
plt.colorbar(label='Entropy')
279+
280+
# Overlay t-SNE points
281+
plt.scatter(tsne_results[:, 0], tsne_results[:, 1],
282+
c=y, cmap=ListedColormap(class_colors),
283+
s=10, alpha=0.2)
284+
285+
plt.title("Interpolated Entropy in t-SNE Space")
286+
plt.xlabel("t-SNE dim 1")
287+
plt.ylabel("t-SNE dim 2")
288+
plt.tight_layout()
289+
plt.savefig('grid_entropies.png', dpi=300)
290+
import pdb; pdb.set_trace()
291+
plt.show()
292+
293+
294+
295+
296+
297+
143298
@hydra.main(config_path=".", config_name="mnist_config")
144299
def main(cfg):
145300
print("Starting experiment with configuration: %s", cfg.data.dataset_name)
@@ -172,6 +327,7 @@ def main(cfg):
172327
bounding_box_coords = {"x_min": cfg.data.tsne_box_coordinates.x_min, "x_max": cfg.data.tsne_box_coordinates.x_max, "y_min": cfg.data.tsne_box_coordinates.y_min, "y_max": cfg.data.tsne_box_coordinates.y_max}
173328
bounding_box_name = f"{cfg.data.tsne_box_coordinates.x_min}_{cfg.data.tsne_box_coordinates.x_max}_{cfg.data.tsne_box_coordinates.y_min}_{cfg.data.tsne_box_coordinates.y_max}"
174329
results_folder_name = f"{cfg.data.split_type}_{bounding_box_name}"
330+
175331
elif cfg.data.split_type == "random":
176332
results_folder_name = f"{cfg.data.split_type}_{cfg.data.n_forget_points}"
177333

@@ -185,25 +341,28 @@ def main(cfg):
185341
# --------- random sampled forget set ---------
186342
if cfg.data.split_type == 'random':
187343
dataloader_train, dataloader_retain, dataloader_forget, dataloader_val, forget_idxs = get_image_unlearn_data(root_dir=dataset_dir,
188-
dataset_name=cfg['data']['dataset_name'],
189-
n_forget_points=cfg.data.n_forget_points,
190-
subsample_size=cfg.data.subsample_size,
191-
patch_size=cfg.data.patch_size,
192-
batch_size=cfg.data.batch_size,
193-
seed=cfg.data.seed)
344+
dataset_name=cfg['data']['dataset_name'],
345+
n_forget_points=cfg.data.n_forget_points,
346+
subsample_size=cfg.data.subsample_size,
347+
patch_size=cfg.data.patch_size,
348+
batch_size=cfg.data.batch_size,
349+
seed=cfg.data.seed)
194350
# --------- tsne box forget set ---------
195351
elif cfg.data.split_type == 'tsne_box':
196-
dataloader_train, dataloader_retain, dataloader_forget, dataloader_val, forget_idxs = get_image_unlearn_data_tsne_box(root_dir=dataset_dir,
197-
batch_size=cfg.data.batch_size,
198-
subsample_size=cfg.data.subsample_size,
199-
seed=cfg.data.seed,
200-
boundary=bounding_box_coords)
352+
dataloader_train, dataloader_retain, dataloader_forget, dataloader_val, forget_idxs, tsne_results = get_image_unlearn_data_tsne_box(root_dir=dataset_dir,
353+
batch_size=cfg.data.batch_size,
354+
subsample_size=cfg.data.subsample_size,
355+
seed=cfg.data.seed,
356+
boundary=bounding_box_coords,
357+
return_tsne_results=True)
358+
201359
print("MNIST dataloaders created. Train size: %d, Forget size: %d",
202360
len(dataloader_train.dataset), len(dataloader_forget.dataset))
203361

204362
else:
205363
raise NotImplementedError(f"The dataset {cfg.data.dataset_name} is not supported!")
206364

365+
207366
assert (dataloader_train.dataset.y[forget_idxs] == dataloader_forget.dataset.y).all(), 'Forget indices applied to train do not correspond to the forget data!'
208367

209368
# ========================== Initialize logger ==========================
@@ -260,7 +419,7 @@ def main(cfg):
260419
else:
261420
# Load original model weights
262421
print("Loading pre-trained original model weights")
263-
original_model.load_state_dict(torch.load(orig_model_path))
422+
original_model.load_state_dict(torch.load(orig_model_path, map_location=DEVICE))
264423

265424
# Calculate and save metrics for original model
266425
calculate_model_metrics(original_model,
@@ -310,7 +469,7 @@ def main(cfg):
310469

311470
else:
312471
print("Loading pre-trained retrained model weights")
313-
retrained_model_sd = torch.load(f'{weights_dir}/retrained_model_weights_{seed}.pt')
472+
retrained_model_sd = torch.load(f'{weights_dir}/retrained_model_weights_{seed}.pt', map_location=DEVICE)
314473
retrained_model.load_state_dict(retrained_model_sd)
315474

316475
# Calculate and save metrics for retrained model
@@ -319,6 +478,23 @@ def main(cfg):
319478
DEVICE, 'Retrained model',
320479
save_path=f'{results_dir}/retrained_model_metrics_{seed}.json')
321480

481+
num_neighbors = 3
482+
n_interp_points = 25
483+
n_imgs_per_batch = 10
484+
batch_size = n_interp_points * num_neighbors * n_imgs_per_batch
485+
486+
nn_idxs = get_nearest_neighbor_idxs(dataloader_train.dataset.X, dataloader_train.dataset.y, tsne_results, num_neighbors)
487+
interp_img_generator = interpolate_points_batch(dataloader_train.dataset.X, nn_idxs, n_interp_points=n_interp_points, batch_size=batch_size)
488+
interp_tsne_generator = interpolate_points_batch(torch.tensor(tsne_results), nn_idxs, n_interp_points=n_interp_points, batch_size=batch_size)
489+
490+
create_entropy_tsne_plot(original_model, dataloader_train.dataset.y,
491+
interp_img_generator, tsne_results,
492+
interp_tsne_generator,
493+
(tsne_results[:, 0].min(), tsne_results[:, 0].max()),
494+
(tsne_results[:, 1].min(), tsne_results[:, 1].max()),
495+
n_imgs_per_batch, device='cuda')
496+
497+
import pdb; pdb.set_trace()
322498

323499
# ========================== Unlearn: Teacher Ascender ==========================
324500
if cfg.unlearn.method == 'teacher_ascend':

0 commit comments

Comments
 (0)