1818import graphviz
1919import torch .nn as nn
2020import matplotlib .pyplot as plt
21+ from matplotlib .colors import ListedColormap
22+ from scipy .interpolate import griddata
23+
2124from torch .utils .data import RandomSampler , SequentialSampler
2225
2326from src .models .neural_network import NeuralNetRS
2629from src .evaluation .membership_inference_attack import MIA
2730from src .evaluation .unlearning_evaluator import UnlearningEvaluator
2831from prepare_image_data import get_image_unlearn_data
29- from prepare_image_data_v2 import get_image_unlearn_data as get_image_unlearn_data_tsne_box
32+ from prepare_image_data_tsne import get_image_unlearn_data as get_image_unlearn_data_tsne_box
3033import time
3134import json
35+ import numpy as np
36+
3237
3338def has_shuffle (dataloader ) -> bool :
3439 """
@@ -140,6 +145,156 @@ def calculate_model_metrics(model, dataloaders, device, model_name=None, save_pa
140145
141146 return metrics
142147
148+ def get_nearest_neighbor_idxs (X , y , tsne_results , num_neighbors : int = 5 ):
149+ """
150+ A function that finds the nearest neighbors in t-SNE space to each point in the dataset X.
151+ The neighbors are constrained such that they must belong to the same class as the sample.
152+
153+ Returns:
154+ A numpy array of size [N x num_neighbors] containing the indices in X which have
155+ """
156+ # calculate distance from every point to every other point in t-SNE space
157+ distances = np .linalg .norm ((tsne_results [:, np .newaxis , :] - tsne_results [np .newaxis , :, :]), ord = None , axis = 2 )
158+ np .fill_diagonal (distances , np .inf ) # make sure to exclude diagonal entries
159+ sorted_idxs = np .argsort (distances , axis = 1 ) # sort distances in ascending order and get indices.
160+
161+ labels = y .argmax (dim = - 1 ) # convert from OHE to labels
162+ neighbor_labels = labels [sorted_idxs ] # N x N (labels of all neighbors)
163+
164+ same_class_mask = (neighbor_labels == labels [:, None ]) # N x N mask where true means the neighbor has the same class
165+ masked_sorted_idxs = np .where (same_class_mask , sorted_idxs , - 1 ) # fill different classes with index -1
166+
167+ nn_idxs = np .full ((X .size (0 ), num_neighbors ), - 1 )
168+ for i in range (X .size (0 )):
169+ valid_idxs = sorted_idxs [i ][masked_sorted_idxs [i ] != - 1 ]
170+ nn_idxs [i , :min (num_neighbors , len (valid_idxs ))] = valid_idxs [:num_neighbors ]
171+
172+ return nn_idxs
173+
174+ def interpolate_points_generator (X , nn_idxs , n_interp_points : int = 100 ):
175+ """
176+ X: N x M matrix of points
177+ nn_idxs: A N x n_neighbors matrix of the same order as X.
178+ """
179+ N = X .shape [0 ]
180+ n_neighbors = nn_idxs .shape [1 ]
181+ interp_range = np .arange (1 / n_interp_points , 1 , step = 1 / n_interp_points )
182+ # X_interp = np.zeros((N * n_interp_points * len(interp_range), X.shape[1]))
183+
184+ idx = 0
185+ for i in range (X .shape [0 ]):
186+ for j in range (n_neighbors ):
187+ x = X [i ]
188+ x_nbr = X [nn_idxs [i ][j ]]
189+ for alpha in interp_range :
190+ x_interp = x * alpha + (1 - alpha ) * x_nbr
191+ yield x_interp .unsqueeze (0 )
192+
193+
194+ def interpolate_points_batch (X , nn_idxs , n_interp_points : int = 100 , batch_size : int = 512 ):
195+ from itertools import islice
196+ generator = interpolate_points_generator (X , nn_idxs , n_interp_points )
197+ while True :
198+ batch = list (islice (generator , batch_size ))
199+ if not batch :
200+ break
201+ yield torch .cat (batch )
202+
203+ def create_entropy_tsne_plot (model , y ,
204+ image_generator ,
205+ tsne_results ,
206+ tsne_point_generator ,
207+ tsne_bounds_x : tuple [float , float ],
208+ tsne_bounds_y : tuple [float , float ],
209+ n_imgs_per_batch : int = 10 ,
210+ grid_size : int = 500 ,
211+ device = 'cuda' ):
212+
213+ # Setup
214+ class_colors = ['#003f5c' , '#2f4b7c' , '#665191' , '#a05195' , '#d45087' ,
215+ '#f95d6a' , '#ff7c43' , '#ffa600' , '#aa5382' , '#eea152' ]
216+ plt .style .use ('seaborn-v0_8-paper' )
217+ plt .figure (figsize = (8 , 6 ))
218+
219+ y = y .argmax (dim = - 1 ) # from one-hot to label
220+
221+ # Accumulate data
222+ all_tsne = []
223+ all_entropy = []
224+ all_labels = []
225+
226+ itt = 0
227+ while True :
228+ imgs = next (iter (image_generator ), None )
229+ tsne_points = next (iter (tsne_point_generator ), None )
230+ if imgs is None : #or tsne_points is None:
231+ break
232+
233+ imgs = imgs .to (device )
234+ ys = y [itt * n_imgs_per_batch : (itt + 1 ) * n_imgs_per_batch ]
235+ ys = ys .repeat_interleave (imgs .size (0 ) // n_imgs_per_batch )
236+
237+ probs = model .inference (imgs )['probabilities' ]
238+ entropies = - (torch .log (probs + 1e-8 ) * probs ).sum (dim = - 1 )
239+
240+ all_tsne .append (tsne_points .cpu ().numpy ())
241+ all_entropy .append (entropies .cpu ().numpy ())
242+ all_labels .append (ys .cpu ().numpy ())
243+
244+ itt += 1
245+
246+ all_tsne = np .concatenate (all_tsne , axis = 0 )
247+ all_entropy = np .concatenate (all_entropy , axis = 0 )
248+ all_labels = np .concatenate (all_labels , axis = 0 )
249+
250+ # Interpolate entropy
251+ x_min , x_max = all_tsne [:, 0 ].min (), all_tsne [:, 0 ].max ()
252+ y_min , y_max = all_tsne [:, 1 ].min (), all_tsne [:, 1 ].max ()
253+
254+ grid_x , grid_y = np .meshgrid (
255+ np .linspace (x_min , x_max , grid_size ),
256+ np .linspace (y_min , y_max , grid_size )
257+ )
258+
259+ grid_entropy = griddata (
260+ all_tsne , all_entropy ,
261+ (grid_x , grid_y ),
262+ #method='linear' # This leaves NaNs in missing regions
263+ )
264+
265+ # Set colormap with transparency for missing areas
266+ cmap = plt .cm .viridis .copy ()
267+ cmap .set_bad (color = (1 , 1 , 1 , 0 )) # Fully transparent for NaNs
268+
269+ # Show entropy heatmap with transparency for missing data
270+ plt .imshow (
271+ grid_entropy ,
272+ extent = (x_min , x_max , y_min , y_max ),
273+ origin = 'lower' ,
274+ cmap = cmap ,
275+ alpha = 1.0 ,
276+ aspect = 'auto'
277+ )
278+ plt .colorbar (label = 'Entropy' )
279+
280+ # Overlay t-SNE points
281+ plt .scatter (tsne_results [:, 0 ], tsne_results [:, 1 ],
282+ c = y , cmap = ListedColormap (class_colors ),
283+ s = 10 , alpha = 0.2 )
284+
285+ plt .title ("Interpolated Entropy in t-SNE Space" )
286+ plt .xlabel ("t-SNE dim 1" )
287+ plt .ylabel ("t-SNE dim 2" )
288+ plt .tight_layout ()
289+ plt .savefig ('grid_entropies.png' , dpi = 300 )
290+ import pdb ; pdb .set_trace ()
291+ plt .show ()
292+
293+
294+
295+
296+
297+
143298@hydra .main (config_path = "." , config_name = "mnist_config" )
144299def main (cfg ):
145300 print ("Starting experiment with configuration: %s" , cfg .data .dataset_name )
@@ -172,6 +327,7 @@ def main(cfg):
172327 bounding_box_coords = {"x_min" : cfg .data .tsne_box_coordinates .x_min , "x_max" : cfg .data .tsne_box_coordinates .x_max , "y_min" : cfg .data .tsne_box_coordinates .y_min , "y_max" : cfg .data .tsne_box_coordinates .y_max }
173328 bounding_box_name = f"{ cfg .data .tsne_box_coordinates .x_min } _{ cfg .data .tsne_box_coordinates .x_max } _{ cfg .data .tsne_box_coordinates .y_min } _{ cfg .data .tsne_box_coordinates .y_max } "
174329 results_folder_name = f"{ cfg .data .split_type } _{ bounding_box_name } "
330+
175331 elif cfg .data .split_type == "random" :
176332 results_folder_name = f"{ cfg .data .split_type } _{ cfg .data .n_forget_points } "
177333
@@ -185,25 +341,28 @@ def main(cfg):
185341 # --------- random sampled forget set ---------
186342 if cfg .data .split_type == 'random' :
187343 dataloader_train , dataloader_retain , dataloader_forget , dataloader_val , forget_idxs = get_image_unlearn_data (root_dir = dataset_dir ,
188- dataset_name = cfg ['data' ]['dataset_name' ],
189- n_forget_points = cfg .data .n_forget_points ,
190- subsample_size = cfg .data .subsample_size ,
191- patch_size = cfg .data .patch_size ,
192- batch_size = cfg .data .batch_size ,
193- seed = cfg .data .seed )
344+ dataset_name = cfg ['data' ]['dataset_name' ],
345+ n_forget_points = cfg .data .n_forget_points ,
346+ subsample_size = cfg .data .subsample_size ,
347+ patch_size = cfg .data .patch_size ,
348+ batch_size = cfg .data .batch_size ,
349+ seed = cfg .data .seed )
194350 # --------- tsne box forget set ---------
195351 elif cfg .data .split_type == 'tsne_box' :
196- dataloader_train , dataloader_retain , dataloader_forget , dataloader_val , forget_idxs = get_image_unlearn_data_tsne_box (root_dir = dataset_dir ,
197- batch_size = cfg .data .batch_size ,
198- subsample_size = cfg .data .subsample_size ,
199- seed = cfg .data .seed ,
200- boundary = bounding_box_coords )
352+ dataloader_train , dataloader_retain , dataloader_forget , dataloader_val , forget_idxs , tsne_results = get_image_unlearn_data_tsne_box (root_dir = dataset_dir ,
353+ batch_size = cfg .data .batch_size ,
354+ subsample_size = cfg .data .subsample_size ,
355+ seed = cfg .data .seed ,
356+ boundary = bounding_box_coords ,
357+ return_tsne_results = True )
358+
201359 print ("MNIST dataloaders created. Train size: %d, Forget size: %d" ,
202360 len (dataloader_train .dataset ), len (dataloader_forget .dataset ))
203361
204362 else :
205363 raise NotImplementedError (f"The dataset { cfg .data .dataset_name } is not supported!" )
206364
365+
207366 assert (dataloader_train .dataset .y [forget_idxs ] == dataloader_forget .dataset .y ).all (), 'Forget indices applied to train do not correspond to the forget data!'
208367
209368 # ========================== Initialize logger ==========================
@@ -260,7 +419,7 @@ def main(cfg):
260419 else :
261420 # Load original model weights
262421 print ("Loading pre-trained original model weights" )
263- original_model .load_state_dict (torch .load (orig_model_path ))
422+ original_model .load_state_dict (torch .load (orig_model_path , map_location = DEVICE ))
264423
265424 # Calculate and save metrics for original model
266425 calculate_model_metrics (original_model ,
@@ -310,7 +469,7 @@ def main(cfg):
310469
311470 else :
312471 print ("Loading pre-trained retrained model weights" )
313- retrained_model_sd = torch .load (f'{ weights_dir } /retrained_model_weights_{ seed } .pt' )
472+ retrained_model_sd = torch .load (f'{ weights_dir } /retrained_model_weights_{ seed } .pt' , map_location = DEVICE )
314473 retrained_model .load_state_dict (retrained_model_sd )
315474
316475 # Calculate and save metrics for retrained model
@@ -319,6 +478,23 @@ def main(cfg):
319478 DEVICE , 'Retrained model' ,
320479 save_path = f'{ results_dir } /retrained_model_metrics_{ seed } .json' )
321480
481+ num_neighbors = 3
482+ n_interp_points = 25
483+ n_imgs_per_batch = 10
484+ batch_size = n_interp_points * num_neighbors * n_imgs_per_batch
485+
486+ nn_idxs = get_nearest_neighbor_idxs (dataloader_train .dataset .X , dataloader_train .dataset .y , tsne_results , num_neighbors )
487+ interp_img_generator = interpolate_points_batch (dataloader_train .dataset .X , nn_idxs , n_interp_points = n_interp_points , batch_size = batch_size )
488+ interp_tsne_generator = interpolate_points_batch (torch .tensor (tsne_results ), nn_idxs , n_interp_points = n_interp_points , batch_size = batch_size )
489+
490+ create_entropy_tsne_plot (original_model , dataloader_train .dataset .y ,
491+ interp_img_generator , tsne_results ,
492+ interp_tsne_generator ,
493+ (tsne_results [:, 0 ].min (), tsne_results [:, 0 ].max ()),
494+ (tsne_results [:, 1 ].min (), tsne_results [:, 1 ].max ()),
495+ n_imgs_per_batch , device = 'cuda' )
496+
497+ import pdb ; pdb .set_trace ()
322498
323499 # ========================== Unlearn: Teacher Ascender ==========================
324500 if cfg .unlearn .method == 'teacher_ascend' :
0 commit comments