diff --git a/intervention/circle_finding_utils.py b/intervention/circle_finding_utils.py index 53f2ce6..64da461 100644 --- a/intervention/circle_finding_utils.py +++ b/intervention/circle_finding_utils.py @@ -113,7 +113,7 @@ def get_logit_diffs_from_subspace_formula_resid_intervention( probe_r = probe_r.to(device) target_embedding_in_q_space = target_to_embedding.to(device) @ probe_r.inverse() - pca_projection_matrix = torch.tensor(pca_projection_matrix).to(device).T.float() + pca_projection_matrix = torch.tensor(pca_projection_matrix).float().to(device).T all_pcas = ( torch.tensor(