From 1e3cd341d4017da455db5ccacccea2e4971f55ee Mon Sep 17 00:00:00 2001 From: jhonatan Date: Tue, 5 May 2026 13:28:46 +0200 Subject: [PATCH 1/2] added gblup model --- pyproject.toml | 7 +++ src/genotypeprediction/models/gblup.py | 78 +++++++++++++++++++++++++- 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2518858..54fcd0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,3 +23,10 @@ dev = [ "pytest>=9.0.3", "ruff>=0.15.12", ] + +[tool.pdm.scripts] +check = "python -m compileall src tests" +lint = "ruff check ." +format-check = "ruff format --check ." +test = "pytest" +ci = {composite = ["check", "lint", "format-check", "test"]} \ No newline at end of file diff --git a/src/genotypeprediction/models/gblup.py b/src/genotypeprediction/models/gblup.py index 02f84ed..395ff80 100644 --- a/src/genotypeprediction/models/gblup.py +++ b/src/genotypeprediction/models/gblup.py @@ -6,11 +6,10 @@ from scipy.linalg import solve from genotypeprediction.data.preprocessing import GenotypeStandardizer -from genotypeprediction.evaluation.metrics import r2 +from genotypeprediction.evaluation.metrics import r2, pearson_corr from genotypeprediction.inference.reml import estimate_reml_variance_components -from sklearn.metrics import mean_squared_error, mean_absolute_error -from sklearn.model_selection import train_test_split +from sklearn.exceptions import NotFittedError class GBLUPDual: @@ -52,3 +51,76 @@ def _fit_standardized( assume_a="pos", ) return K_train, alpha_hat + + def fit( + self, + X_train: np.ndarray, + y_train: np.ndarray, + lambda_value: float | None = None, + estimate_lambda_reml: bool = True, + feature_names: list[str] | None = None, + ) -> "GBLUPDual": + """Fit dual GBLUP. + + The default path estimates = ``lambda_g = sigma_e2 / sigma_g2`` by REML. + """ + + self.standardizer_ = GenotypeStandardizer() + self.X_train_ = self.standardizer_.fit_transform( + X_train, feature_names=feature_names + ) + self.standardizer_.fit_y(y_train) + self.K_train_ = self._relationship_matrix(self.X_train_) + + if lambda_value is not None: + self.lambda_g = float(lambda_value) + self.sigma_e2_hat = None + self.sigma_g2_hat = None + elif estimate_lambda_reml: + reml_estimates = estimate_reml_variance_components( + K=self.K_train_, y=np.asarray(y_train, dtype=float) + ) + self.lambda_g = float(reml_estimates["lambda_g"]) + self.sigma_e2_hat = float(reml_estimates["sigma_e2_hat"]) + self.sigma_g2_hat = float(reml_estimates["sigma_e2_hat"]) + else: + self.lambda_g = 1.0 + self.sigma_e2_hat = None + self.sigma_g2_hat = None + + y_centered = self.standardizer_.center_y(y_train) + self.alpha_hat = solve( + self.K_train_ + self.lambda_g * np.eye(self.K_train_.shape[0]), + y_centered, + assume_a="pos", + ) + return self + + def predict(self, X_test: np.ndarray) -> np.ndarray: + """Predict phenotypes on the original scale.""" + + if ( + self.standardizer_ is None + or self.X_train_ is None + or self.alpha_hat is None + ): + raise NotFittedError("The model must be fitted before prediction.") + + X_test_standardized = self.standardizer_.transform(X_test) + K_test_train = (X_test_standardized @ self.X_train_T) / self.X_train_.shape[1] + y_pred_centered = K_test_train @ self.alpha_hat + return self.standardizer_.restore_y(y_pred_centered) + + def score( + self, X_test: np.ndarray, y_test: np.ndarray, method: str = "r2" + ) -> float: + """Return the out-of-sample method to measure performance. + + method: "r2", "corr" + """ + map_performance = {"r2": r2, "corr": pearson_corr} + if method not in map_performance.keys(): + raise ValueError("The method must be one of the options 'corr' or 'r2.") + performance = map_performance[method] + + return performance(y_true=y_test, y_pred=self.predict(X_test)) From a28cc3cd0b699dcff9dd128fbe5bf0ef722f5a4e Mon Sep 17 00:00:00 2001 From: jhonatan Date: Tue, 5 May 2026 13:33:45 +0200 Subject: [PATCH 2/2] changed order of ci --- .github/workflows/ci.yml | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4d03549..11cceaf 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,12 +27,12 @@ jobs: - name: Compile Python files run: pdm run python -m compileall src tests + + - name: Run tests + run: pdm run python -m pytest - name: Lint run: pdm run ruff check . - name: Check formatting - run: pdm run ruff format --check . - - - name: Run tests - run: pdm run python -m pytest \ No newline at end of file + run: pdm run ruff format --check . \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 54fcd0d..9db03a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,4 +29,4 @@ check = "python -m compileall src tests" lint = "ruff check ." format-check = "ruff format --check ." test = "pytest" -ci = {composite = ["check", "lint", "format-check", "test"]} \ No newline at end of file +ci = {composite = ["check", "test", "lint", "format-check"]} \ No newline at end of file