diff --git a/tests/test_negative_lambda_reg.py b/tests/test_negative_lambda_reg.py new file mode 100644 index 0000000..411d2fa --- /dev/null +++ b/tests/test_negative_lambda_reg.py @@ -0,0 +1,27 @@ +import numpy as np +import pytest + +from transformer_instant.elm import elm_fit +from transformer_instant.linalg import ridge_fit_closed_form +from transformer_instant.utils import krr_fit + + +def test_ridge_negative_lambda_reg(): + X = np.random.randn(4, 3) + y = np.random.randn(4, 1) + with pytest.raises(ValueError): + ridge_fit_closed_form(X, y, lambda_reg=-0.1) + + +def test_krr_negative_lambda_reg(): + X = np.random.randn(4, 3) + y = np.random.randn(4, 1) + with pytest.raises(ValueError): + krr_fit(X, y, lambda_reg=-0.1) + + +def test_elm_negative_lambda_reg(): + X = np.random.randn(4, 3) + y = np.random.randn(4, 1) + with pytest.raises(ValueError): + elm_fit(X, y, lambda_reg=-0.1) diff --git a/transformer_instant/elm.py b/transformer_instant/elm.py index eeb9a56..930c646 100644 --- a/transformer_instant/elm.py +++ b/transformer_instant/elm.py @@ -29,6 +29,9 @@ def elm_fit( ) -> dict[str, Array | float | int | str]: X = np.asarray(X, dtype=np.float64) Y = np.asarray(Y, dtype=np.float64) + lambda_reg = float(lambda_reg) + if lambda_reg < 0.0: + raise ValueError("lambda_reg must be non-negative") n, d = X.shape rng = np.random.default_rng(random_seed) W_hidden = rng.normal( @@ -42,7 +45,7 @@ def elm_fit( else: raise ValueError("Unsupported activation: choose 'tanh' or 'relu'") H = act(cast(Array, (X @ W_hidden + b_hidden).astype(np.float64))) - if H.shape[0] == H.shape[1] and float(lambda_reg) <= 1e-8: + if H.shape[0] == H.shape[1] and lambda_reg <= 1e-8: try: beta = cast(Array, np.linalg.solve(H, Y).astype(np.float64)) except Exception: @@ -52,7 +55,7 @@ def elm_fit( else: HtH = H.T @ H if lambda_reg > 0.0: - HtH = HtH + float(lambda_reg) * np.eye(HtH.shape[0], dtype=HtH.dtype) + HtH = HtH + lambda_reg * np.eye(HtH.shape[0], dtype=HtH.dtype) Hty = H.T @ Y beta = solve_spd(HtH, Hty) return { @@ -61,7 +64,7 @@ def elm_fit( "beta": beta, "activation": activation, "weight_scale": float(weight_scale), - "lambda_reg": float(lambda_reg), + "lambda_reg": lambda_reg, "hidden_units": int(hidden_units), } diff --git a/transformer_instant/linalg.py b/transformer_instant/linalg.py index 8863429..f70441d 100644 --- a/transformer_instant/linalg.py +++ b/transformer_instant/linalg.py @@ -45,6 +45,9 @@ def ridge_fit_closed_form(X: Array, Y: Array, lambda_reg: float = 1e-3) -> Array """ X = np.asarray(X, dtype=np.float64) Y = np.asarray(Y, dtype=np.float64) + lambda_reg = float(lambda_reg) + if lambda_reg < 0.0: + raise ValueError("lambda_reg must be non-negative") if Y.ndim == 1: Y = Y.reshape(-1, 1) squeeze = True @@ -53,7 +56,7 @@ def ridge_fit_closed_form(X: Array, Y: Array, lambda_reg: float = 1e-3) -> Array n, d = X.shape xtx = X.T @ X if lambda_reg > 0.0: - xtx = xtx + float(lambda_reg) * np.eye(d, dtype=xtx.dtype) + xtx = xtx + lambda_reg * np.eye(d, dtype=xtx.dtype) xty = X.T @ Y W = solve_spd(xtx, xty) if squeeze and W.ndim == 2 and W.shape[1] == 1: diff --git a/transformer_instant/utils.py b/transformer_instant/utils.py index ae30824..0732696 100644 --- a/transformer_instant/utils.py +++ b/transformer_instant/utils.py @@ -26,13 +26,16 @@ def krr_fit( """ X = np.asarray(X, dtype=np.float64) Y = np.asarray(Y, dtype=np.float64) + lambda_reg = float(lambda_reg) + if lambda_reg < 0.0: + raise ValueError("lambda_reg must be non-negative") if kernel != "rbf": raise ValueError("Only 'rbf' kernel supported in krr_fit") # Build symmetric Gram with self-consistent diagonal for better precision K = rbf_cross_kernel(X, X, length_scale=length_scale, variance=variance) K = 0.5 * (K + K.T) n = K.shape[0] - lambda_eff = 0.0 if float(lambda_reg) <= 1e-8 else float(lambda_reg) + lambda_eff = 0.0 if lambda_reg <= 1e-8 else lambda_reg A = K + lambda_eff * np.eye(n, dtype=K.dtype) alpha = solve_spd(A, Y) return { @@ -41,7 +44,7 @@ def krr_fit( "kernel": kernel, "length_scale": float(length_scale), "variance": float(variance), - "lambda_reg": float(lambda_reg), + "lambda_reg": lambda_reg, }