diff --git a/intervention/appendix_plots.ipynb b/intervention/appendix_plots.ipynb index e64984b..9c9f7e1 100644 --- a/intervention/appendix_plots.ipynb +++ b/intervention/appendix_plots.ipynb @@ -28,7 +28,8 @@ "\n", "os.makedirs(\"figs/paper_plots\", exist_ok=True)\n", "\n", - "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"" + "device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n", + "dtype = \"float32\"" ] }, { @@ -130,9 +131,9 @@ "for model_name in [\"mistral\", \"llama\"]:\n", " for task_name in [\"task_name\", \"months_of_year\"]:\n", " if task_name == \"{task_name}\":\n", - " task = DaysOfWeekTask(model_name=model_name, device=device)\n", + " task = DaysOfWeekTask(model_name=model_name, device=device, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(model_name=model_name, device=device)\n", + " task = MonthsOfYearTask(model_name=model_name, device=device, dtype=dtype)\n", "\n", " for keep_same_index in [0, 1]:\n", " for layer_type in [\"mlp\", \"attention\", \"resid\", \"attention_head\"]:\n", @@ -186,9 +187,9 @@ "for model_name in [\"mistral\", \"llama\"]:\n", " for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(device, model_name=model_name)\n", + " task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(device, model_name=model_name)\n", + " task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n", "\n", " for patching_type in [\"mlp\", \"attention\"]:\n", " fig, ax = plt.subplots(figsize=(10, 5))\n", @@ -283,9 +284,9 @@ "for model_name in [\"mistral\", \"llama\"]:\n", " for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(device, model_name=model_name)\n", + " task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(device, model_name=model_name)\n", + " task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n", "\n", " fig, ax = plt.subplots(figsize=(10, 5))\n", "\n", @@ -369,9 +370,9 @@ "data = []\n", "for model_name, task_name in all_top_heads.keys():\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(device, model_name=model_name)\n", + " task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(device, model_name=model_name)\n", + " task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n", "\n", " acts = get_all_acts(\n", " task,\n", diff --git a/intervention/circle_probe_interventions.py b/intervention/circle_probe_interventions.py index d936af3..71f748d 100644 --- a/intervention/circle_probe_interventions.py +++ b/intervention/circle_probe_interventions.py @@ -42,6 +42,7 @@ help="Choose 'llama' or 'mistral' model", ) parser.add_argument("--device", type=int, default=4, help="CUDA device number") + parser.add_argument("--dtype", type=str, default="float32", help="Data type for torch tensors") parser.add_argument( "--use_inverse_regression_probe", action="store_true", @@ -74,6 +75,7 @@ ) args = parser.parse_args() device = f"cuda:{args.device}" + dtype = args.dtype day_month_choice = args.problem_type circle_letter = args.intervene_on model_name = args.model @@ -101,6 +103,7 @@ # intervention_pca_k = 5 device = "cuda:4" + dtype = "float32" circle_letter = "c" day_month_choice = "day" model_name = "mistral" @@ -131,9 +134,9 @@ # %% if day_month_choice == "day": - task = DaysOfWeekTask(device, model_name=model_name) + task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype) else: - task = MonthsOfYearTask(device, model_name=model_name) + task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype) # %% diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..c82c01a 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -39,13 +39,15 @@ class DaysOfWeekTask: - def __init__(self, device, model_name="mistral", n_devices=None): + def __init__(self, device, model_name="mistral", n_devices=None, dtype="float32"): self.device = device self.model_name = model_name self.n_devices = n_devices + self.dtype = dtype + # Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall) self.allowable_tokens = days_of_week @@ -152,7 +154,7 @@ def get_model(self): if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( - "mistral-7b", device=self.device, n_devices=self.n_devices + "mistral-7b", device=self.device, n_devices=self.n_devices, dtype=self.dtype ) elif self.model_name == "llama": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( @@ -160,6 +162,7 @@ def get_model(self): "meta-llama/Meta-Llama-3-8B", device=self.device, n_devices=self.n_devices, + dtype=self.dtype, ) return self._lazy_model diff --git a/intervention/intervene_in_middle_of_circle.py b/intervention/intervene_in_middle_of_circle.py index 2fb75fb..6e056ca 100644 --- a/intervention/intervene_in_middle_of_circle.py +++ b/intervention/intervene_in_middle_of_circle.py @@ -257,6 +257,7 @@ def get_circle_hook(layer, circle_point): parser.add_argument( "--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu" ) + parser.add_argument("--dtype", type=str, default="float32") args = parser.parse_args() @@ -265,7 +266,7 @@ def get_circle_hook(layer, circle_point): if args.only_paper_plots: task_level_granularity = "day" model_name = "mistral" - task = DaysOfWeekTask(device, model_name=model_name) + task = DaysOfWeekTask(device, model_name=model_name, dtype=args.dtype) layer = 5 bs = range(2, 6) pca_k = 5 @@ -282,9 +283,9 @@ def get_circle_hook(layer, circle_point): bs = range(1, 13) for b in bs: if task_level_granularity == "day": - task = DaysOfWeekTask(device, model_name=model_name) + task = DaysOfWeekTask(device, model_name=model_name, dtype=args.dtype) elif task_level_granularity == "month": - task = MonthsOfYearTask(device, model_name=model_name) + task = MonthsOfYearTask(device, model_name=model_name, dtype=args.dtype) else: raise ValueError(f"Unknown {task_level_granularity}") for pca_k in [5]: diff --git a/intervention/main_text_plots.ipynb b/intervention/main_text_plots.ipynb index 2266226..19cf6d4 100644 --- a/intervention/main_text_plots.ipynb +++ b/intervention/main_text_plots.ipynb @@ -25,7 +25,9 @@ "\n", "os.makedirs(\"figs/paper_plots\", exist_ok=True)\n", "\n", - "torch.set_grad_enabled(False)" + "torch.set_grad_enabled(False)\n", + "device = \"cpu\"\n", + "dtype = \"float32\"" ] }, { @@ -57,7 +59,7 @@ "\n", "\n", "# Left plot\n", - "task = DaysOfWeekTask(\"cpu\", \"mistral\")\n", + "task = DaysOfWeekTask(device, \"mistral\", dtype=dtype)\n", "problems = task.generate_problems()\n", "tokens = task.allowable_tokens\n", "acts = get_acts_pca(task, layer=30, token=task.a_token, pca_k=2)[0]\n", @@ -88,7 +90,7 @@ "ax1.set_ylim(-8, 8)\n", "\n", "# Right plot\n", - "task = MonthsOfYearTask(\"cpu\", \"llama\")\n", + "task = MonthsOfYearTask(device, \"llama\", dtype=dtype)\n", "problems = task.generate_problems()\n", "tokens = task.allowable_tokens\n", "acts = get_acts_pca(task, layer=3, token=task.a_token, pca_k=2)[0]\n", @@ -350,7 +352,7 @@ "s = 0.1\n", "\n", "\n", - "task = DaysOfWeekTask(\"cpu\", model_name=\"mistral\")\n", + "task = DaysOfWeekTask(device, model_name=\"mistral\", dtype=dtype)\n", "layer = 5\n", "token = task.a_token\n", "durations = range(2, 6)\n", @@ -430,7 +432,7 @@ "fig = plt.figure(figsize=(1.65, 1.5))\n", "ax = plt.gca()\n", "\n", - "task = DaysOfWeekTask(\"cpu\", model_name=\"mistral\")\n", + "task = DaysOfWeekTask(device, model_name=\"mistral\", dtype=dtype)\n", "acts = get_acts(task, layer_fetch=25, token_fetch=task.before_c_token)\n", "\n", "problems = task.generate_problems()\n", @@ -524,13 +526,13 @@ "# GPT 2\n", "from transformer_lens import HookedTransformer\n", "\n", - "model = HookedTransformer.from_pretrained(\"gpt2\")\n", + "model = HookedTransformer.from_pretrained(\"gpt2\", device=device, dtype=dtype)\n", "\n", "for task_name in [\"days_of_week\", \"months_of_year\"]:\n", " if task_name == \"days_of_week\":\n", - " task = DaysOfWeekTask(\"cpu\", model_name=\"gpt2\")\n", + " task = DaysOfWeekTask(device, model_name=\"gpt2\", dtype=dtype)\n", " else:\n", - " task = MonthsOfYearTask(\"cpu\", model_name=\"gpt2\")\n", + " task = MonthsOfYearTask(device, model_name=\"gpt2\", dtype=dtype)\n", " problems = task.generate_problems()\n", " answer_logits = [model.to_single_token(token) for token in task.allowable_tokens]\n", " num_correct = 0\n", @@ -546,7 +548,7 @@ ], "metadata": { "kernelspec": { - "display_name": "multiplexing", + "display_name": "multid", "language": "python", "name": "python3" }, diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..b98018b 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -61,13 +61,15 @@ class MonthsOfYearTask: - def __init__(self, device, model_name="mistral", n_devices=None): + def __init__(self, device, model_name="mistral", n_devices=None, dtype="float32"): self.device = device self.model_name = model_name self.n_devices = n_devices + self.dtype = dtype + # Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall) self.allowable_tokens = months_of_year @@ -163,7 +165,10 @@ def get_model(self): if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( - "mistral-7b", device=self.device, n_devices=self.n_devices + "mistral-7b", + device=self.device, + n_devices=self.n_devices, + dtype=self.dtype, ) elif self.model_name == "llama": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( @@ -171,6 +176,7 @@ def get_model(self): "meta-llama/Meta-Llama-3-8B", device=self.device, n_devices=self.n_devices, + dtype=self.dtype, ) return self._lazy_model