File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -81,7 +81,14 @@ def fit_predict(
8181 alphas = np .logspace (- 1 , 8 , 10 )
8282
8383 # Determine device - use GPU if available and requested
84- device = "cuda" if use_gpu and torch .cuda .is_available () else "cpu"
84+
85+ if use_gpu :
86+ if torch .backends .mps .is_available ():
87+ device = "mps:0"
88+ elif torch .cuda .is_available ():
89+ device = "cuda"
90+ else :
91+ device = "cpu"
8592 logger .info (f"Using device: { device } " )
8693 logger .info (f"Folding type: { folding_type } " )
8794
@@ -397,7 +404,7 @@ def _find_best_alphas(
397404 # Find the best alpha for each voxel
398405 best_alpha_idx = torch .argmax (mean_inner_corrs , dim = 0 ) # Shape: (n_voxels,)
399406 best_valphas = torch .tensor (
400- [alphas [i ] for i in best_alpha_idx ], device = X_train .device
407+ [alphas [i ] for i in best_alpha_idx ], device = X_train .device , dtype = torch . float32
401408 )
402409 if logger :
403410 logger .info ("Found best alphas for each voxel" )
You can’t perform that action at this time.
0 commit comments