diff --git a/torch_crps/analytical/studentt.py b/torch_crps/analytical/studentt.py index e8dba89..da8958b 100644 --- a/torch_crps/analytical/studentt.py +++ b/torch_crps/analytical/studentt.py @@ -30,8 +30,8 @@ def standardized_studentt_cdf_via_scipy( "Install `torch-crps` with the 'studentt' dependency group, e.g. `pip install torch-crps[studentt]`." ) from e - z_np = z.detach().cpu().numpy() - nu_np = nu.detach().cpu().numpy() if isinstance(nu, torch.Tensor) else nu + z_np = z.detach().float().cpu().numpy() # float() handles bfloat16 + nu_np = nu.detach().float().cpu().numpy() if isinstance(nu, torch.Tensor) else nu # float() handles bfloat16 cdf_z_np = scipy_student_t.cdf(x=z_np, df=nu_np)