From 98d7222c16c1e387ae829505230f6805c2fb6e22 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sat, 7 Sep 2024 15:02:27 -0700 Subject: [PATCH] Fix to support MPS: convert to float32 earlier --- intervention/circle_finding_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/intervention/circle_finding_utils.py b/intervention/circle_finding_utils.py index 53f2ce6..64da461 100644 --- a/intervention/circle_finding_utils.py +++ b/intervention/circle_finding_utils.py @@ -113,7 +113,7 @@ def get_logit_diffs_from_subspace_formula_resid_intervention( probe_r = probe_r.to(device) target_embedding_in_q_space = target_to_embedding.to(device) @ probe_r.inverse() - pca_projection_matrix = torch.tensor(pca_projection_matrix).to(device).T.float() + pca_projection_matrix = torch.tensor(pca_projection_matrix).float().to(device).T all_pcas = ( torch.tensor(