Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions intervention/appendix_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
"\n",
"os.makedirs(\"figs/paper_plots\", exist_ok=True)\n",
"\n",
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\""
"device = \"cuda:0\" if torch.cuda.is_available() else \"cpu\"\n",
"dtype = \"float32\""
]
},
{
Expand Down Expand Up @@ -130,9 +131,9 @@
"for model_name in [\"mistral\", \"llama\"]:\n",
" for task_name in [\"task_name\", \"months_of_year\"]:\n",
" if task_name == \"{task_name}\":\n",
" task = DaysOfWeekTask(model_name=model_name, device=device)\n",
" task = DaysOfWeekTask(model_name=model_name, device=device, dtype=dtype)\n",
" else:\n",
" task = MonthsOfYearTask(model_name=model_name, device=device)\n",
" task = MonthsOfYearTask(model_name=model_name, device=device, dtype=dtype)\n",
"\n",
" for keep_same_index in [0, 1]:\n",
" for layer_type in [\"mlp\", \"attention\", \"resid\", \"attention_head\"]:\n",
Expand Down Expand Up @@ -186,9 +187,9 @@
"for model_name in [\"mistral\", \"llama\"]:\n",
" for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
" if task_name == \"days_of_week\":\n",
" task = DaysOfWeekTask(device, model_name=model_name)\n",
" task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n",
" else:\n",
" task = MonthsOfYearTask(device, model_name=model_name)\n",
" task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n",
"\n",
" for patching_type in [\"mlp\", \"attention\"]:\n",
" fig, ax = plt.subplots(figsize=(10, 5))\n",
Expand Down Expand Up @@ -283,9 +284,9 @@
"for model_name in [\"mistral\", \"llama\"]:\n",
" for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
" if task_name == \"days_of_week\":\n",
" task = DaysOfWeekTask(device, model_name=model_name)\n",
" task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n",
" else:\n",
" task = MonthsOfYearTask(device, model_name=model_name)\n",
" task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n",
"\n",
" fig, ax = plt.subplots(figsize=(10, 5))\n",
"\n",
Expand Down Expand Up @@ -369,9 +370,9 @@
"data = []\n",
"for model_name, task_name in all_top_heads.keys():\n",
" if task_name == \"days_of_week\":\n",
" task = DaysOfWeekTask(device, model_name=model_name)\n",
" task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)\n",
" else:\n",
" task = MonthsOfYearTask(device, model_name=model_name)\n",
" task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)\n",
"\n",
" acts = get_all_acts(\n",
" task,\n",
Expand Down
7 changes: 5 additions & 2 deletions intervention/circle_probe_interventions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
help="Choose 'llama' or 'mistral' model",
)
parser.add_argument("--device", type=int, default=4, help="CUDA device number")
parser.add_argument("--dtype", type=str, default="float32", help="Data type for torch tensors")
parser.add_argument(
"--use_inverse_regression_probe",
action="store_true",
Expand Down Expand Up @@ -74,6 +75,7 @@
)
args = parser.parse_args()
device = f"cuda:{args.device}"
dtype = args.dtype
day_month_choice = args.problem_type
circle_letter = args.intervene_on
model_name = args.model
Expand Down Expand Up @@ -101,6 +103,7 @@
# intervention_pca_k = 5

device = "cuda:4"
dtype = "float32"
circle_letter = "c"
day_month_choice = "day"
model_name = "mistral"
Expand Down Expand Up @@ -131,9 +134,9 @@
# %%

if day_month_choice == "day":
task = DaysOfWeekTask(device, model_name=model_name)
task = DaysOfWeekTask(device, model_name=model_name, dtype=dtype)
else:
task = MonthsOfYearTask(device, model_name=model_name)
task = MonthsOfYearTask(device, model_name=model_name, dtype=dtype)

# %%

Expand Down
7 changes: 5 additions & 2 deletions intervention/days_of_week_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@


class DaysOfWeekTask:
def __init__(self, device, model_name="mistral", n_devices=None):
def __init__(self, device, model_name="mistral", n_devices=None, dtype="float32"):
self.device = device

self.model_name = model_name

self.n_devices = n_devices

self.dtype = dtype

# Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall)
self.allowable_tokens = days_of_week

Expand Down Expand Up @@ -152,14 +154,15 @@ def get_model(self):
if self._lazy_model is None:
if self.model_name == "mistral":
self._lazy_model = transformer_lens.HookedTransformer.from_pretrained(
"mistral-7b", device=self.device, n_devices=self.n_devices
"mistral-7b", device=self.device, n_devices=self.n_devices, dtype=self.dtype
)
elif self.model_name == "llama":
self._lazy_model = transformer_lens.HookedTransformer.from_pretrained(
# "NousResearch/Meta-Llama-3-8B",
"meta-llama/Meta-Llama-3-8B",
device=self.device,
n_devices=self.n_devices,
dtype=self.dtype,
)
return self._lazy_model

Expand Down
7 changes: 4 additions & 3 deletions intervention/intervene_in_middle_of_circle.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def get_circle_hook(layer, circle_point):
parser.add_argument(
"--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu"
)
parser.add_argument("--dtype", type=str, default="float32")

args = parser.parse_args()

Expand All @@ -265,7 +266,7 @@ def get_circle_hook(layer, circle_point):
if args.only_paper_plots:
task_level_granularity = "day"
model_name = "mistral"
task = DaysOfWeekTask(device, model_name=model_name)
task = DaysOfWeekTask(device, model_name=model_name, dtype=args.dtype)
layer = 5
bs = range(2, 6)
pca_k = 5
Expand All @@ -282,9 +283,9 @@ def get_circle_hook(layer, circle_point):
bs = range(1, 13)
for b in bs:
if task_level_granularity == "day":
task = DaysOfWeekTask(device, model_name=model_name)
task = DaysOfWeekTask(device, model_name=model_name, dtype=args.dtype)
elif task_level_granularity == "month":
task = MonthsOfYearTask(device, model_name=model_name)
task = MonthsOfYearTask(device, model_name=model_name, dtype=args.dtype)
else:
raise ValueError(f"Unknown {task_level_granularity}")
for pca_k in [5]:
Expand Down
20 changes: 11 additions & 9 deletions intervention/main_text_plots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
"\n",
"os.makedirs(\"figs/paper_plots\", exist_ok=True)\n",
"\n",
"torch.set_grad_enabled(False)"
"torch.set_grad_enabled(False)\n",
"device = \"cpu\"\n",
"dtype = \"float32\""
]
},
{
Expand Down Expand Up @@ -57,7 +59,7 @@
"\n",
"\n",
"# Left plot\n",
"task = DaysOfWeekTask(\"cpu\", \"mistral\")\n",
"task = DaysOfWeekTask(device, \"mistral\", dtype=dtype)\n",
"problems = task.generate_problems()\n",
"tokens = task.allowable_tokens\n",
"acts = get_acts_pca(task, layer=30, token=task.a_token, pca_k=2)[0]\n",
Expand Down Expand Up @@ -88,7 +90,7 @@
"ax1.set_ylim(-8, 8)\n",
"\n",
"# Right plot\n",
"task = MonthsOfYearTask(\"cpu\", \"llama\")\n",
"task = MonthsOfYearTask(device, \"llama\", dtype=dtype)\n",
"problems = task.generate_problems()\n",
"tokens = task.allowable_tokens\n",
"acts = get_acts_pca(task, layer=3, token=task.a_token, pca_k=2)[0]\n",
Expand Down Expand Up @@ -350,7 +352,7 @@
"s = 0.1\n",
"\n",
"\n",
"task = DaysOfWeekTask(\"cpu\", model_name=\"mistral\")\n",
"task = DaysOfWeekTask(device, model_name=\"mistral\", dtype=dtype)\n",
"layer = 5\n",
"token = task.a_token\n",
"durations = range(2, 6)\n",
Expand Down Expand Up @@ -430,7 +432,7 @@
"fig = plt.figure(figsize=(1.65, 1.5))\n",
"ax = plt.gca()\n",
"\n",
"task = DaysOfWeekTask(\"cpu\", model_name=\"mistral\")\n",
"task = DaysOfWeekTask(device, model_name=\"mistral\", dtype=dtype)\n",
"acts = get_acts(task, layer_fetch=25, token_fetch=task.before_c_token)\n",
"\n",
"problems = task.generate_problems()\n",
Expand Down Expand Up @@ -524,13 +526,13 @@
"# GPT 2\n",
"from transformer_lens import HookedTransformer\n",
"\n",
"model = HookedTransformer.from_pretrained(\"gpt2\")\n",
"model = HookedTransformer.from_pretrained(\"gpt2\", device=device, dtype=dtype)\n",
"\n",
"for task_name in [\"days_of_week\", \"months_of_year\"]:\n",
" if task_name == \"days_of_week\":\n",
" task = DaysOfWeekTask(\"cpu\", model_name=\"gpt2\")\n",
" task = DaysOfWeekTask(device, model_name=\"gpt2\", dtype=dtype)\n",
" else:\n",
" task = MonthsOfYearTask(\"cpu\", model_name=\"gpt2\")\n",
" task = MonthsOfYearTask(device, model_name=\"gpt2\", dtype=dtype)\n",
" problems = task.generate_problems()\n",
" answer_logits = [model.to_single_token(token) for token in task.allowable_tokens]\n",
" num_correct = 0\n",
Expand All @@ -546,7 +548,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "multiplexing",
"display_name": "multid",
"language": "python",
"name": "python3"
},
Expand Down
10 changes: 8 additions & 2 deletions intervention/months_of_year_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,15 @@


class MonthsOfYearTask:
def __init__(self, device, model_name="mistral", n_devices=None):
def __init__(self, device, model_name="mistral", n_devices=None, dtype="float32"):
self.device = device

self.model_name = model_name

self.n_devices = n_devices

self.dtype = dtype

# Tokens we expect as possible answers. Best of these can optionally be saved (as opposed to best logit overall)
self.allowable_tokens = months_of_year

Expand Down Expand Up @@ -163,14 +165,18 @@ def get_model(self):
if self._lazy_model is None:
if self.model_name == "mistral":
self._lazy_model = transformer_lens.HookedTransformer.from_pretrained(
"mistral-7b", device=self.device, n_devices=self.n_devices
"mistral-7b",
device=self.device,
n_devices=self.n_devices,
dtype=self.dtype,
)
elif self.model_name == "llama":
self._lazy_model = transformer_lens.HookedTransformer.from_pretrained(
# "NousResearch/Meta-Llama-3-8B",
"meta-llama/Meta-Llama-3-8B",
device=self.device,
n_devices=self.n_devices,
dtype=self.dtype,
)
return self._lazy_model

Expand Down