From 8b4089eaf5ca37cb9cb7957da8ab1cdd27ccd600 Mon Sep 17 00:00:00 2001 From: Jason Gross Date: Tue, 27 Aug 2024 22:28:44 -0700 Subject: [PATCH] Don't fail with "AssertionError: Not enough CUDA devices to support n_devices 2" when there aren't 2 GPUs ``` File "/home/jason/MultiDimensionalFeatures/multid/lib/python3.10/site-packages/transformer_lens/HookedTransformerConfig.py", line 315, in __post_init__ torch.cuda.device_count() >= self.n_devices AssertionError: Not enough CUDA devices to support n_devices 2 ``` --- intervention/days_of_week_task.py | 7 ++++++- intervention/months_of_year_task.py | 7 ++++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..0819ab3 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -7,6 +7,7 @@ import numpy as np import transformer_lens +import torch from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching @@ -148,7 +149,11 @@ def generate_problems(self): def get_model(self): if self.n_devices is None: - self.n_devices = 2 if "llama" == self.model_name else 1 + self.n_devices = ( + min(2, max(1, torch.cuda.device_count())) + if "llama" == self.model_name + else 1 + ) if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained( diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..49e9e2b 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -7,6 +7,7 @@ import numpy as np import transformer_lens +import torch from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching @@ -159,7 +160,11 @@ def generate_problems(self): def get_model(self): if self.n_devices is None: - self.n_devices = 2 if "llama" == self.model_name else 1 + self.n_devices = ( + min(2, max(1, torch.cuda.device_count())) + if "llama" == self.model_name + else 1 + ) if self._lazy_model is None: if self.model_name == "mistral": self._lazy_model = transformer_lens.HookedTransformer.from_pretrained(