Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ jobs:

- name: Compile Python files
run: pdm run python -m compileall src tests

- name: Run tests
run: pdm run python -m pytest

- name: Lint
run: pdm run ruff check .

- name: Check formatting
run: pdm run ruff format --check .

- name: Run tests
run: pdm run python -m pytest
run: pdm run ruff format --check .
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,10 @@ dev = [
"pytest>=9.0.3",
"ruff>=0.15.12",
]

[tool.pdm.scripts]
check = "python -m compileall src tests"
lint = "ruff check ."
format-check = "ruff format --check ."
test = "pytest"
ci = {composite = ["check", "test", "lint", "format-check"]}
78 changes: 75 additions & 3 deletions src/genotypeprediction/models/gblup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
from scipy.linalg import solve

from genotypeprediction.data.preprocessing import GenotypeStandardizer
from genotypeprediction.evaluation.metrics import r2
from genotypeprediction.evaluation.metrics import r2, pearson_corr
from genotypeprediction.inference.reml import estimate_reml_variance_components

from sklearn.metrics import mean_squared_error, mean_absolute_error
from sklearn.model_selection import train_test_split
from sklearn.exceptions import NotFittedError


class GBLUPDual:
Expand Down Expand Up @@ -52,3 +51,76 @@ def _fit_standardized(
assume_a="pos",
)
return K_train, alpha_hat

def fit(
self,
X_train: np.ndarray,
y_train: np.ndarray,
lambda_value: float | None = None,
estimate_lambda_reml: bool = True,
feature_names: list[str] | None = None,
) -> "GBLUPDual":
"""Fit dual GBLUP.

The default path estimates = ``lambda_g = sigma_e2 / sigma_g2`` by REML.
"""

self.standardizer_ = GenotypeStandardizer()
self.X_train_ = self.standardizer_.fit_transform(
X_train, feature_names=feature_names
)
self.standardizer_.fit_y(y_train)
self.K_train_ = self._relationship_matrix(self.X_train_)

if lambda_value is not None:
self.lambda_g = float(lambda_value)
self.sigma_e2_hat = None
self.sigma_g2_hat = None
elif estimate_lambda_reml:
reml_estimates = estimate_reml_variance_components(
K=self.K_train_, y=np.asarray(y_train, dtype=float)
)
self.lambda_g = float(reml_estimates["lambda_g"])
self.sigma_e2_hat = float(reml_estimates["sigma_e2_hat"])
self.sigma_g2_hat = float(reml_estimates["sigma_e2_hat"])
else:
self.lambda_g = 1.0
self.sigma_e2_hat = None
self.sigma_g2_hat = None

y_centered = self.standardizer_.center_y(y_train)
self.alpha_hat = solve(
self.K_train_ + self.lambda_g * np.eye(self.K_train_.shape[0]),
y_centered,
assume_a="pos",
)
return self

def predict(self, X_test: np.ndarray) -> np.ndarray:
"""Predict phenotypes on the original scale."""

if (
self.standardizer_ is None
or self.X_train_ is None
or self.alpha_hat is None
):
raise NotFittedError("The model must be fitted before prediction.")

X_test_standardized = self.standardizer_.transform(X_test)
K_test_train = (X_test_standardized @ self.X_train_T) / self.X_train_.shape[1]
y_pred_centered = K_test_train @ self.alpha_hat
return self.standardizer_.restore_y(y_pred_centered)

def score(
self, X_test: np.ndarray, y_test: np.ndarray, method: str = "r2"
) -> float:
"""Return the out-of-sample method to measure performance.

method: "r2", "corr"
"""
map_performance = {"r2": r2, "corr": pearson_corr}
if method not in map_performance.keys():
raise ValueError("The method must be one of the options 'corr' or 'r2.")
performance = map_performance[method]

return performance(y_true=y_test, y_pred=self.predict(X_test))
Loading