Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/compare_solvers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ fn main() {
.missing_observation(24.0, 0)
.build();

let spp = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V
let parameters = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V

// Run each solver and collect predictions
let bdf = two_cpt(OdeSolver::Bdf);
Expand All @@ -69,7 +69,7 @@ fn main() {
// ── Run all solvers and collect results ───────────────────────
let mut rows: Vec<(&str, u128, Vec<f64>)> = Vec::new();
for (name, ode) in &results {
let (preds, us) = timed(|| ode.estimate_predictions(&subject, &spp).unwrap());
let (preds, us) = timed(|| ode.estimate_predictions(&subject, &parameters).unwrap());
let preds: Vec<f64> = preds.flat_predictions().to_vec();
rows.push((name, us, preds));
}
Expand Down
10 changes: 5 additions & 5 deletions examples/gendata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ fn main() {
// let v_dist = rand_distr::Normal::new(50.0, 10.0).unwrap();
// let ske_dist = rand_distr::Normal::new(0.1, 0.01).unwrap();

let mut support_points = vec![];
let mut file = File::create("spp.csv").unwrap();
let mut parameters = vec![];
let mut file = File::create("parameters.csv").unwrap();
for _ in 0..100 {
let ke = ke_dist.sample(&mut rand::rng());
// let ke = 1.2;
Expand All @@ -55,14 +55,14 @@ fn main() {
// let ske = ske_dist.sample(&mut rand::thread_rng());
// let v = v_dist.sample(&mut rand::thread_rng());
let v = 50.0;
support_points.push(vec![ke]);
parameters.push(vec![ke]);
println!("{ke}, {ske}, {v}");
writeln!(file, "{}, {}, {}", ke, ske, v).unwrap();
}

let mut data = vec![];
for (i, spp) in support_points.iter().enumerate() {
let trajectories = sde.estimate_predictions(&subject, spp).unwrap();
for (i, parameter) in parameters.iter().enumerate() {
let trajectories = sde.estimate_predictions(&subject, parameter).unwrap();
let trajectory = trajectories.row(0);
// dbg!(&trajectory);
let mut sb = Subject::builder(format!("id{}", i)).bolus(0.0, 20.0, 0);
Expand Down
14 changes: 7 additions & 7 deletions src/data/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -592,12 +592,12 @@ impl Occasion {
}

fn add_lagtime(&mut self, reorder: Option<(&Fa, &Lag, &[f64], &Covariates)>) {
if let Some((_, fn_lag, spp, covariates)) = reorder {
let spp = nalgebra::DVector::from_vec(spp.to_vec());
if let Some((_, fn_lag, parameters, covariates)) = reorder {
let parameters = nalgebra::DVector::from_vec(parameters.to_vec());
for event in self.events.iter_mut() {
let time = event.time();
if let Event::Bolus(bolus) = event {
let lagtime = fn_lag(&spp.clone().into(), time, covariates);
let lagtime = fn_lag(&parameters.clone().into(), time, covariates);
if let Some(l) = lagtime.get(&bolus.input()) {
*bolus.mut_time() += l;
}
Expand All @@ -609,12 +609,12 @@ impl Occasion {

fn add_bioavailability(&mut self, reorder: Option<(&Fa, &Lag, &[f64], &Covariates)>) {
// If lagtime is empty, return early
if let Some((fn_fa, _, spp, covariates)) = reorder {
let spp = nalgebra::DVector::from_vec(spp.to_vec());
if let Some((fn_fa, _, parameters, covariates)) = reorder {
let parameters = nalgebra::DVector::from_vec(parameters.to_vec());
for event in self.events.iter_mut() {
let time = event.time();
if let Event::Bolus(bolus) = event {
let fa = fn_fa(&spp.clone().into(), time, covariates);
let fa = fn_fa(&parameters.clone().into(), time, covariates);
if let Some(f) = fa.get(&bolus.input()) {
bolus.set_amount(bolus.amount() * f);
}
Expand Down Expand Up @@ -656,7 +656,7 @@ impl Occasion {
///
/// # Arguments
///
/// * `reorder` - Optional tuple containing references to (Fa, Lag, support point, covariates) for adjustments
/// * `reorder` - Optional tuple containing references to (Fa, Lag, parameter, covariates) for adjustments
/// * `ignore` - If true, filter out events marked as ignore
/// * `mappings` - Optional reference to an [equation::Mapper] for input remapping
///
Expand Down
2 changes: 1 addition & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub use crate::data::Interpolation::*;
pub use crate::data::*;
pub use crate::equation::*;
pub use crate::optimize::effect::get_e2;
pub use crate::optimize::spp::SppOptimizer;
pub use crate::optimize::parameters::parametersOptimizer;
pub use crate::simulator::equation::{
self,
ode::{ExplicitRkTableau, OdeSolver, SdirkTableau},
Expand Down
2 changes: 1 addition & 1 deletion src/optimize/mod.rs
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pub mod effect;
pub mod spp;
pub mod parameters;
18 changes: 9 additions & 9 deletions src/optimize/spp.rs → src/optimize/parameters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,28 @@ use ndarray::{Array1, Axis};

use crate::{prelude::simulator::log_likelihood_matrix, AssayErrorModels, Data, Equation};

pub struct SppOptimizer<'a, E: Equation> {
pub struct parametersOptimizer<'a, E: Equation> {
equation: &'a E,
data: &'a Data,
sig: &'a AssayErrorModels,
pyl: &'a Array1<f64>,
}

impl<E: Equation> CostFunction for SppOptimizer<'_, E> {
impl<E: Equation> CostFunction for parametersOptimizer<'_, E> {
type Param = Vec<f64>;
type Output = f64;
fn cost(&self, spp: &Self::Param) -> Result<Self::Output, Error> {
let theta = Array1::from(spp.clone()).insert_axis(Axis(0));
fn cost(&self, parameters: &Self::Param) -> Result<Self::Output, Error> {
let theta = Array1::from(parameters.clone()).insert_axis(Axis(0));

let log_psi = log_likelihood_matrix(self.equation, self.data, &theta, self.sig, false)?;
let psi = log_psi.mapv(f64::exp);

if psi.ncols() > 1 {
tracing::error!("Psi in SppOptimizer has more than one column");
tracing::error!("Psi in parametersOptimizer has more than one column");
}
if psi.nrows() != self.pyl.len() {
tracing::error!(
"Psi in SppOptimizer has {} rows, but spp has {}",
"Psi in parametersOptimizer has {} rows, but parameters has {}",
psi.nrows(),
self.pyl.len()
);
Expand All @@ -42,7 +42,7 @@ impl<E: Equation> CostFunction for SppOptimizer<'_, E> {
}
}

impl<'a, E: Equation> SppOptimizer<'a, E> {
impl<'a, E: Equation> parametersOptimizer<'a, E> {
pub fn new(
equation: &'a E,
data: &'a Data,
Expand All @@ -56,8 +56,8 @@ impl<'a, E: Equation> SppOptimizer<'a, E> {
pyl,
}
}
pub fn optimize_point(self, spp: Array1<f64>) -> Result<Array1<f64>, Error> {
let simplex = create_initial_simplex(&spp.to_vec());
pub fn optimize_point(self, parameters: Array1<f64>) -> Result<Array1<f64>, Error> {
let simplex = create_initial_simplex(&parameters.to_vec());
let solver: NelderMead<Vec<f64>, f64> = NelderMead::new(simplex).with_sd_tolerance(1e-2)?;
let res = Executor::new(self, solver)
.configure(|state| state.max_iters(5))
Expand Down
4 changes: 2 additions & 2 deletions src/simulator/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ use crate::simulator::likelihood::SubjectPredictions;
/// Default maximum number of entries per cache.
pub const DEFAULT_CACHE_SIZE: u64 = 100_000;

/// Cache key: (subject_hash, support_point_hash)
/// Cache key: (subject_hash, parameters_hash)
pub(crate) type PredictionKey = (u64, u64);

/// Cache key for SDE: (subject_hash, support_point_hash, error_model_hash)
/// Cache key for SDE: (subject_hash, parameters_hash, error_model_hash)
pub(crate) type SdeKey = (u64, u64, u64);

/// Thread-safe LRU cache for subject predictions.
Expand Down
47 changes: 26 additions & 21 deletions src/simulator/equation/analytical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub use three_compartment_models::*;
pub use two_compartment_cl_models::*;
pub use two_compartment_models::*;

use super::spphash;
use super::parametershash;

use crate::data::error_model::AssayErrorModels;
use crate::simulator::cache::{PredictionCache, DEFAULT_CACHE_SIZE};
Expand Down Expand Up @@ -151,13 +151,13 @@ impl EquationPriv for Analytical {
// }

// #[inline(always)]
// fn get_lag(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
// Some((self.lag)(&V::from_vec(spp.to_owned())))
// fn get_lag(&self, parameters: &[f64]) -> Option<HashMap<usize, f64>> {
// Some((self.lag)(&V::from_vec(parameters.to_owned())))
// }

// #[inline(always)]
// fn get_fa(&self, spp: &[f64]) -> Option<HashMap<usize, f64>> {
// Some((self.fa)(&V::from_vec(spp.to_owned())))
// fn get_fa(&self, parameters: &[f64]) -> Option<HashMap<usize, f64>> {
// Some((self.fa)(&V::from_vec(parameters.to_owned())))
// }

#[inline(always)]
Expand Down Expand Up @@ -188,7 +188,7 @@ impl EquationPriv for Analytical {
fn solve(
&self,
x: &mut Self::S,
support_point: &[f64],
parameters: &[f64],
covariates: &Covariates,
infusions: &[Infusion],
ti: f64,
Expand Down Expand Up @@ -217,7 +217,7 @@ impl EquationPriv for Analytical {

// 2) March over each sub-interval
let mut current_t = ts[0];
let mut sp = V::from_vec(support_point.to_vec(), NalgebraContext);
let mut sp = V::from_vec(parameters.to_vec(), NalgebraContext);
let mut rateiv = V::zeros(self.get_ndrugs(), NalgebraContext);

for &next_t in &ts[1..] {
Expand Down Expand Up @@ -253,7 +253,7 @@ impl EquationPriv for Analytical {
#[inline(always)]
fn process_observation(
&self,
support_point: &[f64],
parameters: &[f64],
observation: &Observation,
error_models: Option<&AssayErrorModels>,
_time: f64,
Expand All @@ -266,7 +266,7 @@ impl EquationPriv for Analytical {
let out = &self.out;
(out)(
x,
&V::from_vec(support_point.to_vec(), NalgebraContext),
&V::from_vec(parameters.to_vec(), NalgebraContext),
observation.time(),
covariates,
&mut y,
Expand All @@ -280,12 +280,17 @@ impl EquationPriv for Analytical {
Ok(())
}
#[inline(always)]
fn initial_state(&self, spp: &[f64], covariates: &Covariates, occasion_index: usize) -> V {
fn initial_state(
&self,
parameters: &[f64],
covariates: &Covariates,
occasion_index: usize,
) -> V {
let init = &self.init;
let mut x = V::zeros(self.get_nstates(), NalgebraContext);
if occasion_index == 0 {
(init)(
&V::from_vec(spp.to_vec(), NalgebraContext),
&V::from_vec(parameters.to_vec(), NalgebraContext),
0.0,
covariates,
&mut x,
Expand Down Expand Up @@ -554,19 +559,19 @@ impl Equation for Analytical {
fn estimate_likelihood(
&self,
subject: &Subject,
support_point: &[f64],
parameters: &[f64],
error_models: &AssayErrorModels,
) -> Result<f64, PharmsolError> {
_estimate_likelihood(self, subject, support_point, error_models)
_estimate_likelihood(self, subject, parameters, error_models)
}

fn estimate_log_likelihood(
&self,
subject: &Subject,
support_point: &[f64],
parameters: &[f64],
error_models: &AssayErrorModels,
) -> Result<f64, PharmsolError> {
let ypred = _subject_predictions(self, subject, support_point)?;
let ypred = _subject_predictions(self, subject, parameters)?;
ypred.log_likelihood(error_models)
}

Expand All @@ -579,28 +584,28 @@ impl Equation for Analytical {
fn _subject_predictions(
analytical: &Analytical,
subject: &Subject,
support_point: &[f64],
parameters: &[f64],
) -> Result<SubjectPredictions, PharmsolError> {
if let Some(cache) = &analytical.cache {
let key = (subject.hash(), spphash(support_point));
let key = (subject.hash(), parametershash(parameters));
if let Some(cached) = cache.get(&key) {
return Ok(cached);
}

let result = analytical.simulate_subject(subject, support_point, None)?.0;
let result = analytical.simulate_subject(subject, parameters, None)?.0;
cache.insert(key, result.clone());
Ok(result)
} else {
Ok(analytical.simulate_subject(subject, support_point, None)?.0)
Ok(analytical.simulate_subject(subject, parameters, None)?.0)
}
}

fn _estimate_likelihood(
ode: &Analytical,
subject: &Subject,
support_point: &[f64],
parameters: &[f64],
error_models: &AssayErrorModels,
) -> Result<f64, PharmsolError> {
let ypred = _subject_predictions(ode, subject, support_point)?;
let ypred = _subject_predictions(ode, subject, parameters)?;
Ok(ypred.log_likelihood(error_models)?.exp())
}
Loading