Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions tests/test_negative_lambda_reg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np
import pytest

from transformer_instant.elm import elm_fit
from transformer_instant.linalg import ridge_fit_closed_form
from transformer_instant.utils import krr_fit


def test_ridge_negative_lambda_reg():
X = np.random.randn(4, 3)
y = np.random.randn(4, 1)
with pytest.raises(ValueError):
ridge_fit_closed_form(X, y, lambda_reg=-0.1)


def test_krr_negative_lambda_reg():
X = np.random.randn(4, 3)
y = np.random.randn(4, 1)
with pytest.raises(ValueError):
krr_fit(X, y, lambda_reg=-0.1)


def test_elm_negative_lambda_reg():
X = np.random.randn(4, 3)
y = np.random.randn(4, 1)
with pytest.raises(ValueError):
elm_fit(X, y, lambda_reg=-0.1)
9 changes: 6 additions & 3 deletions transformer_instant/elm.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ def elm_fit(
) -> dict[str, Array | float | int | str]:
X = np.asarray(X, dtype=np.float64)
Y = np.asarray(Y, dtype=np.float64)
lambda_reg = float(lambda_reg)
if lambda_reg < 0.0:
raise ValueError("lambda_reg must be non-negative")
n, d = X.shape
rng = np.random.default_rng(random_seed)
W_hidden = rng.normal(
Expand All @@ -42,7 +45,7 @@ def elm_fit(
else:
raise ValueError("Unsupported activation: choose 'tanh' or 'relu'")
H = act(cast(Array, (X @ W_hidden + b_hidden).astype(np.float64)))
if H.shape[0] == H.shape[1] and float(lambda_reg) <= 1e-8:
if H.shape[0] == H.shape[1] and lambda_reg <= 1e-8:
try:
beta = cast(Array, np.linalg.solve(H, Y).astype(np.float64))
except Exception:
Expand All @@ -52,7 +55,7 @@ def elm_fit(
else:
HtH = H.T @ H
if lambda_reg > 0.0:
HtH = HtH + float(lambda_reg) * np.eye(HtH.shape[0], dtype=HtH.dtype)
HtH = HtH + lambda_reg * np.eye(HtH.shape[0], dtype=HtH.dtype)
Hty = H.T @ Y
beta = solve_spd(HtH, Hty)
return {
Expand All @@ -61,7 +64,7 @@ def elm_fit(
"beta": beta,
"activation": activation,
"weight_scale": float(weight_scale),
"lambda_reg": float(lambda_reg),
"lambda_reg": lambda_reg,
"hidden_units": int(hidden_units),
}

Expand Down
5 changes: 4 additions & 1 deletion transformer_instant/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def ridge_fit_closed_form(X: Array, Y: Array, lambda_reg: float = 1e-3) -> Array
"""
X = np.asarray(X, dtype=np.float64)
Y = np.asarray(Y, dtype=np.float64)
lambda_reg = float(lambda_reg)
if lambda_reg < 0.0:
raise ValueError("lambda_reg must be non-negative")
if Y.ndim == 1:
Y = Y.reshape(-1, 1)
squeeze = True
Expand All @@ -53,7 +56,7 @@ def ridge_fit_closed_form(X: Array, Y: Array, lambda_reg: float = 1e-3) -> Array
n, d = X.shape
xtx = X.T @ X
if lambda_reg > 0.0:
xtx = xtx + float(lambda_reg) * np.eye(d, dtype=xtx.dtype)
xtx = xtx + lambda_reg * np.eye(d, dtype=xtx.dtype)
xty = X.T @ Y
W = solve_spd(xtx, xty)
if squeeze and W.ndim == 2 and W.shape[1] == 1:
Expand Down
7 changes: 5 additions & 2 deletions transformer_instant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@ def krr_fit(
"""
X = np.asarray(X, dtype=np.float64)
Y = np.asarray(Y, dtype=np.float64)
lambda_reg = float(lambda_reg)
if lambda_reg < 0.0:
raise ValueError("lambda_reg must be non-negative")
if kernel != "rbf":
raise ValueError("Only 'rbf' kernel supported in krr_fit")
# Build symmetric Gram with self-consistent diagonal for better precision
K = rbf_cross_kernel(X, X, length_scale=length_scale, variance=variance)
K = 0.5 * (K + K.T)
n = K.shape[0]
lambda_eff = 0.0 if float(lambda_reg) <= 1e-8 else float(lambda_reg)
lambda_eff = 0.0 if lambda_reg <= 1e-8 else lambda_reg
A = K + lambda_eff * np.eye(n, dtype=K.dtype)
alpha = solve_spd(A, Y)
return {
Expand All @@ -41,7 +44,7 @@ def krr_fit(
"kernel": kernel,
"length_scale": float(length_scale),
"variance": float(variance),
"lambda_reg": float(lambda_reg),
"lambda_reg": lambda_reg,
}


Expand Down
Loading