Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ be directly added to this file to describe the related changes.
- Fixed typos in the help page for `objective_function`, and in the `add_norm`
function (defined in `R/objective_function_helpers.R`)

- Allowed user-supplied regularization functions

# Changes in BioCroValidation Version 0.2.0 (2025-05-23)

- Added 2002 and 2005 SoyFACE biomass and standard deviation data.
Expand Down
20 changes: 13 additions & 7 deletions R/objective_function.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@ objective_function <- function(
long_form_data,
normalization_method,
normalization_param,
length(data_driver_pairs)
length(data_driver_pairs),
verbose_startup
)

# Add variance-based weights
long_form_data <- add_w_var(
long_form_data,
stdev_weight_method,
stdev_weight_param
stdev_weight_param,
verbose_startup
)

# Print the long form data, if desired. Do this before checking the data,
Expand All @@ -109,11 +111,15 @@ objective_function <- function(
# Get the data-driver pair weights
ddp_weights <- get_ddp_weights(data_driver_pairs)

# Print the data-driver pair weights, if desired
if (verbose_startup) {
cat('\nThe user-supplied data-driver pair weights:\n\n')
utils::str(ddp_weights)
}
# Print additional startup information, if desired
print_misc_verbose_startup(
ddp_weights,
regularization_method,
dependent_arg_function,
post_process_function,
extra_penalty_function,
verbose_startup
)

# Create the objective function
obj_fun <- get_obj_fun(
Expand Down
144 changes: 120 additions & 24 deletions R/objective_function_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -188,9 +188,32 @@ add_norm <- function(
long_form_data,
normalization_method,
normalization_param,
n_ddp
n_ddp,
verbose_startup
)
{
eps_max <- get_param_default(normalization_param, 1e-1)
eps_obs <- get_param_default(normalization_param, 1e-1)

method <- toupper(normalization_method)

if (verbose_startup) {
method_info <- if (method %in% c('MAX', 'MEAN_MAX')) {
paste(method, 'with eps =', eps_max)
} else if (method %in% c('OBS', 'MEAN_OBS')) {
paste(method, 'with eps =', eps_obs)
} else {
method
}

cat(paste(
'\nNormalization method:',
method_info,
'\n',
collapse = ''
))
}

for (i in seq_along(long_form_data)) {
data_table <- long_form_data[[i]]

Expand All @@ -204,20 +227,17 @@ add_norm <- function(
qmax <- max(abs(qname_subset[['quantity_value']]))
qobs <- data_table[j, 'quantity_value']

eps_max <- get_param_default(normalization_param, 1e-1)
eps_obs <- get_param_default(normalization_param, 1e-1)

if (tolower(normalization_method) == 'equal') {
if (method == 'EQUAL') {
1.0
} else if (tolower(normalization_method) == 'mean') {
} else if (method == 'MEAN') {
npts * n_ddp
} else if (tolower(normalization_method) == 'max') {
} else if (method == 'MAX') {
qmax^2 + eps_max
} else if (tolower(normalization_method) == 'obs') {
} else if (method == 'OBS') {
qobs^2 + eps_obs
} else if (tolower(normalization_method) == 'mean_max') {
} else if (method == 'MEAN_MAX') {
npts * n_ddp * (qmax^2 + eps_max)
} else if (tolower(normalization_method) == 'mean_obs') {
} else if (method == 'MEAN_OBS') {
npts * n_ddp * (qobs^2 + eps_obs)
} else {
stop('Unsupported normalization_method: ', normalization_method)
Expand All @@ -231,20 +251,45 @@ add_norm <- function(
}

# Helping function for getting variance-based weights
add_w_var <- function(long_form_data, stdev_weight_method, stdev_weight_param) {
add_w_var <- function(
long_form_data,
stdev_weight_method,
stdev_weight_param,
verbose_startup
)
{
eps_log <- get_param_default(stdev_weight_param, 1e-5)
eps_inv <- get_param_default(stdev_weight_param, 1e-1)

method <- toupper(stdev_weight_method)

if (verbose_startup) {
method_info <- if (method == 'LOGARITHM') {
paste(method, 'with eps =', eps_log)
} else if (method == 'INVERSE') {
paste(method, 'with eps =', eps_inv)
} else {
method
}

cat(paste(
'\nStandard-deviation-based weight method:',
method_info,
'\n',
collapse = ''
))
}

for (i in seq_along(long_form_data)) {
data_table <- long_form_data[[i]]
data_stdev <- data_table[['quantity_stdev']]

eps_log <- get_param_default(stdev_weight_param, 1e-5)
eps_inv <- get_param_default(stdev_weight_param, 1e-1)

data_table[['w_var']] <-
if (tolower(stdev_weight_method) == 'equal') {
if (method == 'EQUAL') {
1.0
} else if (tolower(stdev_weight_method) == 'logarithm') {
} else if (method == 'LOGARITHM') {
log(1.0 / (data_stdev + eps_log))
} else if (tolower(stdev_weight_method) == 'inverse') {
} else if (method == 'INVERSE') {
1.0 / (data_stdev^2 + eps_inv)
} else {
stop('Unsupported stdev_weight_method: ', stdev_weight_method)
Expand Down Expand Up @@ -408,14 +453,20 @@ regularization_penalty <- function(
regularization_lambda
)
{
if (toupper(regularization_method) == 'NONE') {
0.0
} else if (toupper(regularization_method) == 'LASSO' || toupper(regularization_method) == 'L1') {
regularization_lambda * sum(abs(ind_arg_vals))
} else if (toupper(regularization_method) == 'RIDGE' || toupper(regularization_method) == 'L2') {
regularization_lambda * sum(ind_arg_vals^2)
if (is.function(regularization_method)) {
regularization_method(ind_arg_vals, regularization_lambda)
} else {
stop('Unsupported regularization method: ', regularization_method)
method <- toupper(regularization_method)

if (method == 'NONE') {
0.0
} else if (method %in% c('LASSO', 'L1')) {
regularization_lambda * sum(abs(ind_arg_vals))
} else if (method %in% c('RIDGE', 'L2')) {
regularization_lambda * sum(ind_arg_vals^2)
} else {
stop('Unsupported regularization method: ', regularization_method)
}
}
}

Expand Down Expand Up @@ -493,3 +544,48 @@ get_obj_fun <- function(
}
}
}

# Print the data-driver pair weights, information about the regularization
# method, and information about optional functions, if desired
print_misc_verbose_startup <- function(
ddp_weights,
regularization_method,
dependent_arg_function,
post_process_function,
extra_penalty_function,
verbose_startup
)
{
if (verbose_startup){
user_func_msg <- 'user-supplied function:\n\n'

cat('\nThe user-supplied data-driver pair weights:\n\n')
utils::str(ddp_weights)

cat('\nRegularization method: ')

if (is.function(regularization_method)) {
cat(user_func_msg)
print(regularization_method)
} else {
cat(paste0(toupper(regularization_method), '\n'))
}

func_to_print <- list(
list(info = 'Dependent argument', func = dependent_arg_function),
list(info = 'Post-processing', func = post_process_function),
list(info = 'Extra penalty', func = extra_penalty_function)
)

for (x in func_to_print) {
cat(paste0('\n', x$info, ' function: '))

if (is.null(x$func)) {
cat('none\n')
} else {
cat(user_func_msg)
print(x$func)
}
}
}
}
40 changes: 28 additions & 12 deletions man/objective_function.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@

\item{regularization_method}{
A string indicating the regularization method to be used when calculating
the regularization penalty term; see below for more details.
the regularization penalty term, or a function that calculates the penalty;
see below for more details.
}

\item{dependent_arg_function}{
Expand Down Expand Up @@ -333,9 +334,10 @@

\strong{Standard-deviation-based weight methods}

The following methods are available for determining weight factors from values
of the standard deviation (\eqn{\sigma}), which can be (optionally) supplied
via the \code{data_stdev} elements of the \code{data_driver_pairs}:
The following pre-set methods are available for determining weight factors
from values of the standard deviation (\eqn{\sigma}), which can be
(optionally) supplied via the \code{data_stdev} elements of the
\code{data_driver_pairs}:

\itemize{
\item \code{'equal'}: For this method, \eqn{w_i^{stdev}} is always set to
Expand Down Expand Up @@ -381,7 +383,7 @@

\strong{Normalization methods}

The following normalization methods are available:
The following pre-set normalization methods are available:

\itemize{
\item \code{'equal'}: For this method, \eqn{N_i} is always set to 1. In
Expand Down Expand Up @@ -466,7 +468,7 @@

\strong{Regularization methods}

The following regularization methods are available:
The following pre-set regularization methods are available:

\itemize{
\item \code{'none'}: For this method, \eqn{P_{regularization}} is always set
Expand All @@ -493,6 +495,10 @@
section below for details of how to specify \eqn{\lambda}.
}

It is also possible to supply a function that accepts two input arguments
(\code{x} and \code{lambda}, as described in the "Value" section below) and
returns a numeric penalty value.

\strong{Input checks}

Several checks are made to ensure that the objective function is properly
Expand Down Expand Up @@ -649,6 +655,8 @@ if (require(BioCro)) {
# original Soybean-BioCro model.
independent_args <- BioCro::soybean[['parameters']][c('alphaLeaf', 'betaLeaf')]

initial_guess <- as.numeric(independent_args)

dependent_arg_function <- function(ind_args) {
list(alphaStem = ind_args[['alphaLeaf']])
}
Expand All @@ -675,6 +683,12 @@ if (require(BioCro)) {
}
}

# We want to use the regularization term to penalize deviations away from the
# initial guess, so we will define a custom L2 regularization function
regularization_function <- function(x, lambda) {
lambda * sum((x - initial_guess)^2)
}

# Now we can finally create the objective function
obj_fun <- objective_function(
base_model_definition,
Expand All @@ -683,6 +697,7 @@ if (require(BioCro)) {
quantity_weights,
data_definitions = data_definitions,
stdev_weight_method = 'logarithm',
regularization_method = regularization_function,
dependent_arg_function = dependent_arg_function,
post_process_function = post_process_function,
extra_penalty_function = extra_penalty_function
Expand All @@ -691,17 +706,18 @@ if (require(BioCro)) {
# This function could now be passed to an optimizer; here we will simply
# evaluate it for two sets of parameter values.

# Try doubling each parameter value; in this case, the value of the
# objective function increases, indicating a lower degree of agreement between
# the model and the observed data. Here we will call `obj_fun` in debug mode,
# which will automatically print the value of the error metric.
# Try doubling each parameter value and setting lambda to a nonzero value; in
# this case, the value of the objective function increases, indicating a lower
# degree of agreement between the model and the observed data. Here we will
# call `obj_fun` in debug mode, which will automatically print the value of
# the error metric.
cat('\nError metric calculated by doubling the original argument values:\n')
error_metric <- obj_fun(2 * as.numeric(independent_args), debug_mode = TRUE)
error_metric <- obj_fun(2 * initial_guess, 0.001, debug_mode = TRUE)

# We can also see the values of each term that makes up the error metric;
# again, we will call `obj_fun` in debug mode for automatic printing.
cat('\nError metric terms calculated by doubling the original argument values:\n')
error_terms <-
obj_fun(2 * as.numeric(independent_args), return_terms = TRUE, debug_mode = TRUE)
obj_fun(2 * initial_guess, 0.001, return_terms = TRUE, debug_mode = TRUE)
}
}
19 changes: 19 additions & 0 deletions tests/testthat/test-objective_function.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,25 @@ test_that('Objective functions can be created and behave as expected', {
obj_fun(as.numeric(independent_args), lambda = 0.5)
)

# Two data-driver pairs, with dependent arguments and L4 regularization
obj_fun <- expect_silent(
objective_function(
model,
ddps,
independent_args,
quantity_weights,
data_definitions = data_definitions,
dependent_arg_function = dependent_arg_function,
post_process_function = post_process_function,
regularization_method = function(x, lambda) {lambda * sum(x^4)},
verbose_startup = verbose_startup
)
)

expect_silent(
obj_fun(as.numeric(independent_args), lambda = 0.5)
)

expect_true(
is.list(
obj_fun(as.numeric(independent_args), lambda = 0.5, return_terms = TRUE)
Expand Down