diff --git a/README.md b/README.md index 7e7f6a8..a5b5fb3 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ This is the github repo for our paper ["Not All Language Model Features Are Line ## Reproducing each figure -Below are instructions to reproduce each figure (aspirationally). +Below are instructions to reproduce each figure (aspirationally). The required pthon packages to run this repo are ``` @@ -17,7 +17,7 @@ either manually using pip or using the existing requirements.txt if you are on a machine with Cuda 12.1: ``` python -m venv multid -pip install -r requirements.txt +pip install -r requirements.txt OR pip install transformer_lens sae_lens transformers datasets torch adjustText circuitsvis ipython ``` @@ -31,16 +31,16 @@ To reproduce the intervention results, you will first need to run intervention e ``` cd intervention -python3 circle_probe_interventions.py day a mistral --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin -python3 circle_probe_interventions.py month a mistral --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin -python3 circle_probe_interventions.py day a llama --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin -python3 circle_probe_interventions.py month a llama --device 0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py day a mistral --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py month a mistral --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py day a llama --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin +python3 circle_probe_interventions.py month a llama --device cuda:0 --intervention_pca_k 5 --probe_on_cos --probe_on_sin ``` You can then reproduce *Figure 3*, *Figure 5*, *Figure 6*, and *Table 1* by running the corresponding cells in intervention/main_text_plots.ipynb. -After running these intervention experiments, you can reproduce *Figure 6* by running +After running these intervention experiments, you can reproduce *Figure 6* by running ``` cd intervention python3 intervene_in_middle_of_circle.py --only_paper_plots @@ -132,5 +132,3 @@ If you have any questions about the paper or reproducing results, feel free to e year={2024} } ``` - - diff --git a/intervention/circle_probe_interventions.py b/intervention/circle_probe_interventions.py index d936af3..ba481c5 100644 --- a/intervention/circle_probe_interventions.py +++ b/intervention/circle_probe_interventions.py @@ -41,7 +41,12 @@ choices=["llama", "mistral"], help="Choose 'llama' or 'mistral' model", ) - parser.add_argument("--device", type=int, default=4, help="CUDA device number") + parser.add_argument( + "--device", + type=str, + default="cuda:4" if torch.cuda.is_available() else "cpu", + help="Device to use", + ) parser.add_argument( "--use_inverse_regression_probe", action="store_true", @@ -73,7 +78,7 @@ help="Probe on linear representation with center of 0.", ) args = parser.parse_args() - device = f"cuda:{args.device}" + device = args.device day_month_choice = args.problem_type circle_letter = args.intervene_on model_name = args.model @@ -100,7 +105,7 @@ # use_inverse_regression_probe = False # intervention_pca_k = 5 - device = "cuda:4" + device = "cuda:4" if torch.cuda.is_available() else "cpu" circle_letter = "c" day_month_choice = "day" model_name = "mistral" diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..4e8e061 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -11,7 +11,7 @@ from task import activation_patching -device = "cuda:4" +device = "cuda:4" if torch.cuda.is_available() else "cpu" # # %% diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..6adfef3 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -5,13 +5,14 @@ setup_notebook() +import torch import numpy as np import transformer_lens from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching -device = "cuda:4" +device = "cuda:4" if torch.cuda.is_available() else "cpu" # # %%