From e8c566fe685478fe1f29b690d965b036dcf9bfbe Mon Sep 17 00:00:00 2001 From: eloch216 <48919455+eloch216@users.noreply.github.com> Date: Sat, 31 May 2025 11:06:00 -0500 Subject: [PATCH 1/4] Standardize and simplify method selection --- R/objective_function_helpers.R | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/R/objective_function_helpers.R b/R/objective_function_helpers.R index bfbae4d..d891e15 100644 --- a/R/objective_function_helpers.R +++ b/R/objective_function_helpers.R @@ -207,17 +207,19 @@ add_norm <- function( eps_max <- get_param_default(normalization_param, 1e-1) eps_obs <- get_param_default(normalization_param, 1e-1) - if (tolower(normalization_method) == 'equal') { + method <- toupper(normalization_method) + + if (method == 'EQUAL') { 1.0 - } else if (tolower(normalization_method) == 'mean') { + } else if (method == 'MEAN') { npts * n_ddp - } else if (tolower(normalization_method) == 'max') { + } else if (method == 'MAX') { qmax^2 + eps_max - } else if (tolower(normalization_method) == 'obs') { + } else if (method == 'OBS') { qobs^2 + eps_obs - } else if (tolower(normalization_method) == 'mean_max') { + } else if (method == 'MEAN_MAX') { npts * n_ddp * (qmax^2 + eps_max) - } else if (tolower(normalization_method) == 'mean_obs') { + } else if (method == 'MEAN_OBS') { npts * n_ddp * (qobs^2 + eps_obs) } else { stop('Unsupported normalization_method: ', normalization_method) @@ -239,12 +241,14 @@ add_w_var <- function(long_form_data, stdev_weight_method, stdev_weight_param) { eps_log <- get_param_default(stdev_weight_param, 1e-5) eps_inv <- get_param_default(stdev_weight_param, 1e-1) + method <- toupper(stdev_weight_method) + data_table[['w_var']] <- - if (tolower(stdev_weight_method) == 'equal') { + if (method == 'EQUAL') { 1.0 - } else if (tolower(stdev_weight_method) == 'logarithm') { + } else if (method == 'LOGARITHM') { log(1.0 / (data_stdev + eps_log)) - } else if (tolower(stdev_weight_method) == 'inverse') { + } else if (method == 'INVERSE') { 1.0 / (data_stdev^2 + eps_inv) } else { stop('Unsupported stdev_weight_method: ', stdev_weight_method) @@ -408,11 +412,13 @@ regularization_penalty <- function( regularization_lambda ) { - if (toupper(regularization_method) == 'NONE') { + method <- toupper(regularization_method) + + if (method == 'NONE') { 0.0 - } else if (toupper(regularization_method) == 'LASSO' || toupper(regularization_method) == 'L1') { + } else if (method %in% c('LASSO', 'L1')) { regularization_lambda * sum(abs(ind_arg_vals)) - } else if (toupper(regularization_method) == 'RIDGE' || toupper(regularization_method) == 'L2') { + } else if (method %in% c('RIDGE', 'L2')) { regularization_lambda * sum(ind_arg_vals^2) } else { stop('Unsupported regularization method: ', regularization_method) From 665a963a09ea7c0661a5851efc9a7e9a7463c320 Mon Sep 17 00:00:00 2001 From: eloch216 <48919455+eloch216@users.noreply.github.com> Date: Sat, 31 May 2025 12:09:41 -0500 Subject: [PATCH 2/4] Allow custom regularization function --- R/objective_function_helpers.R | 22 +++++++------ man/objective_function.Rd | 40 +++++++++++++++++------- tests/testthat/test-objective_function.R | 19 +++++++++++ 3 files changed, 60 insertions(+), 21 deletions(-) diff --git a/R/objective_function_helpers.R b/R/objective_function_helpers.R index d891e15..fc65cee 100644 --- a/R/objective_function_helpers.R +++ b/R/objective_function_helpers.R @@ -412,16 +412,20 @@ regularization_penalty <- function( regularization_lambda ) { - method <- toupper(regularization_method) - - if (method == 'NONE') { - 0.0 - } else if (method %in% c('LASSO', 'L1')) { - regularization_lambda * sum(abs(ind_arg_vals)) - } else if (method %in% c('RIDGE', 'L2')) { - regularization_lambda * sum(ind_arg_vals^2) + if (is.function(regularization_method)) { + regularization_method(ind_arg_vals, regularization_lambda) } else { - stop('Unsupported regularization method: ', regularization_method) + method <- toupper(regularization_method) + + if (method == 'NONE') { + 0.0 + } else if (method %in% c('LASSO', 'L1')) { + regularization_lambda * sum(abs(ind_arg_vals)) + } else if (method %in% c('RIDGE', 'L2')) { + regularization_lambda * sum(ind_arg_vals^2) + } else { + stop('Unsupported regularization method: ', regularization_method) + } } } diff --git a/man/objective_function.Rd b/man/objective_function.Rd index 269fbf9..cd11df9 100644 --- a/man/objective_function.Rd +++ b/man/objective_function.Rd @@ -128,7 +128,8 @@ \item{regularization_method}{ A string indicating the regularization method to be used when calculating - the regularization penalty term; see below for more details. + the regularization penalty term, or a function that calculates the penalty; + see below for more details. } \item{dependent_arg_function}{ @@ -333,9 +334,10 @@ \strong{Standard-deviation-based weight methods} - The following methods are available for determining weight factors from values - of the standard deviation (\eqn{\sigma}), which can be (optionally) supplied - via the \code{data_stdev} elements of the \code{data_driver_pairs}: + The following pre-set methods are available for determining weight factors + from values of the standard deviation (\eqn{\sigma}), which can be + (optionally) supplied via the \code{data_stdev} elements of the + \code{data_driver_pairs}: \itemize{ \item \code{'equal'}: For this method, \eqn{w_i^{stdev}} is always set to @@ -381,7 +383,7 @@ \strong{Normalization methods} - The following normalization methods are available: + The following pre-set normalization methods are available: \itemize{ \item \code{'equal'}: For this method, \eqn{N_i} is always set to 1. In @@ -466,7 +468,7 @@ \strong{Regularization methods} - The following regularization methods are available: + The following pre-set regularization methods are available: \itemize{ \item \code{'none'}: For this method, \eqn{P_{regularization}} is always set @@ -493,6 +495,10 @@ section below for details of how to specify \eqn{\lambda}. } + It is also possible to supply a function that accepts two input arguments + (\code{x} and \code{lambda}, as described in the "Value" section below) and + returns a numeric penalty value. + \strong{Input checks} Several checks are made to ensure that the objective function is properly @@ -649,6 +655,8 @@ if (require(BioCro)) { # original Soybean-BioCro model. independent_args <- BioCro::soybean[['parameters']][c('alphaLeaf', 'betaLeaf')] + initial_guess <- as.numeric(independent_args) + dependent_arg_function <- function(ind_args) { list(alphaStem = ind_args[['alphaLeaf']]) } @@ -675,6 +683,12 @@ if (require(BioCro)) { } } + # We want to use the regularization term to penalize deviations away from the + # initial guess, so we will define a custom L2 regularization function + regularization_function <- function(x, lambda) { + lambda * sum((x - initial_guess)^2) + } + # Now we can finally create the objective function obj_fun <- objective_function( base_model_definition, @@ -683,6 +697,7 @@ if (require(BioCro)) { quantity_weights, data_definitions = data_definitions, stdev_weight_method = 'logarithm', + regularization_method = regularization_function, dependent_arg_function = dependent_arg_function, post_process_function = post_process_function, extra_penalty_function = extra_penalty_function @@ -691,17 +706,18 @@ if (require(BioCro)) { # This function could now be passed to an optimizer; here we will simply # evaluate it for two sets of parameter values. - # Try doubling each parameter value; in this case, the value of the - # objective function increases, indicating a lower degree of agreement between - # the model and the observed data. Here we will call `obj_fun` in debug mode, - # which will automatically print the value of the error metric. + # Try doubling each parameter value and setting lambda to a nonzero value; in + # this case, the value of the objective function increases, indicating a lower + # degree of agreement between the model and the observed data. Here we will + # call `obj_fun` in debug mode, which will automatically print the value of + # the error metric. cat('\nError metric calculated by doubling the original argument values:\n') - error_metric <- obj_fun(2 * as.numeric(independent_args), debug_mode = TRUE) + error_metric <- obj_fun(2 * initial_guess, 0.001, debug_mode = TRUE) # We can also see the values of each term that makes up the error metric; # again, we will call `obj_fun` in debug mode for automatic printing. cat('\nError metric terms calculated by doubling the original argument values:\n') error_terms <- - obj_fun(2 * as.numeric(independent_args), return_terms = TRUE, debug_mode = TRUE) + obj_fun(2 * initial_guess, 0.001, return_terms = TRUE, debug_mode = TRUE) } } diff --git a/tests/testthat/test-objective_function.R b/tests/testthat/test-objective_function.R index dcfccf1..02df6d1 100644 --- a/tests/testthat/test-objective_function.R +++ b/tests/testthat/test-objective_function.R @@ -103,6 +103,25 @@ test_that('Objective functions can be created and behave as expected', { obj_fun(as.numeric(independent_args), lambda = 0.5) ) + # Two data-driver pairs, with dependent arguments and L4 regularization + obj_fun <- expect_silent( + objective_function( + model, + ddps, + independent_args, + quantity_weights, + data_definitions = data_definitions, + dependent_arg_function = dependent_arg_function, + post_process_function = post_process_function, + regularization_method = function(x, lambda) {lambda * sum(x^4)}, + verbose_startup = verbose_startup + ) + ) + + expect_silent( + obj_fun(as.numeric(independent_args), lambda = 0.5) + ) + expect_true( is.list( obj_fun(as.numeric(independent_args), lambda = 0.5, return_terms = TRUE) From b8f71c3219d3ef01ccec494cc9db170e3a0f1649 Mon Sep 17 00:00:00 2001 From: eloch216 <48919455+eloch216@users.noreply.github.com> Date: Sat, 31 May 2025 13:44:40 -0500 Subject: [PATCH 3/4] Print more startup information --- R/objective_function.R | 20 +++--- R/objective_function_helpers.R | 110 +++++++++++++++++++++++++++++---- 2 files changed, 111 insertions(+), 19 deletions(-) diff --git a/R/objective_function.R b/R/objective_function.R index 6c21e00..c03ee8a 100644 --- a/R/objective_function.R +++ b/R/objective_function.R @@ -76,14 +76,16 @@ objective_function <- function( long_form_data, normalization_method, normalization_param, - length(data_driver_pairs) + length(data_driver_pairs), + verbose_startup ) # Add variance-based weights long_form_data <- add_w_var( long_form_data, stdev_weight_method, - stdev_weight_param + stdev_weight_param, + verbose_startup ) # Print the long form data, if desired. Do this before checking the data, @@ -109,11 +111,15 @@ objective_function <- function( # Get the data-driver pair weights ddp_weights <- get_ddp_weights(data_driver_pairs) - # Print the data-driver pair weights, if desired - if (verbose_startup) { - cat('\nThe user-supplied data-driver pair weights:\n\n') - utils::str(ddp_weights) - } + # Print additional startup information, if desired + print_misc_verbose_startup( + ddp_weights, + regularization_method, + dependent_arg_function, + post_process_function, + extra_penalty_function, + verbose_startup + ) # Create the objective function obj_fun <- get_obj_fun( diff --git a/R/objective_function_helpers.R b/R/objective_function_helpers.R index fc65cee..0b6c4f6 100644 --- a/R/objective_function_helpers.R +++ b/R/objective_function_helpers.R @@ -188,9 +188,32 @@ add_norm <- function( long_form_data, normalization_method, normalization_param, - n_ddp + n_ddp, + verbose_startup ) { + eps_max <- get_param_default(normalization_param, 1e-1) + eps_obs <- get_param_default(normalization_param, 1e-1) + + method <- toupper(normalization_method) + + if (verbose_startup) { + method_info <- if (method %in% c('MAX', 'MEAN_MAX')) { + paste(method, 'with eps =', eps_max) + } else if (method %in% c('OBS', 'MEAN_OBS')) { + paste(method, 'with eps =', eps_obs) + } else { + method + } + + cat(paste( + '\nNormalization method:', + method_info, + '\n', + collapse = '' + )) + } + for (i in seq_along(long_form_data)) { data_table <- long_form_data[[i]] @@ -204,11 +227,6 @@ add_norm <- function( qmax <- max(abs(qname_subset[['quantity_value']])) qobs <- data_table[j, 'quantity_value'] - eps_max <- get_param_default(normalization_param, 1e-1) - eps_obs <- get_param_default(normalization_param, 1e-1) - - method <- toupper(normalization_method) - if (method == 'EQUAL') { 1.0 } else if (method == 'MEAN') { @@ -233,16 +251,39 @@ add_norm <- function( } # Helping function for getting variance-based weights -add_w_var <- function(long_form_data, stdev_weight_method, stdev_weight_param) { +add_w_var <- function( + long_form_data, + stdev_weight_method, + stdev_weight_param, + verbose_startup +) +{ + eps_log <- get_param_default(stdev_weight_param, 1e-5) + eps_inv <- get_param_default(stdev_weight_param, 1e-1) + + method <- toupper(stdev_weight_method) + + if (verbose_startup) { + method_info <- if (method == 'LOGARITHM') { + paste(method, 'with eps =', eps_log) + } else if (method == 'INVERSE') { + paste(method, 'with eps =', eps_inv) + } else { + method + } + + cat(paste( + '\nStandard-deviation-based weight method:', + method_info, + '\n', + collapse = '' + )) + } + for (i in seq_along(long_form_data)) { data_table <- long_form_data[[i]] data_stdev <- data_table[['quantity_stdev']] - eps_log <- get_param_default(stdev_weight_param, 1e-5) - eps_inv <- get_param_default(stdev_weight_param, 1e-1) - - method <- toupper(stdev_weight_method) - data_table[['w_var']] <- if (method == 'EQUAL') { 1.0 @@ -503,3 +544,48 @@ get_obj_fun <- function( } } } + +# Print the data-driver pair weights, information about the regularization +# method, and information about optional functions, if desired +print_misc_verbose_startup <- function( + ddp_weights, + regularization_method, + dependent_arg_function, + post_process_function, + extra_penalty_function, + verbose_startup +) +{ + if (verbose_startup){ + user_func_msg <- 'user-supplied function:\n\n' + + cat('\nThe user-supplied data-driver pair weights:\n\n') + utils::str(ddp_weights) + + cat('\nRegularization method: ') + + if (is.function(regularization_method)) { + cat(user_func_msg) + print(regularization_method) + } else { + cat(paste0(toupper(regularization_method), '\n')) + } + + func_to_print <- list( + list(info = 'Dependent argument', func = dependent_arg_function), + list(info = 'Post-processing', func = post_process_function), + list(info = 'Extra penalty', func = extra_penalty_function) + ) + + for (x in func_to_print) { + cat(paste0('\n', x$info, ' function: ')) + + if (is.null(x$func)) { + cat('none\n') + } else { + cat(user_func_msg) + print(x$func) + } + } + } +} From f55b3067045d7d1acef52c52e44ff832cdf10754 Mon Sep 17 00:00:00 2001 From: eloch216 <48919455+eloch216@users.noreply.github.com> Date: Sat, 31 May 2025 13:46:26 -0500 Subject: [PATCH 4/4] Update NEWS.md --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index d90cf8f..b0ba5f4 100644 --- a/NEWS.md +++ b/NEWS.md @@ -38,6 +38,8 @@ be directly added to this file to describe the related changes. - Fixed typos in the help page for `objective_function`, and in the `add_norm` function (defined in `R/objective_function_helpers.R`) +- Allowed user-supplied regularization functions + # Changes in BioCroValidation Version 0.2.0 (2025-05-23) - Added 2002 and 2005 SoyFACE biomass and standard deviation data.