Skip to content

Commit 31fc7ab

Browse files
Merge pull request #1 from GT-LIT-Lab/mps-support
add support for mac GPUS
2 parents 023226a + 135e685 commit 31fc7ab

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

encoding/models/nested_cv.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,14 @@ def fit_predict(
8181
alphas = np.logspace(-1, 8, 10)
8282

8383
# Determine device - use GPU if available and requested
84-
device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu"
84+
85+
if use_gpu:
86+
if torch.backends.mps.is_available():
87+
device = "mps:0"
88+
elif torch.cuda.is_available():
89+
device = "cuda"
90+
else:
91+
device = "cpu"
8592
logger.info(f"Using device: {device}")
8693
logger.info(f"Folding type: {folding_type}")
8794

@@ -397,7 +404,7 @@ def _find_best_alphas(
397404
# Find the best alpha for each voxel
398405
best_alpha_idx = torch.argmax(mean_inner_corrs, dim=0) # Shape: (n_voxels,)
399406
best_valphas = torch.tensor(
400-
[alphas[i] for i in best_alpha_idx], device=X_train.device
407+
[alphas[i] for i in best_alpha_idx], device=X_train.device, dtype=torch.float32
401408
)
402409
if logger:
403410
logger.info("Found best alphas for each voxel")

0 commit comments

Comments
 (0)