From 79c5d1a76f73cc6d55d27e4deab2be02bb6f3a2c Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Sat, 7 Sep 2024 13:56:17 -0700 Subject: [PATCH] Use pathlib more and have default cache paths be relative to the repo directory --- intervention/appendix_plots.ipynb | 5 +-- intervention/circle_probe_interventions.py | 4 +-- intervention/days_of_week_task.py | 6 ++-- intervention/intervene_in_middle_of_circle.py | 2 +- intervention/main_text_plots.ipynb | 5 +-- intervention/months_of_year_task.py | 6 ++-- intervention/task.py | 32 +++++++++---------- intervention/utils.py | 5 +-- .../generate_feature_occurence_data.py | 8 ++--- sae_multid_feature_discovery/utils.py | 4 ++- 10 files changed, 41 insertions(+), 36 deletions(-) diff --git a/intervention/appendix_plots.ipynb b/intervention/appendix_plots.ipynb index e64984b..9b872e2 100644 --- a/intervention/appendix_plots.ipynb +++ b/intervention/appendix_plots.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "# %%\n", + "from pathlib import Path\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from task import get_acts, get_acts_pca, get_all_acts\n", @@ -462,7 +463,7 @@ "\n", "for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " results_mistral = pd.read_csv(\n", - " f\"{BASE_DIR}/mistral_{task_name}/results.csv\", skipinitialspace=True\n", + " Path(BASE_DIR) / f\"mistral_{task_name}\" / \"results.csv\", skipinitialspace=True\n", " )\n", "\n", " results_mistral = results_mistral.rename(\n", @@ -480,7 +481,7 @@ " print(sum(results_mistral[\"mistral_correct\"]))\n", "\n", " results_llama = pd.read_csv(\n", - " f\"{BASE_DIR}/llama_{task_name}/results.csv\", skipinitialspace=True\n", + " Path(BASE_DIR) / f\"llama_{task_name}\" / \"results.csv\", skipinitialspace=True\n", " )\n", "\n", " results_llama = results_llama.rename(\n", diff --git a/intervention/circle_probe_interventions.py b/intervention/circle_probe_interventions.py index d936af3..5012fd9 100644 --- a/intervention/circle_probe_interventions.py +++ b/intervention/circle_probe_interventions.py @@ -171,7 +171,7 @@ probe_projections = {} target_to_embeddings = {} -os.makedirs(f"{task.prefix}/circle_probes_{circle_letter}", exist_ok=True) +(task.prefix / f"circle_probes_{circle_letter}").mkdir(exist_ok=True) all_maes = [] all_r_squareds = [] @@ -262,7 +262,7 @@ "probe_r": probe_r, "target_to_embedding": target_to_embedding, }, - f"{task.prefix}/circle_probes_{circle_letter}/{probe_file_extension}_layer_{layer}_token_{token}_pca_{pca_k}.pt", + task.prefix / f"circle_probes_{circle_letter}" / f"{probe_file_extension}_layer_{layer}_token_{token}_pca_{pca_k}.pt", ) mae = (predictions - multid_targets_train).abs().mean() diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..fa159da 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -1,5 +1,6 @@ # %% +from pathlib import Path import os from utils import setup_notebook, BASE_DIR @@ -49,9 +50,8 @@ def __init__(self, device, model_name="mistral", n_devices=None): # Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall) self.allowable_tokens = days_of_week - self.prefix = f"{BASE_DIR}{model_name}_days_of_week/" - if not os.path.exists(self.prefix): - os.makedirs(self.prefix) + self.prefix = Path(BASE_DIR) / f"{model_name}_days_of_week" + self.prefix.mkdir(parents=True, exist_ok=True) self.num_tokens_in_answer = 1 diff --git a/intervention/intervene_in_middle_of_circle.py b/intervention/intervene_in_middle_of_circle.py index 2fb75fb..95e1ddd 100644 --- a/intervention/intervene_in_middle_of_circle.py +++ b/intervention/intervene_in_middle_of_circle.py @@ -40,7 +40,7 @@ def vary_wthin_circle(circle_letter, duration, layer, token, pca_k, all_points): model = task.get_model() circle_projection_qr = torch.load( - f"{task.prefix}/circle_probes_{circle_letter}/cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt" + task.prefix / f"circle_probes_{circle_letter}" / f"cos_sin_layer_{layer}_token_{token}_pca_{pca_k}.pt" ) for problem in task.generate_problems(): diff --git a/intervention/main_text_plots.ipynb b/intervention/main_text_plots.ipynb index 2266226..bb65ad4 100644 --- a/intervention/main_text_plots.ipynb +++ b/intervention/main_text_plots.ipynb @@ -7,6 +7,7 @@ "outputs": [], "source": [ "# %%\n", + "from pathlib import Path\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from task import get_acts, get_acts_pca\n", @@ -516,7 +517,7 @@ "for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " for model_name in [\"mistral\", \"llama\"]:\n", " results = pd.read_csv(\n", - " f\"{BASE_DIR}/{model_name}_{task_name}/results.csv\", skipinitialspace=True\n", + " Path(BASE_DIR) / f\"{model_name}_{task_name}\" / \"results.csv\", skipinitialspace=True\n", " )\n", " number_correct = results[\"best_token\"] == results[\"ground_truth\"]\n", " print(task_name, model_name, np.sum(number_correct))\n", @@ -560,7 +561,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.7" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..5f9d411 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -1,6 +1,7 @@ # %% import os +from pathlib import Path from utils import setup_notebook, BASE_DIR setup_notebook() @@ -71,9 +72,8 @@ def __init__(self, device, model_name="mistral", n_devices=None): # Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall) self.allowable_tokens = months_of_year - self.prefix = f"{BASE_DIR}{model_name}_months_of_year/" - if not os.path.exists(self.prefix): - os.makedirs(self.prefix) + self.prefix = Path(BASE_DIR) / f"{model_name}_months_of_year" + self.prefix.mkdir(parents=True, exist_ok=True) self.num_tokens_in_answer = 1 diff --git a/intervention/task.py b/intervention/task.py index 5c02857..93f798f 100644 --- a/intervention/task.py +++ b/intervention/task.py @@ -1,3 +1,4 @@ +from pathlib import Path from utils import BASE_DIR # Need this import to set the huggingface cache directory import os import numpy as np @@ -24,7 +25,6 @@ def __str__(self): def __repr__(self): return str(self) - def generate_and_save_acts( task, names_filter, @@ -39,10 +39,10 @@ def generate_and_save_acts( forward_batch_size = 2 num_tokens_to_generate = task.num_tokens_in_answer all_problems = task.generate_problems() - output_file = task.prefix + "results.csv" + output_file = task.prefix / "results.csv" if save_results_csv: - os.makedirs(task.prefix, exist_ok=True) + task.prefix.mkdir(parents=True, exist_ok=True) model_best_addition = "" if not save_best_logit else ", best_token" with open(output_file, "w") as f: f.write( @@ -98,7 +98,7 @@ def generate_and_save_acts( print(tensors.shape) torch.save( tensors, - f"{task.prefix}{save_file_prefix}{current_problem_index}.pt", + task.prefix / f"{save_file_prefix}{current_problem_index}.pt", ) if save_results_csv: @@ -146,7 +146,7 @@ def get_all_acts( all_problems = task.generate_problems() all_problems_already_generated = True for i in range(len(all_problems)): - if not os.path.exists(f"{task.prefix}{save_file_prefix}{i}.pt"): + if not (task.prefix / f"{save_file_prefix}{i}.pt").exists(): all_problems_already_generated = False break if not all_problems_already_generated or force_regenerate: @@ -163,7 +163,7 @@ def get_all_acts( all_acts = [] for i in range(0, len(all_problems)): tensors = torch.load( - f"{task.prefix}{save_file_prefix}{i}.pt", map_location="cpu" + task.prefix / f"{save_file_prefix}{i}.pt", map_location="cpu" ) all_acts.append(tensors) if len(all_acts) > 1: @@ -186,9 +186,9 @@ def get_acts( if save_file_prefix != "" and save_file_prefix[-1] != "_": save_file_prefix += "_" file_name = ( - f"{task.prefix}{save_file_prefix}layer{layer_fetch}_token{token_fetch}.pt" + task.prefix / f"{save_file_prefix}layer{layer_fetch}_token{token_fetch}.pt" ) - if not os.path.exists(file_name) or force_regenerate: + if not file_name.exists() or force_regenerate: print(file_name, "not exists") all_acts = get_all_acts( task, names_filter=names_filter, save_file_prefix=save_file_prefix @@ -196,7 +196,7 @@ def get_acts( for layer in range(all_acts.shape[1]): for token in range(all_acts.shape[2]): file_name = ( - f"{task.prefix}{save_file_prefix}layer{layer}_token{token}.pt" + task.prefix / f"{save_file_prefix}layer{layer}_token{token}.pt" ) torch.save( all_acts[:, layer, token, :].detach().cpu().clone(), file_name @@ -218,11 +218,11 @@ def get_acts_pca( names_filter=lambda x: "resid_post" in x or "hook_embed" in x, save_file_prefix="", ): - act_file_name = f"{task.prefix}pca/{save_file_prefix}/layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pt" - pca_pkl_file_name = f"{task.prefix}pca/{save_file_prefix}/layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pkl" - os.makedirs(f"{task.prefix}/pca/{save_file_prefix}", exist_ok=True) + act_file_name = task.prefix / "pca" / save_file_prefix / f"layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pt" + pca_pkl_file_name = task.prefix / "pca" / save_file_prefix / f"layer{layer}_token{token}_pca{pca_k}{'_normalize' if normalize_rms else ''}.pkl" + (task.prefix / "pca" / save_file_prefix).mkdir(parents=True, exist_ok=True) - if not os.path.exists(act_file_name) or not os.path.exists(pca_pkl_file_name): + if not act_file_name.exists() or not pca_pkl_file_name.exists(): acts = get_acts( task, layer, @@ -239,9 +239,9 @@ def get_acts_pca( def get_acts_pls(task, layer, token, pls_k, normalize_rms=False): - act_file_name = f"{task.prefix}/pls/layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pt" - pls_pkl_file_name = f"{task.prefix}/pls/layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pkl" - os.makedirs(f"{task.prefix}/pls", exist_ok=True) + act_file_name = task.prefix / "pls" / f"layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pt" + pls_pkl_file_name = task.prefix / "pls" / f"layer{layer}_token{token}_pls{pls_k}{'_normalize' if normalize_rms else ''}.pkl" + (task.prefix / "pls").mkdir(parents=True, exist_ok=True) # if not os.path.exists(act_file_name) or not os.path.exists(pls_pkl_file_name): if True: diff --git a/intervention/utils.py b/intervention/utils.py index 216a26c..d3bfd1f 100644 --- a/intervention/utils.py +++ b/intervention/utils.py @@ -1,9 +1,10 @@ import os import dill as pickle +from pathlib import Path -BASE_DIR = "/data/scratch/jae/" +BASE_DIR = Path(__file__).parent.parent / "cache" -os.environ["TRANSFORMERS_CACHE"] = f"{BASE_DIR}/.cache/" +os.environ["TRANSFORMERS_CACHE"] = f"{(Path(BASE_DIR) / '.cache').absolute()}/" def setup_notebook(): diff --git a/sae_multid_feature_discovery/generate_feature_occurence_data.py b/sae_multid_feature_discovery/generate_feature_occurence_data.py index 63733aa..022c09c 100644 --- a/sae_multid_feature_discovery/generate_feature_occurence_data.py +++ b/sae_multid_feature_discovery/generate_feature_occurence_data.py @@ -1,12 +1,12 @@ # %% - +from pathlib import Path import os from utils import BASE_DIR # hopefully this will help with memory fragmentation os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" -os.environ["TRANSFORMERS_CACHE"] = f"{BASE_DIR}.cache/" +os.environ["TRANSFORMERS_CACHE"] = f"{(Path(BASE_DIR) / '.cache').absolute()}/" import einops import numpy as np @@ -52,8 +52,8 @@ num_sae_activations_to_save = 10**9 -save_folder = f"{BASE_DIR}{model_name}" -os.makedirs(save_folder, exist_ok=True) +save_folder = Path(BASE_DIR) / model_name +save_folder.mkdir(exist_ok=True, parents=True) t.set_grad_enabled(False) diff --git a/sae_multid_feature_discovery/utils.py b/sae_multid_feature_discovery/utils.py index f69f0ab..23f40dc 100644 --- a/sae_multid_feature_discovery/utils.py +++ b/sae_multid_feature_discovery/utils.py @@ -1,7 +1,9 @@ + +from pathlib import Path from huggingface_hub import hf_hub_download import os -BASE_DIR = "/data/scratch/jae/" +BASE_DIR = Path(__file__).parent.parent / "cache" def get_gpt2_sae(device, layer): from sae_lens import SAE