From 9213652a596d00fe19bf51b0d9edfa6beb160166 Mon Sep 17 00:00:00 2001 From: Tauheed Elahee Date: Fri, 20 Feb 2026 14:05:19 -0500 Subject: [PATCH] Update torch.load calls to include weights_only parameter --- intervention/task.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/intervention/task.py b/intervention/task.py index 5c02857..562ab6b 100644 --- a/intervention/task.py +++ b/intervention/task.py @@ -163,7 +163,9 @@ def get_all_acts( all_acts = [] for i in range(0, len(all_problems)): tensors = torch.load( - f"{task.prefix}{save_file_prefix}{i}.pt", map_location="cpu" + f"{task.prefix}{save_file_prefix}{i}.pt", + map_location="cpu", + weights_only=False, ) all_acts.append(tensors) if len(all_acts) > 1: @@ -201,7 +203,7 @@ def get_acts( torch.save( all_acts[:, layer, token, :].detach().cpu().clone(), file_name ) - data = torch.load(file_name) + data = torch.load(file_name, weights_only=False) if normalize_rms: eps = 1e-5 scale = (data.pow(2).mean(-1, keepdim=True) + eps).sqrt() @@ -235,7 +237,9 @@ def get_acts_pca( pca_acts = pca_object.transform(acts) torch.save(pca_acts, act_file_name) pkl.dump(pca_object, open(pca_pkl_file_name, "wb")) - return torch.load(act_file_name), pkl.load(open(pca_pkl_file_name, "rb")) + return torch.load(act_file_name, weights_only=False), pkl.load( + open(pca_pkl_file_name, "rb") + ) def get_acts_pls(task, layer, token, pls_k, normalize_rms=False): @@ -255,7 +259,9 @@ def get_acts_pls(task, layer, token, pls_k, normalize_rms=False): torch.save(torch.tensor(pls_acts), act_file_name) pkl.dump(pls, open(pls_pkl_file_name, "wb")) - return torch.load(act_file_name), pkl.load(open(pls_pkl_file_name, "rb")) + return torch.load(act_file_name, weights_only=False), pkl.load( + open(pls_pkl_file_name, "rb") + ) def _set_plotting_sizes():