From 7a2f20e9fef9cf283a72ce6c9c78a9bb68a48dce Mon Sep 17 00:00:00 2001 From: Bastian Grumbrecht Date: Sat, 27 Dec 2025 21:36:18 +0100 Subject: [PATCH 1/2] feat:add Narwhals utilities for horizontal statistical operations. --- src/centimators/narwhals_utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/centimators/narwhals_utils.py b/src/centimators/narwhals_utils.py index b506710..1848900 100644 --- a/src/centimators/narwhals_utils.py +++ b/src/centimators/narwhals_utils.py @@ -13,7 +13,7 @@ def _ensure_numpy(data, allow_series: bool = False): """Convert data to numpy array, handling both numpy arrays and dataframes. Args: - data: Input data (numpy array, dataframe, or series) + data: Input data (numpy array, dataframe, series, or PyTorch tensor) allow_series: Whether to allow series inputs Returns: @@ -24,6 +24,14 @@ def _ensure_numpy(data, allow_series: bool = False): try: return nw.from_native(data, allow_series=allow_series).to_numpy() except Exception: + # Handle PyTorch tensors (including CUDA tensors) + try: + import torch + if isinstance(data, torch.Tensor): + # Move to CPU if on GPU, then convert to numpy + return data.detach().cpu().numpy() + except ImportError: + pass return numpy.asarray(data) From f55782fd3e199ab7c33fc099f0c4447c5367587a Mon Sep 17 00:00:00 2001 From: Bastian Grumbrecht Date: Sat, 10 Jan 2026 22:20:30 +0100 Subject: [PATCH 2/2] make the spearman loss pickle-able --- src/centimators/losses.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/centimators/losses.py b/src/centimators/losses.py index 124155f..e804419 100644 --- a/src/centimators/losses.py +++ b/src/centimators/losses.py @@ -16,8 +16,9 @@ import keras.ops as K from keras.losses import Loss from keras.config import epsilon +from keras.saving import register_keras_serializable - +@register_keras_serializable() class SpearmanCorrelation(Loss): """Differentiable Spearman rank correlation loss.