From a7aab6497ce69fb508d351e78367c91f503eeed9 Mon Sep 17 00:00:00 2001 From: EngineerDanny Date: Tue, 24 Mar 2026 10:24:14 -0700 Subject: [PATCH 1/4] fix conditional prediction var_par output for sparse covariance models --- R/PLNfit-class.R | 9 ++++++--- tests/testthat/test-plnfit.R | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 3 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 3c5001b7..c5438f14 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -577,9 +577,12 @@ PLNfit <- R6Class( if (is.null(O)) O <- matrix(0, n_new, self$p) # Compute parameters of the law - vcov11 <- private$Sigma[cond , cond, drop = FALSE] - vcov22 <- private$Sigma[!cond, !cond, drop = FALSE] - vcov12 <- private$Sigma[cond , !cond, drop = FALSE] + # as.matrix() coerces sparse Matrix (returned by diagonal/spherical covariance + # models) to dense, so that simplify2array() in the map below produces a + # numeric array rather than a list of sparse Matrix objects. + vcov11 <- as.matrix(private$Sigma[cond , cond, drop = FALSE]) + vcov22 <- as.matrix(private$Sigma[!cond, !cond, drop = FALSE]) + vcov12 <- as.matrix(private$Sigma[cond , !cond, drop = FALSE]) prec11 <- solve(vcov11) A <- crossprod(vcov12, prec11) Sigma21 <- vcov22 - A %*% vcov12 diff --git a/tests/testthat/test-plnfit.R b/tests/testthat/test-plnfit.R index 7c749cb9..2ada59fa 100644 --- a/tests/testthat/test-plnfit.R +++ b/tests/testthat/test-plnfit.R @@ -173,6 +173,30 @@ test_that("PLN fit: Check conditional prediction", { }) +test_that("PLN fit: Check conditional prediction with sparse covariance models", { + + n_cond <- 10 + p_cond <- 2 + p <- ncol(trichoptera$Abundance) + Yc <- trichoptera$Abundance[1:n_cond, 1:p_cond, drop = FALSE] + newdata <- trichoptera[1:n_cond, , drop = FALSE] + + for (covariance in c("diagonal", "spherical")) { + model <- PLN( + Abundance ~ 1, + data = trichoptera, + control = PLN_param(covariance = covariance, trace = 0) + ) + + pred <- predict_cond(model, newdata, Yc, type = "response", var_par = TRUE) + expect_equal(dim(pred), c(n_cond, p - p_cond)) + expect_equal(dim(attr(pred, "M")), dim(pred)) + expect_equal(dim(attr(pred, "S")), c(p - p_cond, p - p_cond, n_cond)) + expect_true(is.array(attr(pred, "S"))) + expect_true(is.numeric(attr(pred, "S"))) + } +}) + test_that("PLN fit: Check number of parameters", { p <- ncol(trichoptera$Abundance) From f9b98c21df3ff5be8e65da0d09e8b875aa94265b Mon Sep 17 00:00:00 2001 From: EngineerDanny Date: Tue, 24 Mar 2026 11:11:44 -0700 Subject: [PATCH 2/4] Add torch implementation to PLNPCA --- R/PLNPCA.R | 2 +- R/PLNPCAfit-class.R | 189 +++++++++++++++++++++++++++++++- tests/testthat/test-plnpcafit.R | 33 ++++++ 3 files changed, 220 insertions(+), 4 deletions(-) diff --git a/R/PLNPCA.R b/R/PLNPCA.R index 505061d4..c0749232 100644 --- a/R/PLNPCA.R +++ b/R/PLNPCA.R @@ -74,7 +74,7 @@ PLNPCA <- function(formula, data, subset, weights, ranks = 1:5, control = PLNPCA #' @inherit PLN_param details #' @export PLNPCA_param <- function( - backend = "nlopt", + backend = c("nlopt", "torch"), trace = 1 , config_optim = list() , config_post = list() , diff --git a/R/PLNPCAfit-class.R b/R/PLNPCAfit-class.R index 7baa2c2a..d63bf1b5 100644 --- a/R/PLNPCAfit-class.R +++ b/R/PLNPCAfit-class.R @@ -47,7 +47,185 @@ PLNPCAfit <- R6Class( ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% private = list( C = NULL, - svdCM = NULL + svdCM = NULL, + + ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + ## PRIVATE TORCH METHODS FOR RANK-CONSTRAINED OPTIMIZATION + ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + torch_elbo_rank_core = function(data, M, S, B, C, index) { + S2 <- torch_square(S[index]) # (batch, q) + C2 <- torch_square(C) # (p, q) + Z <- data$O[index] + + torch_mm(M[index], torch_t(C)) + + torch_mm(data$X[index], B) # (batch, p) + A <- torch_exp(Z + 0.5 * torch_mm(S2, torch_t(C2))) + lik_part <- torch_sum(data$w[index, NULL] * (A - data$Y[index] * Z)) + kl_part <- 0.5 * torch_sum(data$w[index, NULL] * + (torch_square(M[index]) + S2 - torch_log(S2) - 1)) + lik_part + kl_part + }, + + torch_elbo_rank = function(data, params, index = torch_tensor(1:self$n)) { + private$torch_elbo_rank_core(data, params$M, params$S, params$B, params$C, index) + }, + + torch_vloglik_rank = function(data, params) { + S2 <- torch_square(params$S) + C2 <- torch_square(params$C) + Z <- data$O + torch_mm(params$M, torch_t(params$C)) + torch_mm(data$X, params$B) + A <- torch_exp(Z + 0.5 * torch_mm(S2, torch_t(C2))) + Ji <- - torch_sum(.logfactorial_torch(data$Y), dim = 2) + + torch_sum(data$Y * Z - A, dim = 2) - + 0.5 * torch_sum(torch_square(params$M) + S2 - torch_log(S2) - 1, dim = 2) + Ji <- .5 * self$p + as.numeric(Ji$cpu()) + attr(Ji, "weights") <- as.numeric(data$w$cpu()) + Ji + }, + + torch_optimize_rank_core = function(data, params, config, n_obs, loss_fn) { + optimizer <- switch(config$algorithm, + "RPROP" = optim_rprop(params, lr = config$lr, etas = config$etas, step_sizes = config$step_sizes), + "RMSPROP" = optim_rmsprop(params, lr = config$lr, weight_decay = config$weight_decay, momentum = config$momentum, centered = config$centered), + "ADAM" = optim_adam(params, lr = config$lr, weight_decay = config$weight_decay), + "ADAGRAD" = optim_adagrad(params, lr = config$lr, weight_decay = config$weight_decay) + ) + + status <- 5 + num_epoch <- config$num_epoch + num_batch <- config$num_batch + batch_size <- floor(n_obs / num_batch) + + objective <- double(length = config$num_epoch + 1) + for (iterate in 1:num_epoch) { + permute <- torch::torch_tensor(sample.int(n_obs), dtype = torch_long(), device = config$device) + for (batch_idx in 1:num_batch) { + index <- permute[(batch_size * (batch_idx - 1) + 1):(batch_idx * batch_size)] + optimizer$zero_grad() + loss <- loss_fn(index) + loss$backward() + optimizer$step() + } + + objective[iterate + 1] <- loss$item() + delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1]) + + if (!is.finite(loss$item())) { + stop(sprintf( + "The ELBO diverged during the optimization procedure.\nConsider using:\n* a different optimizer (current optimizer: %s)\n* a smaller learning rate (current rate: %.3f)\nwith `control = PLNPCA_param(backend = 'torch', config_optim = list(algorithm = ..., lr = ...))`", + config$algorithm, config$lr + )) + } + + if (config$trace > 1 && (iterate %% 50 == 1)) + cat('\niteration:', iterate, 'objective', objective[iterate + 1], + 'delta_f', round(delta_f, 6)) + + if (delta_f < config$ftol_rel) status <- 3 + if (status %in% c(3, 4)) { + objective <- objective[seq_len(iterate + 1)] + break + } + } + + list( + params = params, + objective = objective, + iterations = iterate, + status = status + ) + }, + + torch_optimize_vestep_rank = function(data, params, B, C, config) { + if (config$trace > 1) + message(paste("optimizing with device:", config$device)) + + n <- nrow(data$Y) + data <- lapply(data, torch_tensor, dtype = torch_float32(), device = config$device) + params <- lapply(params, torch_tensor, dtype = torch_float32(), requires_grad = TRUE, device = config$device) + B <- torch_tensor(B, dtype = torch_float32(), device = config$device) + C <- torch_tensor(C, dtype = torch_float32(), device = config$device) + + optim_out <- private$torch_optimize_rank_core( + data = data, + params = params, + config = config, + n_obs = n, + loss_fn = function(index) { + private$torch_elbo_rank_core(data, params$M, params$S, B, C, index) + } + ) + params_r <- lapply(optim_out$params, function(x) as.matrix(x$cpu())) + Ji_r <- private$torch_vloglik_rank(data, c(optim_out$params, list(B = B, C = C))) + + list( + M = params_r$M, + S = params_r$S, + Ji = Ji_r, + monitoring = list( + objective = optim_out$objective, + iterations = optim_out$iterations, + status = optim_out$status, + backend = "torch" + ) + ) + }, + + torch_optimize_rank = function(data, params, config) { + if (config$trace > 1) + message(paste("optimizing with device:", config$device)) + + data <- lapply(data, torch_tensor, dtype = torch_float32(), device = config$device) + params <- lapply(params, torch_tensor, dtype = torch_float32(), requires_grad = TRUE, device = config$device) + + optim_out <- private$torch_optimize_rank_core( + data = data, + params = params, + config = config, + n_obs = self$n, + loss_fn = function(index) { + private$torch_elbo_rank(data, params, index) + } + ) + + ## Compute derived quantities on CPU + params_r <- lapply(optim_out$params, function(x) as.matrix(x$cpu())) + data_r <- lapply(data, function(x) as.matrix(x$cpu())) + + q <- ncol(params_r$M) + S2_r <- params_r$S^2 + C2_r <- params_r$C^2 + Z_r <- data_r$O + params_r$M %*% t(params_r$C) + data_r$X %*% params_r$B + A_r <- exp(Z_r + 0.5 * S2_r %*% t(C2_r)) + w_r <- as.numeric(data_r$w) + + wM <- params_r$M * sqrt(w_r) + inner_q <- (crossprod(wM) + diag(colSums(S2_r * w_r), nrow = q)) / sum(w_r) + Sigma_r <- params_r$C %*% inner_q %*% t(params_r$C) + Omega_r <- params_r$C %*% solve(inner_q) %*% t(params_r$C) + + Ji_r <- .5 * self$p - rowSums(.logfactorial(as.matrix(data_r$Y))) + + rowSums(data_r$Y * Z_r - A_r) - + 0.5 * rowSums(params_r$M^2 + S2_r - log(S2_r) - 1) + attr(Ji_r, "weights") <- w_r + + list( + B = params_r$B, + C = params_r$C, + M = params_r$M, + S = params_r$S, + Z = Z_r, + A = A_r, + Sigma = Sigma_r, + Omega = Omega_r, + Ji = Ji_r, + monitoring = list( + objective = optim_out$objective, + iterations = optim_out$iterations, + status = optim_out$status, + backend = "torch" + ) + ) + } ), ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## PUBLIC MEMBERS ---- @@ -58,7 +236,11 @@ PLNPCAfit <- R6Class( #' @description Initialize a [`PLNPCAfit`] object initialize = function(rank, responses, covariates, offsets, weights, formula, control) { super$initialize(responses, covariates, offsets, weights, formula, control) - private$optimizer$main <- nlopt_optimize_rank + if (control$backend == "torch") { + private$optimizer$main <- private$torch_optimize_rank + } else { + private$optimizer$main <- nlopt_optimize_rank + } private$optimizer$vestep <- nlopt_optimize_vestep_rank if (!is.null(control$svdM)) { svdM <- control$svdM @@ -125,7 +307,8 @@ PLNPCAfit <- R6Class( B = private$B, C = private$C, config = control$config_optim) - optim_out <- do.call(private$optimizer$vestep, args) + vestep_optimizer <- if (control$backend == "torch") private$torch_optimize_vestep_rank else private$optimizer$vestep + optim_out <- do.call(vestep_optimizer, args) optim_out }, diff --git a/tests/testthat/test-plnpcafit.R b/tests/testthat/test-plnpcafit.R index 96f83292..8b8d0978 100644 --- a/tests/testthat/test-plnpcafit.R +++ b/tests/testthat/test-plnpcafit.R @@ -80,6 +80,39 @@ test_that("PLNPCA fit: check classes, getters and field access", { expect_true(inherits(myPLNfit, "PCA")) }) +test_that("PLNPCA torch backend works for fit and project", { + skip_if_not_installed("torch") + + torch_control <- PLNPCA_param( + backend = "torch", + trace = 0, + config_optim = list(algorithm = "RPROP", lr = 0.01, num_epoch = 20, num_batch = 1) + ) + + torch_fit <- getModel( + PLNPCA( + Abundance ~ 1, + data = trichoptera, + ranks = 1, + control = torch_control + ), + 1 + ) + + Y <- as.matrix(trichoptera$Abundance) + expected_loglik_vec <- .5 * ncol(Y) - rowSums(PLNmodels:::.logfactorial(Y)) + + rowSums(Y * torch_fit$latent - fitted(torch_fit)) - + .5 * rowSums(torch_fit$var_par$M^2 + torch_fit$var_par$S^2 - log(torch_fit$var_par$S^2) - 1) + + expect_equal(torch_fit$loglik_vec, expected_loglik_vec, tolerance = 1e-4, check.attributes = FALSE) + + model1 <- getModel(models, 1) + expect_no_error(scores <- model1$project(newdata = trichoptera, control = torch_control)) + expect_false(is.null(scores)) + expect_equal(dim(scores), dim(model1$scores)) + expect_equal(dimnames(scores), dimnames(model1$scores)) +}) + test_that("Bindings for factoextra return sensible values", { ## $eig expect_gte(min(myPLNfit$eig[, "eigenvalue"]), 0) From 19d30b8ec2ff78a68a279d441ccf6dcd63a60953 Mon Sep 17 00:00:00 2001 From: EngineerDanny Date: Tue, 24 Mar 2026 11:46:32 -0700 Subject: [PATCH 3/4] fix torch implementation in the PLNnetwork path --- R/PLNfit-class.R | 2 +- tests/testthat/test-plnnetworkfit.R | 39 +++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index c5438f14..cf2817b4 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -934,7 +934,7 @@ PLNfit_fixedcov <- R6Class( }, torch_Omega = function(data, params) { - params$Omega <- torch_tensor(private$Omega) + params$Omega <- torch_tensor(private$Omega, dtype = params$B$dtype, device = params$B$device) }, ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% diff --git a/tests/testthat/test-plnnetworkfit.R b/tests/testthat/test-plnnetworkfit.R index a835dc96..f2ef1586 100644 --- a/tests/testthat/test-plnnetworkfit.R +++ b/tests/testthat/test-plnnetworkfit.R @@ -46,3 +46,42 @@ test_that("PLNnetwork fit: check classes, getters and field access", { expect_true(inherits(myPLNfit$plot_network(output = "corrplot", plot = FALSE), "Matrix")) }) + +test_that("PLNnetwork fit accepts torch backend", { + skip_if_not_installed("torch") + + data("trichoptera", package = "PLNmodels", envir = environment()) + trichoptera_small <- prepare_data( + trichoptera$Abundance[1:10, 1:4], + trichoptera$Covariate[1:10, , drop = FALSE] + ) + Y <- as.matrix(trichoptera_small$Abundance) + torch_control <- PLNnetwork_param( + backend = "torch", + trace = 0, + config_optim = list( + algorithm = "RPROP", + lr = 0.01, + num_epoch = 5, + num_batch = 1, + maxit_out = 2 + ) + ) + + models <- NULL + expect_no_error(models <- PLNnetwork( + Abundance ~ 1, + data = trichoptera_small, + penalties = 0.1, + control = torch_control + )) + expect_false(is.null(models)) + + myPLNfit <- getBestModel(models) + expect_equal(dim(myPLNfit$latent), dim(Y)) + expect_equal(dim(myPLNfit$model_par$B), c(1, ncol(Y))) + expect_equal(dim(myPLNfit$model_par$Omega), c(ncol(Y), ncol(Y))) + expect_equal(dim(myPLNfit$var_par$M), dim(Y)) + expect_equal(dim(myPLNfit$var_par$S), dim(Y)) + expect_equal(sum(myPLNfit$loglik_vec), myPLNfit$loglik, tolerance = 1e-4) +}) From e804dacd64c27600b11111ae907c41fe09f3649c Mon Sep 17 00:00:00 2001 From: EngineerDanny Date: Tue, 24 Mar 2026 13:34:21 -0700 Subject: [PATCH 4/4] Skip torch tests without runtime support --- man/PLNPCA_param.Rd | 2 +- tests/testthat/test-plnnetworkfit.R | 1 + tests/testthat/test-plnpcafit.R | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) diff --git a/man/PLNPCA_param.Rd b/man/PLNPCA_param.Rd index 0e081d76..b5c23c7d 100644 --- a/man/PLNPCA_param.Rd +++ b/man/PLNPCA_param.Rd @@ -5,7 +5,7 @@ \title{Control of PLNPCA fit} \usage{ PLNPCA_param( - backend = "nlopt", + backend = c("nlopt", "torch"), trace = 1, config_optim = list(), config_post = list(), diff --git a/tests/testthat/test-plnnetworkfit.R b/tests/testthat/test-plnnetworkfit.R index f2ef1586..412a2017 100644 --- a/tests/testthat/test-plnnetworkfit.R +++ b/tests/testthat/test-plnnetworkfit.R @@ -49,6 +49,7 @@ test_that("PLNnetwork fit: check classes, getters and field access", { test_that("PLNnetwork fit accepts torch backend", { skip_if_not_installed("torch") + skip_if_not(torch::torch_is_installed()) data("trichoptera", package = "PLNmodels", envir = environment()) trichoptera_small <- prepare_data( diff --git a/tests/testthat/test-plnpcafit.R b/tests/testthat/test-plnpcafit.R index 8b8d0978..f7fbcc1a 100644 --- a/tests/testthat/test-plnpcafit.R +++ b/tests/testthat/test-plnpcafit.R @@ -82,6 +82,7 @@ test_that("PLNPCA fit: check classes, getters and field access", { test_that("PLNPCA torch backend works for fit and project", { skip_if_not_installed("torch") + skip_if_not(torch::torch_is_installed()) torch_control <- PLNPCA_param( backend = "torch",