diff --git a/.gitignore b/.gitignore index a34b170b..5d7e9858 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ Cargo.lock /paper_files paper.html /tests/browser-e2e/node_modules/ +docs/ diff --git a/README.md b/README.md index b1620b8e..d24aea87 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ A high-performance Rust library for pharmacokinetic/pharmacodynamic (PK/PD) simu ## Installation -Add `pharmsol` to your `Cargo.toml`, either manually or using +Add `pharmsol` to `Cargo.toml`: ```bash cargo add pharmsol @@ -16,65 +16,75 @@ cargo add pharmsol ## Quick Start +Most Rust-first workflows start with one of the equation macros: `analytical!`, +`ode!`, or `sde!`. Here is a simple one-compartment IV infusion model using `analytical!`: + ```rust -use pharmsol::*; +use pharmsol::prelude::*; + +let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, +}; + +let iv = analytical.route_index("iv").unwrap(); +let cp = analytical.output_index("cp").unwrap(); -// Create a subject with an IV infusion and observations let subject = Subject::builder("patient_001") - .infusion(0.0, 500.0, 0, 0.5) // 500 units over 0.5 hours - .observation(0.5, 1.645, 0) - .observation(1.0, 1.216, 0) - .observation(2.0, 0.462, 0) - .observation(4.0, 0.063, 0) + .infusion(0.0, 500.0, iv, 0.5) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) .build(); -// Define parameters: ke (elimination rate), v (volume) -let ke = 1.022; -let v = 194.0; - -// Use the built-in one-compartment analytical solution -let analytical = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; // Concentration = Amount / Volume - }, - (1, 1), // (compartments, outputs) -); - -// Get predictions -let predictions = analytical.estimate_predictions(&subject, &vec![ke, v]).unwrap(); +let predictions = analytical + .estimate_predictions(&subject, &[1.022, 194.0]) + .unwrap(); ``` -## ODE-Based Models +## Modeling Surfaces -For custom or complex models, define your own ODEs: +Here is the same one-compartment IV setup written as an ODE: ```rust -use pharmsol::*; +use pharmsol::prelude::*; -let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - fetch_params!(p, ke, _v); - // One-compartment model with IV infusion support - dx[0] = -ke * x[0] + rateiv[0]; +let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - (1, 1), -); +}; ``` -## Supported Analytical Models +See [examples/analytical_readme.rs](examples/analytical_readme.rs), +[examples/ode_readme.rs](examples/ode_readme.rs), +[examples/sde_readme.rs](examples/sde_readme.rs), +[examples/analytical_vs_ode.rs](examples/analytical_vs_ode.rs), and +[examples/compare_solvers.rs](examples/compare_solvers.rs). For migration-oriented notes, +see [docs/analytical-authoring-migration.md](docs/analytical-authoring-migration.md) and +[docs/ode-authoring-migration.md](docs/ode-authoring-migration.md). + +### Built-In Analytical Kernels - [x] One-compartment with IV infusion - [x] One-compartment with IV infusion and oral absorption @@ -83,6 +93,21 @@ let ode = equation::ODE::new( - [x] Three-compartment with IV infusion - [x] Three-compartment with IV infusion and oral absorption +## DSL and Runtime Targets + +If the model needs to be loaded or compiled at runtime, pharmsol also provides a DSL with +the same broad modeling coverage: ODE, analytical, and SDE authoring. The DSL can target +an in-process JIT runtime, native ahead-of-time artifacts, or WASM bundles depending on +how you want to ship and execute the model. + +- `dsl-jit`: compile DSL source into a runtime model inside the current process. +- `dsl-aot` and `dsl-aot-load`: emit a native artifact and load it later. +- `dsl-wasm`: compile and execute portable WASM model artifacts. + +See [examples/dsl_runtime_jit.rs](examples/dsl_runtime_jit.rs) for the in-repo JIT flow. +The companion `pharmsol-examples` crate includes end-to-end native AOT and WASM runtime +examples. + ## Performance Analytical solutions provide 20-33× speedups compared to equivalent ODE formulations. See [benchmarks](benches/) for details. diff --git a/examples/analytical_readme.rs b/examples/analytical_readme.rs new file mode 100644 index 00000000..8e5b97f7 --- /dev/null +++ b/examples/analytical_readme.rs @@ -0,0 +1,35 @@ +fn main() -> Result<(), pharmsol::PharmsolError> { + use pharmsol::prelude::*; + + let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let iv = analytical.route_index("iv").expect("iv route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + + let subject = Subject::builder("analytical_readme") + .infusion(0.0, 500.0, iv, 0.5) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .build(); + + let predictions = analytical.estimate_predictions(&subject, &[1.022, 194.0])?; + + println!("times => {:?}", predictions.flat_times()); + println!("predictions => {:?}", predictions.flat_predictions()); + + Ok(()) +} diff --git a/examples/analytical_vs_ode.rs b/examples/analytical_vs_ode.rs index 97112f15..290d6632 100644 --- a/examples/analytical_vs_ode.rs +++ b/examples/analytical_vs_ode.rs @@ -4,6 +4,8 @@ //! two-compartment IV, two-compartment oral), this example runs both the //! closed-form analytical solution and the equivalent ODE, then prints //! the predictions side by side so you can verify they match. +//! Both authoring paths use the declaration-first macro surface so the +//! example stays on the preferred public authoring story. //! //! cargo run --release --example analytical_vs_ode @@ -11,29 +13,29 @@ use pharmsol::prelude::*; // ── Subjects ─────────────────────────────────────────────────────── -fn subject_iv() -> Subject { +fn subject_iv(input: usize, output: usize) -> Subject { Subject::builder("1") - .infusion(0.0, 500.0, 0, 0.5) - .observation(0.5, 0.0, 0) - .observation(1.0, 0.0, 0) - .observation(2.0, 0.0, 0) - .observation(4.0, 0.0, 0) - .observation(8.0, 0.0, 0) - .observation(12.0, 0.0, 0) - .observation(24.0, 0.0, 0) + .infusion(0.0, 500.0, input, 0.5) + .observation(0.5, 0.0, output) + .observation(1.0, 0.0, output) + .observation(2.0, 0.0, output) + .observation(4.0, 0.0, output) + .observation(8.0, 0.0, output) + .observation(12.0, 0.0, output) + .observation(24.0, 0.0, output) .build() } -fn subject_oral() -> Subject { +fn subject_oral(input: usize, output: usize) -> Subject { Subject::builder("1") - .bolus(0.0, 500.0, 0) - .observation(0.5, 0.0, 0) - .observation(1.0, 0.0, 0) - .observation(2.0, 0.0, 0) - .observation(4.0, 0.0, 0) - .observation(8.0, 0.0, 0) - .observation(12.0, 0.0, 0) - .observation(24.0, 0.0, 0) + .bolus(0.0, 500.0, input) + .observation(0.5, 0.0, output) + .observation(1.0, 0.0, output) + .observation(2.0, 0.0, output) + .observation(4.0, 0.0, output) + .observation(8.0, 0.0, output) + .observation(12.0, 0.0, output) + .observation(24.0, 0.0, output) .build() } @@ -64,168 +66,181 @@ fn print_comparison(label: &str, analytical: &SubjectPredictions, ode: &SubjectP // ── One-compartment IV ───────────────────────────────────────────── -fn one_cmt_iv(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; - }, - ) - .with_nstates(1) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - fetch_params!(p, ke, _v); - dx[0] = -ke * x[0] + rateiv[0]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; - }, - ) - .with_nstates(1) - .with_nout(1); - - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); +fn one_cmt_iv(params: &[f64]) { + let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let iv = analytical.route_index("iv").expect("iv route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + let subject = subject_iv(iv, cp); + + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("One-compartment IV", &pred_a, &pred_o); } // ── One-compartment oral ─────────────────────────────────────────── -fn one_cmt_oral(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - one_compartment_with_absorption, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, v); - y[0] = x[1] / v; - }, - ) - .with_nstates(2) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, _rateiv, _cov| { - fetch_params!(p, ka, ke, _v); - dx[0] = -ka * x[0]; - dx[1] = ka * x[0] - ke * x[1]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, v); - y[0] = x[1] / v; - }, - ) - .with_nstates(2) - .with_nout(1); - - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); +fn one_cmt_oral(params: &[f64]) { + let analytical = analytical! { + name: "one_cmt_oral", + params: [ka, ke, v], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + structure: one_compartment_with_absorption, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let ode = ode! { + name: "one_cmt_oral", + params: [ka, ke, v], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let oral = analytical.route_index("oral").expect("oral route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + let subject = subject_oral(oral, cp); + + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("One-compartment oral", &pred_a, &pred_o); } // ── Two-compartment IV ───────────────────────────────────────────── -fn two_cmt_iv(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - two_compartments, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, _k12, _k21, v); - y[0] = x[0] / v; - }, - ) - .with_nstates(2) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - fetch_params!(p, ke, k12, k21, _v); - dx[0] = -ke * x[0] - k12 * x[0] + k21 * x[1] + rateiv[0]; - dx[1] = k12 * x[0] - k21 * x[1]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, _k12, _k21, v); - y[0] = x[0] / v; - }, - ) - .with_nstates(2) - .with_nout(1); - - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); +fn two_cmt_iv(params: &[f64]) { + let analytical = analytical! { + name: "two_cmt_iv", + params: [ke, k12, k21, v], + states: [central, peripheral], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + structure: two_compartments, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let ode = ode! { + name: "two_cmt_iv", + params: [ke, k12, k21, v], + states: [central, peripheral], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central] - k12 * x[central] + k21 * x[peripheral]; + dx[peripheral] = k12 * x[central] - k21 * x[peripheral]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let iv = analytical.route_index("iv").expect("iv route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + let subject = subject_iv(iv, cp); + + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("Two-compartment IV", &pred_a, &pred_o); } // ── Two-compartment oral ─────────────────────────────────────────── -fn two_cmt_oral(subject: &Subject, params: &[f64]) { - let analytical = equation::Analytical::new( - two_compartments_with_absorption, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _k12, _k21, v); - y[0] = x[1] / v; - }, - ) - .with_nstates(3) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, _rateiv, _cov| { - fetch_params!(p, ka, ke, k12, k21, _v); - dx[0] = -ka * x[0]; - dx[1] = ka * x[0] - ke * x[1] - k12 * x[1] + k21 * x[2]; - dx[2] = k12 * x[1] - k21 * x[2]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _k12, _k21, v); - y[0] = x[1] / v; - }, - ) - .with_nstates(3) - .with_nout(1); - - let pred_a = analytical.estimate_predictions(subject, params).unwrap(); - let pred_o = ode.estimate_predictions(subject, params).unwrap(); +fn two_cmt_oral(params: &[f64]) { + let analytical = analytical! { + name: "two_cmt_oral", + params: [ka, ke, k12, k21, v], + states: [gut, central, peripheral], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + structure: two_compartments_with_absorption, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let ode = ode! { + name: "two_cmt_oral", + params: [ka, ke, k12, k21, v], + states: [gut, central, peripheral], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central] - k12 * x[central] + k21 * x[peripheral]; + dx[peripheral] = k12 * x[central] - k21 * x[peripheral]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let oral = analytical.route_index("oral").expect("oral route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + let subject = subject_oral(oral, cp); + + let pred_a = analytical.estimate_predictions(&subject, params).unwrap(); + let pred_o = ode.estimate_predictions(&subject, params).unwrap(); print_comparison("Two-compartment oral", &pred_a, &pred_o); } // ── Main ─────────────────────────────────────────────────────────── fn main() { - let iv = subject_iv(); - let oral = subject_oral(); - - one_cmt_iv(&iv, &[0.1, 50.0]); // ke, v - one_cmt_oral(&oral, &[1.0, 0.1, 50.0]); // ka, ke, v - two_cmt_iv(&iv, &[0.1, 0.3, 0.2, 50.0]); // ke, k12, k21, v - two_cmt_oral(&oral, &[1.0, 0.1, 0.3, 0.2, 50.0]); // ka, ke, k12, k21, v + one_cmt_iv(&[0.1, 50.0]); // ke, v + one_cmt_oral(&[1.0, 0.1, 50.0]); // ka, ke, v + two_cmt_iv(&[0.1, 0.3, 0.2, 50.0]); // ke, k12, k21, v + two_cmt_oral(&[1.0, 0.1, 0.3, 0.2, 50.0]); // ka, ke, k12, k21, v } diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index 3a34b424..ebec4caa 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -1,4 +1,4 @@ -//! Shows how to select different ODE solvers for the same model. +//! Shows how to select different ODE solvers for the same declaration-first model. //! //! pharmsol wraps diffsol's solver families: //! @@ -14,19 +14,26 @@ use std::time::Instant; use pharmsol::prelude::*; // ── Model ────────────────────────────────────────────────────────── -// Two-compartment IV model. The solver is the only thing that changes -// between runs — the ODE, output equation and dimensions stay the same. +// Two-compartment IV model. The solver is the only thing that changes +// between runs; the declaration-first `ode!` surface and the generated +// metadata stay the same. fn two_cpt(solver: OdeSolver) -> equation::ODE { ode! { - diffeq: |x, p, _t, dx, b, rateiv, _cov| { - fetch_params!(p, ke, kcp, kpc, _v); - dx[0] = rateiv[0] + b[0] - ke * x[0] - kcp * x[0] + kpc * x[1]; - dx[1] = kcp * x[0] - kpc * x[1]; + name: "two_cpt", + params: [ke, kcp, kpc, v], + states: [central, peripheral], + outputs: [cp], + routes: { + bolus(load) -> central, + infusion(iv) -> central, }, - out: |x, p, _t, _cov, y| { - fetch_params!(p, _ke, _kcp, _kpc, v); - y[0] = x[0] / v; + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; + dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, } .with_solver(solver) @@ -35,30 +42,42 @@ fn two_cpt(solver: OdeSolver) -> equation::ODE { // ── Main ─────────────────────────────────────────────────────────── fn main() { - let subject = Subject::builder("id1") - .bolus(0.0, 100.0, 0) - .infusion(12.0, 200.0, 0, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(4.0, 0) - .missing_observation(8.0, 0) - .missing_observation(12.0, 0) - .missing_observation(12.5, 0) - .missing_observation(13.0, 0) - .missing_observation(14.0, 0) - .missing_observation(16.0, 0) - .missing_observation(24.0, 0) - .build(); - - let spp = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V - // Run each solver and collect predictions let bdf = two_cpt(OdeSolver::Bdf); let tsit45 = two_cpt(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); let trbdf2 = two_cpt(OdeSolver::Sdirk(SdirkTableau::TrBdf2)); let esdirk34 = two_cpt(OdeSolver::Sdirk(SdirkTableau::Esdirk34)); + // Both declarations resolve to the same shared input channel, so subject + // authoring still uses one numeric index for the loading bolus and the + // maintenance infusion. + let load = bdf.route_index("load").expect("load route exists"); + let iv = bdf.route_index("iv").expect("iv route exists"); + let cp = bdf.output_index("cp").expect("cp output exists"); + + assert_eq!( + load, iv, + "mixed IV declarations should share one numeric channel" + ); + + let subject = Subject::builder("id1") + .bolus(0.0, 100.0, iv) + .infusion(12.0, 200.0, iv, 2.0) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .missing_observation(12.5, cp) + .missing_observation(13.0, cp) + .missing_observation(14.0, cp) + .missing_observation(16.0, cp) + .missing_observation(24.0, cp) + .build(); + + let spp = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V + let results: Vec<(&str, equation::ODE)> = vec![ ("Bdf", bdf), ("Sdirk(TrBdf2)", trbdf2), diff --git a/examples/covariates.rs b/examples/covariates.rs index f9b97f29..9aabf491 100644 --- a/examples/covariates.rs +++ b/examples/covariates.rs @@ -1,61 +1,51 @@ fn main() { use pharmsol::prelude::*; - // Create a subject with a bolus dose, observations, and covariates - let subject = Subject::builder("id1") - // Administer a bolus dose of 100 units at time 0 - .bolus(0.0, 100.0, 0) - // Give two additional doses at 2-hour intervals - .repeat(2, 2.0) - .observation(0.5, 0.1, 0) - .observation(1.0, 0.4, 0) - .observation(2.0, 1.0, 0) - .observation(2.5, 1.1, 0) - // Creatinine covariate changes over time, with initial value of 80 at time 0 - .covariate("creatinine", 0.0, 80.0) - // New obseration of creatinine at time 6 hours - // The value will be linearly interpolated between time 0 and time 6 - .covariate("creatinine", 1.0, 40.0) - // For age, the covariate is constant over time, as there are no changes - .covariate("age", 0.0, 25.0) - .missing_observation(8.0, 0) - .build(); - - let ode = equation::ODE::new( - |x, p, t, dx, b, _rateiv, cov| { - // Macro to get the (possibly interpolated) covariate values at time `t` - fetch_cov!(cov, t, creatinine, age); - // Macro to fetch parameter values from `p` - // Note the order must match the order in which parameters are defined later - fetch_params!(p, ka, ke, _tlag, _v); - - let ke = ke * (creatinine / 75.0).powf(0.75) * (age / 25.0).powf(0.5); + let ode = ode! { + name: "one_cmt_covariates", + params: [ka, ke, tlag, v], + covariates: [creatinine, age], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + diffeq: |x, _t, dx| { + let scaled_ke = ke * (creatinine / 75.0).powf(0.75) * (age / 25.0).powf(0.5); - //Struct - dx[0] = -ka * x[0] + b[0]; - dx[1] = ka * x[0] - ke * x[1]; + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - scaled_ke * x[central]; }, // This blocks defines the lag-time of the bolus dose - |p, _t, _cov| { - fetch_params!(p, _ka, _ke, tlag, _v); + lag: |_t| { // Macro used to define the lag-time for the input of the bolus dose - lag! {0=>tlag} + lag! { oral => tlag } }, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _tlag, v); - + out: |x, _t, y| { // Define the predicted concentration as the amount in the central compartment divided by volume - y[0] = x[1] / v; + y[cp] = x[central] / v; }, - ) - .with_nstates(2) - .with_nout(1); + }; + + let oral = ode.route_index("oral").expect("oral route exists"); + let cp = ode.output_index("cp").expect("cp output exists"); + + // Create a subject with metadata-backed route and output names instead of + // hard-coded numeric indices. + let subject = Subject::builder("id1") + .bolus(0.0, 100.0, oral) + .repeat(2, 2.0) + .observation(0.5, 0.1, cp) + .observation(1.0, 0.4, cp) + .observation(2.0, 1.0, cp) + .observation(2.5, 1.1, cp) + .covariate("creatinine", 0.0, 80.0) + .covariate("creatinine", 1.0, 40.0) + .covariate("age", 0.0, 25.0) + .missing_observation(8.0, cp) + .build(); // Define parameter values - // Note that the order matters and should correspond to the order in which parameters are fetched in the model - // This is subject to change in future versions let ka = 1.0; // Absorption rate constant let ke = 0.2; // Elimination rate constant let tlag = 0.0; // Lag time diff --git a/examples/dsl_runtime_jit.rs b/examples/dsl_runtime_jit.rs index 655981bc..932acaae 100644 --- a/examples/dsl_runtime_jit.rs +++ b/examples/dsl_runtime_jit.rs @@ -8,7 +8,7 @@ fn main() -> Result<(), Box> { use pharmsol::prelude::*; let model_source = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs new file mode 100644 index 00000000..4d8f74d0 --- /dev/null +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -0,0 +1,110 @@ +//! Compares a declaration-first macro ODE with the equivalent handwritten ODE. +//! +//! This is the advanced comparison path for users who want to confirm that the +//! preferred macro surface and the low-level API produce the same metadata and +//! predictions on the same one-compartment IV problem. + +use pharmsol::prelude::*; + +fn macro_model() -> equation::ODE { + ode! { + name: "one_cpt_macro_parity", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cpt_macro_parity") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ), + ) + .expect("handwritten one-compartment metadata should validate") +} + +fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { + left.iter() + .zip(right.iter()) + .map(|(lhs, rhs)| (lhs - rhs).abs()) + .fold(0.0_f64, f64::max) +} + +fn main() -> Result<(), pharmsol::PharmsolError> { + let macro_ode = macro_model(); + let handwritten_ode = handwritten_model(); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + + let iv = macro_ode.route_index("iv").expect("iv route exists"); + let cp = macro_ode.output_index("cp").expect("cp output exists"); + + assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); + assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); + + let subject = Subject::builder("macro-vs-handwritten-one-cpt") + .infusion(0.0, 500.0, iv, 0.5) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .missing_observation(8.0, cp) + .build(); + + let params = [1.022, 194.0]; + let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_ode.estimate_predictions(&subject, ¶ms)?; + + let macro_flat = macro_predictions.flat_predictions(); + let handwritten_flat = handwritten_predictions.flat_predictions(); + let diff = max_abs_diff(¯o_flat, &handwritten_flat); + + assert!( + diff <= 1e-10, + "macro and handwritten one-compartment predictions diverged: {diff:e}" + ); + + println!("one-compartment parity max abs diff: {diff:e}"); + for ((time, macro_pred), handwritten_pred) in macro_predictions + .flat_times() + .iter() + .zip(macro_flat.iter()) + .zip(handwritten_flat.iter()) + { + println!("t={time:>4.1} macro={macro_pred:>12.8} handwritten={handwritten_pred:>12.8}"); + } + + Ok(()) +} diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs new file mode 100644 index 00000000..9ab1a675 --- /dev/null +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -0,0 +1,130 @@ +//! Compares a declaration-first macro ODE with the equivalent handwritten ODE +//! on a two-compartment IV problem that shares one numeric input channel across +//! a loading bolus and a maintenance infusion. +//! +//! This keeps the macro story as the default surface while showing the +//! low-level API as an explicit advanced comparison path. + +use pharmsol::prelude::*; + +fn macro_model() -> equation::ODE { + ode! { + name: "two_cpt_shared_channel_parity", + params: [ke, kcp, kpc, v], + states: [central, peripheral], + outputs: [cp], + routes: { + bolus(load) -> central, + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; + dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_model() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + fetch_params!(p, ke, kcp, kpc, _v); + dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0] + bolus[0]; + dx[1] = kcp * x[0] - kpc * x[1]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, _kcp, _kpc, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("two_cpt_shared_channel_parity") + .parameters(["ke", "kcp", "kpc", "v"]) + .states(["central", "peripheral"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("load") + .to_state("central") + .inject_input_to_destination(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]), + ) + .expect("handwritten two-compartment metadata should validate") +} + +fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { + left.iter() + .zip(right.iter()) + .map(|(lhs, rhs)| (lhs - rhs).abs()) + .fold(0.0_f64, f64::max) +} + +fn main() -> Result<(), pharmsol::PharmsolError> { + let macro_ode = macro_model(); + let handwritten_ode = handwritten_model(); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + + let load = macro_ode.route_index("load").expect("load route exists"); + let iv = macro_ode.route_index("iv").expect("iv route exists"); + let cp = macro_ode.output_index("cp").expect("cp output exists"); + + assert_eq!( + load, iv, + "load and iv should share one numeric input channel" + ); + assert_eq!(handwritten_ode.route_index("load"), Some(load)); + assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); + assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); + + let subject = Subject::builder("macro-vs-handwritten-two-cpt") + .bolus(0.0, 100.0, load) + .infusion(12.0, 200.0, iv, 2.0) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .missing_observation(12.5, cp) + .missing_observation(13.0, cp) + .missing_observation(14.0, cp) + .missing_observation(16.0, cp) + .missing_observation(24.0, cp) + .build(); + + let params = [0.1, 0.05, 0.03, 50.0]; + let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_ode.estimate_predictions(&subject, ¶ms)?; + + let macro_flat = macro_predictions.flat_predictions(); + let handwritten_flat = handwritten_predictions.flat_predictions(); + let diff = max_abs_diff(¯o_flat, &handwritten_flat); + + assert!( + diff <= 1e-10, + "macro and handwritten two-compartment predictions diverged: {diff:e}" + ); + + println!("two-compartment parity max abs diff: {diff:e}"); + for ((time, macro_pred), handwritten_pred) in macro_predictions + .flat_times() + .iter() + .zip(macro_flat.iter()) + .zip(handwritten_flat.iter()) + { + println!("t={time:>5.1} macro={macro_pred:>12.8} handwritten={handwritten_pred:>12.8}"); + } + + Ok(()) +} diff --git a/examples/ode_readme.rs b/examples/ode_readme.rs index 51765af1..a0174801 100644 --- a/examples/ode_readme.rs +++ b/examples/ode_readme.rs @@ -1,43 +1,40 @@ -fn main() { +fn main() -> Result<(), pharmsol::PharmsolError> { use pharmsol::prelude::*; - let subject = Subject::builder("id1") - .bolus(0.0, 100.0, 0) - .repeat(2, 0.5) - .observation(0.5, 0.1, 0) - .observation(1.0, 0.4, 0) - .observation(2.0, 1.0, 0) - .observation(2.5, 1.1, 0) - .covariate("wt", 0.0, 80.0) - .covariate("wt", 1.0, 83.0) - .covariate("age", 0.0, 25.0) - .build(); - println!("{subject}"); - let ode = equation::ODE::new( - |x, p, _t, dx, b, _rateiv, _cov| { - // fetch_cov!(cov, t,); - fetch_params!(p, ka, ke, _tlag, _v); - //Struct - dx[0] = -ka * x[0] + b[0]; - dx[1] = ka * x[0] - ke * x[1]; + let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, }, - |p, _t, _cov| { - fetch_params!(p, _ka, _ke, tlag, _v); - lag! {0=>tlag} + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; }, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ka, _ke, _tlag, v); - y[0] = x[1] / v; + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; }, - ) - .with_nstates(2) - .with_ndrugs(5) - .with_nout(1); + }; + + let iv = ode.route_index("iv").expect("iv route exists"); + let cp = ode.output_index("cp").expect("cp output exists"); + + let subject = Subject::builder("id1") + .infusion(0.0, 100.0, iv, 0.5) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .build(); + + let predictions = ode.estimate_predictions(&subject, &[1.022, 194.0])?; + println!( + "state central => {}", + ode.state_index("central").expect("central state exists") + ); + println!("prediction times => {:?}", predictions.flat_times()); + println!("predictions => {:?}", predictions.flat_predictions()); - let op = ode - .estimate_predictions(&subject, &[0.3, 0.5, 0.1, 70.0]) - .unwrap(); - println!("{:#?}", op.flat_predictions()); + Ok(()) } diff --git a/examples/one_compartment.rs b/examples/one_compartment.rs index a66112b5..aafdf2b2 100644 --- a/examples/one_compartment.rs +++ b/examples/one_compartment.rs @@ -1,67 +1,58 @@ fn main() -> Result<(), pharmsol::PharmsolError> { use pharmsol::prelude::*; - // Create a subject using the builder pattern - let subject = Subject::builder("Nikola Tesla") - // An initial infusion of 500 units over 0.5 time units - .infusion(0., 500.0, 0, 0.5) - // Observations at various time points - .observation(0.5, 1.645, 0) - .observation(1., 1.216, 0) - .observation(2., 0.462, 0) - .observation(3., 0.169, 0) - .observation(4., 0.063, 0) - .observation(6., 0.009, 0) - .observation(8., 0.001, 0) - // A missing observation, to force the simulator to predict to this time point - // For missing observations, predictions are made but no likelihood contribution is computed - .missing_observation(12.0, 0) - // Build the subject - .build(); - - // Define the one-compartment analytical solution function - let an = equation::Analytical::new( - one_compartment, - |_p, _t, _cov| {}, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - // Calculate the output concentration, here defined as amount over volume - y[0] = x[0] / v; + let analytical = analytical! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, }, - ) - .with_nstates(1) - .with_nout(1); - - let ode = equation::ODE::new( - |x, p, _t, dx, _b, rateiv, _cov| { - // Macro to fetch parameters from the parameter vector - // This exposes them as local variables - fetch_params!(p, ke, _v); + structure: one_compartment, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; - // Define the ODE for the one-compartment model - // Note that rateiv is used to include infusion rates - dx[0] = -ke * x[0] + rateiv[0]; + let ode = ode! { + name: "one_cmt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - // Calculate the output concentration, here defined as amount over volume - y[0] = x[0] / v; + diffeq: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; }, - ) - .with_nstates(1) - .with_nout(1); + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let iv = analytical.route_index("iv").expect("iv route exists"); + let cp = analytical.output_index("cp").expect("cp output exists"); + + // Create a subject using metadata-backed route and output names instead of + // hard-coded indices. + let subject = Subject::builder("Nikola Tesla") + .infusion(0., 500.0, iv, 0.5) + .observation(0.5, 1.645, cp) + .observation(1., 1.216, cp) + .observation(2., 0.462, cp) + .observation(3., 0.169, cp) + .observation(4., 0.063, cp) + .observation(6., 0.009, cp) + .observation(8., 0.001, cp) + .missing_observation(12.0, cp) + .build(); // Define the error models for the observations let ems = AssayErrorModels::new(). // For this example, we use a simple additive error model with 5% error add( - 0, + cp, AssayErrorModel::additive(ErrorPoly::new(0.0, 0.05, 0.0, 0.0), 0.0), )?; @@ -70,9 +61,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { let v = 194.0; // Volume of distribution // Compute likelihoods and predictions for both models - let analytical_likelihoods = an.estimate_log_likelihood(&subject, &[ke, v], &ems)?; + let analytical_likelihoods = analytical.estimate_log_likelihood(&subject, &[ke, v], &ems)?; - let analytical_predictions = an.estimate_predictions(&subject, &[ke, v])?; + let analytical_predictions = analytical.estimate_predictions(&subject, &[ke, v])?; let ode_likelihoods = ode.estimate_log_likelihood(&subject, &[ke, v], &ems)?; diff --git a/examples/sde_readme.rs b/examples/sde_readme.rs new file mode 100644 index 00000000..6106b17a --- /dev/null +++ b/examples/sde_readme.rs @@ -0,0 +1,41 @@ +fn main() -> Result<(), pharmsol::PharmsolError> { + use pharmsol::prelude::*; + + let sde = sde! { + name: "one_cmt_sde", + params: [ke, sigma_ke, v], + states: [central], + outputs: [cp], + particles: 16, + routes: { + infusion(iv) -> central, + }, + drift: |x, _p, _t, dx, _cov| { + dx[central] = -ke * x[central]; + }, + diffusion: |_p, sigma| { + sigma[central] = sigma_ke; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + }; + + let iv = sde.route_index("iv").expect("iv route exists"); + let cp = sde.output_index("cp").expect("cp output exists"); + + let subject = Subject::builder("sde_readme") + .infusion(0.0, 500.0, iv, 0.5) + .missing_observation(0.5, cp) + .missing_observation(1.0, cp) + .missing_observation(2.0, cp) + .missing_observation(4.0, cp) + .build(); + + let predictions = sde.estimate_predictions(&subject, &[1.022, 0.0, 194.0])?; + + println!("first prediction => {}", predictions[[0, 0]].prediction()); + println!("prediction grid shape => {:?}", predictions.dim()); + + Ok(()) +} diff --git a/examples/two_compartment.rs b/examples/two_compartment.rs index d81aa1f8..64d554af 100644 --- a/examples/two_compartment.rs +++ b/examples/two_compartment.rs @@ -3,6 +3,9 @@ /// This example demonstrates how to implement a two-compartment pharmacokinetic model /// with weight-based covariate scaling using pharmsol. /// +/// It uses the declaration-first `ode!` surface so the route, covariate, +/// state, and output metadata stay aligned with the generated execution path. +/// /// The two-compartment model describes drug distribution between: /// - Central compartment (x[0]): where drug enters and is eliminated /// - Peripheral compartment (x[1]): a tissue compartment in equilibrium with central @@ -18,36 +21,18 @@ fn main() -> Result<(), pharmsol::PharmsolError> { use pharmsol::prelude::*; - // Create a subject using the builder pattern - let subject = Subject::builder("subject_001") - // An infusion of 500 mg over 0.5 hours (1000 mg/hr rate) - .infusion(0.0, 500.0, 0, 0.5) - // Weight covariate at baseline (85 kg reference weight) - .covariate("wt", 0.0, 70.0) - // Observations at various time points (concentration in mg/L) - .observation(0.5, 8.5, 0) - .observation(1.0, 6.2, 0) - .observation(2.0, 4.1, 0) - .observation(4.0, 2.3, 0) - .observation(6.0, 1.5, 0) - .observation(8.0, 1.1, 0) - .observation(12.0, 0.7, 0) - // Missing observation to force prediction at this time point - .missing_observation(24.0, 0) - .build(); - - // Define the two-compartment ODE model - let ode = equation::ODE::new( - // Primary differential equation block - |x, p, t, dx, _b, rateiv, cov| { - // Fetch the (possibly interpolated) weight covariate at time t - fetch_cov!(cov, t, wt); - - // Fetch parameters from the parameter vector + let ode = ode! { + name: "two_cmt_wt", + params: [cl, v, vp, q], + covariates: [wt], + states: [central, peripheral], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _t, dx| { // CL: Clearance (L/hr), V: Central volume (L) // Vp: Peripheral volume (L), Q: Inter-compartmental clearance (L/hr) - fetch_params!(p, cl, v, vp, q); - // Weight-based allometric scaling // Reference weight is 85 kg let wt_ratio = wt / 85.0; @@ -64,36 +49,41 @@ fn main() -> Result<(), pharmsol::PharmsolError> { let kpc = q_scaled / vp_scaled; // Peripheral to central rate constant // Two-compartment model differential equations - // Central compartment: elimination + distribution + infusion input - dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0]; + // Central compartment: elimination + distribution + dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; // Peripheral compartment: distribution equilibrium - dx[1] = kcp * x[0] - kpc * x[1]; + dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; }, - // Lag time block (no lag in this model) - |_p, _t, _cov| lag! {}, - // Bioavailability block (100% for IV, so not needed) - |_p, _t, _cov| fa! {}, - // Secondary equations block (not used here) - |_p, _t, _cov, _x| {}, // Output equation block - calculates observed concentration - |x, p, t, cov, y| { - fetch_cov!(cov, t, wt); - fetch_params!(p, _cl, v, _vp, _q); - + out: |x, _t, y| { // Calculate scaled volume for concentration let wt_ratio = wt / 85.0; let v_scaled = v * wt_ratio; // Concentration = Amount / Volume - y[0] = x[0] / v_scaled; + y[cp] = x[central] / v_scaled; }, - // Model dimensions: (number of compartments, number of outputs) - ) - .with_nstates(2) - .with_nout(1); + }; + + let iv = ode.route_index("iv").expect("iv route exists"); + let cp = ode.output_index("cp").expect("cp output exists"); + + // Create a subject using metadata-backed route and output names instead of + // hard-coded numeric indices. + let subject = Subject::builder("subject_001") + .infusion(0.0, 500.0, iv, 0.5) + .covariate("wt", 0.0, 70.0) + .observation(0.5, 8.5, cp) + .observation(1.0, 6.2, cp) + .observation(2.0, 4.1, cp) + .observation(4.0, 2.3, cp) + .observation(6.0, 1.5, cp) + .observation(8.0, 1.1, cp) + .observation(12.0, 0.7, cp) + .missing_observation(24.0, cp) + .build(); // Define parameter values - // Note: order must match the fetch_params! macro order let cl = 5.0; // Clearance (L/hr) let v = 50.0; // Central volume of distribution (L) let vp = 100.0; // Peripheral volume of distribution (L) diff --git a/pharmsol-dsl/src/ast.rs b/pharmsol-dsl/src/ast.rs index d43e7404..6cff4483 100644 --- a/pharmsol-dsl/src/ast.rs +++ b/pharmsol-dsl/src/ast.rs @@ -111,10 +111,17 @@ pub struct RoutesBlock { pub span: Span, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum RouteKind { + Bolus, + Infusion, +} + #[derive(Debug, Clone, PartialEq)] pub struct RouteDecl { pub input: Ident, pub destination: Place, + pub kind: Option, pub properties: Vec, pub span: Span, } @@ -141,7 +148,7 @@ pub struct StatementBlock { #[derive(Debug, Clone, PartialEq)] pub struct AnalyticalBlock { - pub kernel: Ident, + pub structure: Ident, pub span: Span, } @@ -491,7 +498,7 @@ fn write_analytical_block( indent(out, indent_level); writeln!(out, "analytical {{")?; indent(out, indent_level + 1); - writeln!(out, "kernel = {}", block.kernel.text)?; + writeln!(out, "structure = {}", block.structure.text)?; indent(out, indent_level); write!(out, "}}") } diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index db32b9c8..c81c19eb 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -12,7 +12,7 @@ pub(super) fn parse_module(src: &str) -> Result { struct AuthoringParser<'a> { src: &'a str, - model_name: Option, + name: Option, explicit_kind: Option<(ModelKind, Span)>, parameters: Vec, constants: Vec, @@ -24,6 +24,7 @@ struct AuthoringParser<'a> { assigned_outputs: BTreeMap, declared_outputs_span: Option, routes: BTreeMap, + route_order: Vec, route_modifiers: BTreeMap>, derive_statements: Vec, derivative_statements: Vec, @@ -68,7 +69,7 @@ impl<'a> AuthoringParser<'a> { fn new(src: &'a str) -> Self { Self { src, - model_name: None, + name: None, explicit_kind: None, parameters: Vec::new(), constants: Vec::new(), @@ -80,6 +81,7 @@ impl<'a> AuthoringParser<'a> { assigned_outputs: BTreeMap::new(), declared_outputs_span: None, routes: BTreeMap::new(), + route_order: Vec::new(), route_modifiers: BTreeMap::new(), derive_statements: Vec::new(), derivative_statements: Vec::new(), @@ -132,11 +134,15 @@ impl<'a> AuthoringParser<'a> { } let surface_routes = std::mem::take(&mut self.routes); + let route_order = std::mem::take(&mut self.route_order); let mut route_modifiers = std::mem::take(&mut self.route_modifiers); let mut routes = Vec::with_capacity(surface_routes.len()); - for (route_name, route) in &surface_routes { + for route_name in route_order { + let Some(route) = surface_routes.get(&route_name) else { + continue; + }; let mut span = route.span; - let properties = route_modifiers.remove(route_name).unwrap_or_default(); + let properties = route_modifiers.remove(&route_name).unwrap_or_default(); if !properties.is_empty() { span = properties .iter() @@ -145,6 +151,10 @@ impl<'a> AuthoringParser<'a> { routes.push(RouteDecl { input: route.input.clone(), destination: route.destination.clone(), + kind: Some(match route.kind { + SurfaceRouteKind::Bolus => RouteKind::Bolus, + SurfaceRouteKind::Infusion => RouteKind::Infusion, + }), properties, span, }); @@ -169,7 +179,7 @@ impl<'a> AuthoringParser<'a> { inject_infusion_rates(&surface_routes, &routes, &mut derivative_statements); let name = self - .model_name + .name .unwrap_or_else(|| Ident::new(DEFAULT_MODEL_NAME, module_span)); let mut items = Vec::new(); @@ -298,9 +308,19 @@ impl<'a> AuthoringParser<'a> { if let Some(rest) = lhs_trimmed.strip_prefix("model") { if !rest.trim().is_empty() { - return Err(ParseError::new("expected `model = `", span)); + return Err(ParseError::new("expected `name = `", span)); } - self.model_name = Some(parse_ident_segment(rhs, rhs_abs)?); + return Err(ParseError::new( + "`model = ...` has been renamed to `name = ...`", + span, + )); + } + + if let Some(rest) = lhs_trimmed.strip_prefix("name") { + if !rest.trim().is_empty() { + return Err(ParseError::new("expected `name = `", span)); + } + self.name = Some(parse_ident_segment(rhs, rhs_abs)?); return Ok(()); } @@ -365,8 +385,15 @@ impl<'a> AuthoringParser<'a> { } if lhs_trimmed == "kernel" { - let kernel = parse_ident_segment(rhs, rhs_abs)?; - self.analytical = Some(AnalyticalBlock { span, kernel }); + return Err(ParseError::new( + "`kernel = ...` has been renamed to `structure = ...`", + span, + )); + } + + if lhs_trimmed == "structure" { + let structure = parse_ident_segment(rhs, rhs_abs)?; + self.analytical = Some(AnalyticalBlock { span, structure }); return Ok(()); } @@ -428,15 +455,16 @@ impl<'a> AuthoringParser<'a> { }; let input = parse_ident_segment(call.argument, call.argument_start)?; + let route_name = input.text.clone(); let destination = parse_place_at(rhs, line_start + arrow + 2)?; - if self.routes.contains_key(&input.text) { + if self.routes.contains_key(&route_name) { return Err(ParseError::new( format!("duplicate route `{}`", input.text), input.span, )); } self.routes.insert( - input.text.clone(), + route_name.clone(), SurfaceRoute { input, destination, @@ -444,6 +472,7 @@ impl<'a> AuthoringParser<'a> { span, }, ); + self.route_order.push(route_name); Ok(()) } @@ -569,12 +598,10 @@ impl<'a> AuthoringParser<'a> { .unwrap_or(module_span); if matches!(kind, ModelKind::Analytical) - && (!self.diffusion_statements.is_empty() - || self.particles.is_some() - || !self.init_statements.is_empty()) + && (!self.diffusion_statements.is_empty() || self.particles.is_some()) { return Err(ParseError::new( - "analytical authoring models cannot declare particles, init, or noise equations", + "analytical authoring models cannot declare particles or noise equations", kind_span, )); } @@ -589,7 +616,7 @@ impl<'a> AuthoringParser<'a> { if matches!(kind, ModelKind::Sde) { if let Some(analytical) = &self.analytical { return Err(ParseError::new( - "SDE authoring models cannot declare an analytical kernel", + "SDE authoring models cannot declare an analytical structure", analytical.span, )); } diff --git a/pharmsol-dsl/src/execution.rs b/pharmsol-dsl/src/execution.rs index 26e0a3cb..886d570a 100644 --- a/pharmsol-dsl/src/execution.rs +++ b/pharmsol-dsl/src/execution.rs @@ -4,8 +4,8 @@ use std::sync::Arc; use crate::{ AnalyticalKernel, ConstValue, CovariateInterpolation, Diagnostic, DiagnosticPhase, - DiagnosticReport, MathIntrinsic, ModelKind, RoutePropertyKind, Span, Symbol, SymbolId, - SymbolKind, SymbolType, TypedAssignTargetKind, TypedBinaryOp, TypedCall, TypedExpr, + DiagnosticReport, MathIntrinsic, ModelKind, RouteKind, RoutePropertyKind, Span, Symbol, + SymbolId, SymbolKind, SymbolType, TypedAssignTargetKind, TypedBinaryOp, TypedCall, TypedExpr, TypedExprKind, TypedModel, TypedModule, TypedRangeExpr, TypedStatePlace, TypedStatementBlock, TypedStmt, TypedStmtKind, TypedUnaryOp, ValueType, DSL_LOWERING_GENERIC, }; @@ -98,7 +98,9 @@ pub struct ExecutionState { pub struct ExecutionRoute { pub symbol: SymbolId, pub name: String, + pub declaration_index: usize, pub index: usize, + pub kind: Option, pub destination: RouteDestination, pub has_lag: bool, pub has_bioavailability: bool, @@ -349,7 +351,7 @@ pub enum ExecutionLoad { State(ExecutionStateRef), Derived(usize), Local(usize), - RouteInput(usize), + RouteInput { route: SymbolId, index: usize }, } #[derive(Debug, Clone, PartialEq)] @@ -531,33 +533,67 @@ impl<'a> ExecutionLowerer<'a> { next_state_offset += len; } + let uses_authoring_route_kinds = + !model.routes.is_empty() && model.routes.iter().all(|route| route.kind.is_some()); let mut route_slots = BTreeMap::new(); - let routes = model - .routes - .iter() - .enumerate() - .map(|(index, route)| { - let symbol = lookup_symbol(&symbol_map, route.symbol, route.span)?; - route_slots.insert(route.symbol, index); - let destination = - lower_route_destination(&symbol_map, &state_slots, &route.destination)?; - Ok(ExecutionRoute { - symbol: route.symbol, - name: symbol.name.clone(), - index, - destination, - has_lag: route - .properties - .iter() - .any(|property| property.kind == RoutePropertyKind::Lag), - has_bioavailability: route - .properties - .iter() - .any(|property| property.kind == RoutePropertyKind::Bioavailability), - span: route.span, - }) - }) - .collect::, LoweringError>>()?; + let mut routes = Vec::with_capacity(model.routes.len()); + let mut next_bolus_index = 0usize; + let mut next_infusion_index = 0usize; + for (declaration_index, route) in model.routes.iter().enumerate() { + let symbol = lookup_symbol(&symbol_map, route.symbol, route.span)?; + if route.kind == Some(RouteKind::Infusion) { + if let Some(property) = route.properties.first() { + let label = match property.kind { + RoutePropertyKind::Lag => "lag", + RoutePropertyKind::Bioavailability => "bioavailability", + }; + return Err(LoweringError::new( + format!( + "DSL authoring does not allow `{label}` on infusion route `{}`", + symbol.name + ), + property.span, + ) + .with_note("lag and bioavailability are bolus-only route properties")); + } + } + let index = if uses_authoring_route_kinds { + match route.kind.expect("authoring routes must preserve kind") { + RouteKind::Bolus => { + let index = next_bolus_index; + next_bolus_index += 1; + index + } + RouteKind::Infusion => { + let index = next_infusion_index; + next_infusion_index += 1; + index + } + } + } else { + declaration_index + }; + route_slots.insert(route.symbol, index); + let destination = + lower_route_destination(&symbol_map, &state_slots, &route.destination)?; + routes.push(ExecutionRoute { + symbol: route.symbol, + name: symbol.name.clone(), + declaration_index, + index, + kind: route.kind, + destination, + has_lag: route + .properties + .iter() + .any(|property| property.kind == RoutePropertyKind::Lag), + has_bioavailability: route + .properties + .iter() + .any(|property| property.kind == RoutePropertyKind::Bioavailability), + span: route.span, + }); + } let mut derived_slots = BTreeMap::new(); let derived = model @@ -607,7 +643,7 @@ impl<'a> ExecutionLowerer<'a> { analytical: model .analytical .as_ref() - .map(|analytical| analytical.kernel), + .map(|analytical| analytical.structure), }, symbol_map, parameter_slots, @@ -653,7 +689,7 @@ impl<'a> ExecutionLowerer<'a> { kernels.push(ExecutionKernel { role: KernelRole::Analytical, signature: signature_for(KernelRole::Analytical), - implementation: KernelImplementation::AnalyticalBuiltin(analytical.kernel), + implementation: KernelImplementation::AnalyticalBuiltin(analytical.structure), span: analytical.span, }); } @@ -745,7 +781,13 @@ impl<'a> ExecutionLowerer<'a> { }, route_buffer: DenseBufferLayout { kind: BufferKind::Routes, - len: self.metadata.routes.len(), + len: self + .metadata + .routes + .iter() + .map(|route| route.index + 1) + .max() + .unwrap_or(0), slots: self .metadata .routes @@ -858,7 +900,39 @@ impl<'a> ExecutionLowerer<'a> { let mut statements = Vec::with_capacity(self.model.routes.len()); let mut locals = LocalLowering::default(); + let default_value = match property_kind { + RoutePropertyKind::Lag => literal_real(0.0, self.model.span), + RoutePropertyKind::Bioavailability => literal_real(1.0, self.model.span), + }; + let route_len = self + .metadata + .routes + .iter() + .map(|route| route.index + 1) + .max() + .unwrap_or(0); + for route_index in 0..route_len { + let target_kind = match property_kind { + RoutePropertyKind::Lag => ExecutionTargetKind::RouteLag(route_index), + RoutePropertyKind::Bioavailability => { + ExecutionTargetKind::RouteBioavailability(route_index) + } + }; + statements.push(ExecutionStmt { + kind: ExecutionStmtKind::Assign(ExecutionAssignStmt { + target: ExecutionTarget { + kind: target_kind, + span: self.model.span, + }, + value: default_value.clone(), + }), + span: self.model.span, + }); + } for route in &self.model.routes { + if route.kind == Some(RouteKind::Infusion) { + continue; + } let route_name = self.symbol_name(route.symbol)?.to_string(); let route_index = *self.route_slots.get(&route.symbol).ok_or_else(|| { LoweringError::new( @@ -872,8 +946,7 @@ impl<'a> ExecutionLowerer<'a> { .find(|property| property.kind == property_kind) { Some(property) => self.lower_expr(&property.value, &mut locals)?, - None if property_kind == RoutePropertyKind::Lag => literal_real(0.0, route.span), - None => literal_real(1.0, route.span), + None => continue, }; let target_kind = match property_kind { RoutePropertyKind::Lag => ExecutionTargetKind::RouteLag(route_index), @@ -1098,7 +1171,10 @@ impl<'a> ExecutionLowerer<'a> { expr.span, ) })?; - ExecutionExprKind::Load(ExecutionLoad::RouteInput(route_index)) + ExecutionExprKind::Load(ExecutionLoad::RouteInput { + route: *route, + index: route_index, + }) } }, }; @@ -1439,6 +1515,100 @@ mod tests { ); } + #[test] + fn authoring_routes_share_channel_indices_by_kind_local_ordinal() { + let src = r#"name = shared_authoring +kind = ode + +params = ka, ke, v, tlag, f_oral +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let lowered = crate::lower_typed_model(&typed).expect("authoring model lowers"); + + assert_eq!(lowered.abi.route_buffer.len, 1); + assert_eq!(lowered.metadata.routes.len(), 2); + assert_eq!(lowered.metadata.routes[0].kind, Some(RouteKind::Bolus)); + assert_eq!(lowered.metadata.routes[1].kind, Some(RouteKind::Infusion)); + assert_eq!(lowered.metadata.routes[0].declaration_index, 0); + assert_eq!(lowered.metadata.routes[1].declaration_index, 1); + assert_eq!(lowered.metadata.routes[0].index, 0); + assert_eq!(lowered.metadata.routes[1].index, 0); + assert!(lowered.metadata.routes[0].has_lag); + assert!(lowered.metadata.routes[0].has_bioavailability); + assert!(!lowered.metadata.routes[1].has_lag); + assert!(!lowered.metadata.routes[1].has_bioavailability); + } + + #[test] + fn authoring_routes_reject_infusion_lag_properties() { + let src = r#"name = invalid_infusion_lag +kind = ode + +params = ke, v, tlag +states = central +outputs = cp + +infusion(iv) -> central +lag(iv) = tlag + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let error = crate::lower_typed_model(&typed) + .err() + .expect("infusion lag should fail during lowering"); + + assert!(error + .to_string() + .contains("DSL authoring does not allow `lag` on infusion route `iv`")); + } + + #[test] + fn authoring_routes_reject_infusion_bioavailability_properties() { + let src = r#"name = invalid_infusion_fa +kind = ode + +params = ke, v, f_iv +states = central +outputs = cp + +infusion(iv) -> central +fa(iv) = f_iv + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + + let model = crate::parse_model(src).expect("authoring model parses"); + let typed = crate::analyze_model(&model).expect("authoring model analyzes"); + let error = crate::lower_typed_model(&typed) + .err() + .expect("infusion bioavailability should fail during lowering"); + + assert!(error + .to_string() + .contains("DSL authoring does not allow `bioavailability` on infusion route `iv`")); + } + #[test] fn flattens_array_states_and_preserves_loop_structure() { let execution = structured_block_execution(); @@ -1538,8 +1708,8 @@ mod tests { panic!("expected statement bioavailability kernel"); }; - assert_eq!(lag_program.body.statements.len(), 2); - assert_eq!(bio_program.body.statements.len(), 2); + assert_eq!(lag_program.body.statements.len(), 3); + assert_eq!(bio_program.body.statements.len(), 3); assert!(matches!( lag_program.body.statements[1].kind, ExecutionStmtKind::Assign(ExecutionAssignStmt { diff --git a/pharmsol-dsl/src/ir.rs b/pharmsol-dsl/src/ir.rs index d1c54c90..5998431c 100644 --- a/pharmsol-dsl/src/ir.rs +++ b/pharmsol-dsl/src/ir.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; -use crate::{ModelKind, Span}; +use crate::{ModelKind, RouteKind, Span}; pub type SymbolId = usize; @@ -145,6 +145,7 @@ pub struct TypedState { #[derive(Debug, Clone, PartialEq)] pub struct TypedRoute { pub symbol: SymbolId, + pub kind: Option, pub destination: TypedStatePlace, pub properties: Vec, pub span: Span, @@ -165,7 +166,7 @@ pub enum RoutePropertyKind { #[derive(Debug, Clone, PartialEq)] pub struct TypedAnalytical { - pub kernel: AnalyticalKernel, + pub structure: AnalyticalKernel, pub span: Span, } diff --git a/pharmsol-dsl/src/parser.rs b/pharmsol-dsl/src/parser.rs index 7af6c681..f07fbd50 100644 --- a/pharmsol-dsl/src/parser.rs +++ b/pharmsol-dsl/src/parser.rs @@ -607,6 +607,7 @@ impl Parser { Ok(RouteDecl { input: input.clone(), destination, + kind: None, properties, span: input.span.join(end_span), }) @@ -616,19 +617,19 @@ impl Parser { let start = self.bump().unwrap().span; let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; - let kernel_name = self.parse_ident()?; - if kernel_name.text != "kernel" { + let structure_name = self.parse_ident()?; + if structure_name.text != "structure" { return Err(ParseError::new( format!( - "expected `kernel = ` inside analytical block, found `{}`", - kernel_name.text + "expected `structure = ` inside analytical block, found `{}`", + structure_name.text ), - kernel_name.span, + structure_name.span, )); } let eq = self.expect_simple(|kind| matches!(kind, TokenKind::Eq), "`=`")?; - let kernel = self.parse_continuation_ident_after(&eq, "kernel identifier")?; + let structure = self.parse_continuation_ident_after(&eq, "structure identifier")?; self.consume_separators(); let end = self.expect_closing( |kind| matches!(kind, TokenKind::RBrace), @@ -637,7 +638,7 @@ impl Parser { "`analytical` block", )?; Ok(AnalyticalBlock { - kernel, + structure, span: start.join(end.span), }) } @@ -1635,14 +1636,14 @@ out(cp) = gut ~ continuous() #[test] fn authoring_output_annotation_is_optional() { let annotated = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 out(cp) = central ~ continuous() "#; let plain = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 @@ -1658,14 +1659,14 @@ out(cp) = central #[test] fn authoring_dx_and_ddt_lower_equivalently() { let dx_src = r#" -model = derivative_alias +name = derivative_alias kind = ode states = central dx(central) = -ke * central out(cp) = central "#; let ddt_src = r#" -model = derivative_alias +name = derivative_alias kind = ode states = central ddt(central) = -ke * central @@ -1681,7 +1682,7 @@ out(cp) = central #[test] fn authoring_rejects_out_target_not_in_declared_outputs() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index 9f46500a..ac9223dd 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -345,29 +345,30 @@ impl<'a> Analyzer<'a> { }; let analytical = if let Some(block) = sections.analytical { - let kernel = AnalyticalKernel::from_name(&block.kernel.text).ok_or_else(|| { - SemanticError::new( - format!("unknown analytical kernel `{}`", block.kernel.text), - block.kernel.span, - ) - })?; + let structure = + AnalyticalKernel::from_name(&block.structure.text).ok_or_else(|| { + SemanticError::new( + format!("unknown analytical structure `{}`", block.structure.text), + block.structure.span, + ) + })?; let state_components = states .iter() .map(|state| state.size.unwrap_or(1)) .sum::(); - if state_components != kernel.state_count() { + if state_components != structure.state_count() { return Err(SemanticError::new( format!( - "analytical kernel `{}` expects {} state value(s), but model declares {}", - block.kernel.text, - kernel.state_count(), + "analytical structure `{}` expects {} state value(s), but model declares {}", + block.structure.text, + structure.state_count(), state_components ), - block.kernel.span, + block.structure.span, )); } Some(TypedAnalytical { - kernel, + structure, span: block.span, }) } else { @@ -624,6 +625,7 @@ impl<'a> Analyzer<'a> { } routes.push(TypedRoute { symbol: id, + kind: route.kind, destination, properties, span: route.span, @@ -2651,6 +2653,7 @@ mod tests { use crate::test_fixtures::{ RECOMMENDED_STYLE_AUTHORING, RECOMMENDED_STYLE_CANONICAL, STRUCTURED_BLOCK_CORPUS, }; + use crate::RouteKind; use crate::{parse_model, parse_module}; #[test] @@ -2667,7 +2670,7 @@ mod tests { let analytical = &typed.models[2]; assert!(matches!( - analytical.analytical.as_ref().map(|value| value.kernel), + analytical.analytical.as_ref().map(|value| value.structure), Some(AnalyticalKernel::OneCompartmentWithAbsorption) )); @@ -2691,7 +2694,7 @@ mod tests { } #[test] - fn authoring_fixture_lowers_to_equivalent_typed_ir() { + fn authoring_fixture_preserves_route_kind_while_remaining_equivalent() { let authoring_surface = RECOMMENDED_STYLE_AUTHORING; let canonical = RECOMMENDED_STYLE_CANONICAL; @@ -2705,6 +2708,8 @@ mod tests { typed_model_signature(&authoring_typed), typed_model_signature(&canonical_typed) ); + assert_eq!(authoring_typed.routes[0].kind, Some(RouteKind::Bolus)); + assert_eq!(canonical_typed.routes[0].kind, None); } #[test] @@ -2977,7 +2982,7 @@ model broken { lines.push(format!("particles:{:?}", model.particles)); lines.push(format!( "analytical:{:?}", - model.analytical.as_ref().map(|value| value.kernel) + model.analytical.as_ref().map(|value| value.structure) )); lines.push(format!( "derive:{}", diff --git a/pharmsol-dsl/src/test_fixtures.rs b/pharmsol-dsl/src/test_fixtures.rs index f26181e4..281a268e 100644 --- a/pharmsol-dsl/src/test_fixtures.rs +++ b/pharmsol-dsl/src/test_fixtures.rs @@ -83,7 +83,7 @@ model one_cmt_abs { oral -> depot } analytical { - kernel = one_compartment_with_absorption + structure = one_compartment_with_absorption } outputs { cp = central / v @@ -132,7 +132,7 @@ model vanco_sde { } "#; -pub(crate) const RECOMMENDED_STYLE_AUTHORING: &str = r#"model = recommended_style +pub(crate) const RECOMMENDED_STYLE_AUTHORING: &str = r#"name = recommended_style kind = ode params = ka, ke, v diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 941ffa77..797be3e9 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -3,14 +3,14 @@ use pharmsol_dsl::{analyze_model, parse_model, parse_module}; #[test] fn output_annotation_is_optional() { let annotated = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 out(cp) = central ~ continuous() "#; let plain = r#" -model = optional_output_annotation +name = optional_output_annotation kind = ode states = central ddt(central) = 0 @@ -26,7 +26,7 @@ out(cp) = central #[test] fn dx_and_ddt_lower_equivalently() { let dx_src = r#" -model = derivative_alias +name = derivative_alias kind = ode params = ke states = central @@ -34,7 +34,7 @@ dx(central) = -ke * central out(cp) = central "#; let ddt_src = r#" -model = derivative_alias +name = derivative_alias kind = ode params = ke states = central @@ -51,7 +51,7 @@ out(cp) = central #[test] fn rejects_out_target_not_in_declared_outputs() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central @@ -95,7 +95,7 @@ out(cp) = central / v ~ continuous() #[test] fn rejects_out_target_not_in_declared_outputs_when_declared_later() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central @@ -122,7 +122,7 @@ ddt(central) = -ke * central #[test] fn rejects_declared_output_without_assignment() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v states = central @@ -144,7 +144,7 @@ out(cp) = central / v #[test] fn rejects_unknown_output_annotation_name() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode states = central ddt(central) = 0 @@ -164,7 +164,7 @@ out(cp) = central ~ continous() #[test] fn unknown_route_destination_state_suggests_declared_state() { let src = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v diff --git a/pharmsol-macros/Cargo.toml b/pharmsol-macros/Cargo.toml index 291b888c..d9fe58ec 100644 --- a/pharmsol-macros/Cargo.toml +++ b/pharmsol-macros/Cargo.toml @@ -13,4 +13,4 @@ proc-macro = true [dependencies] proc-macro2 = "1.0.106" quote = "1.0.45" -syn = { version = "2.0.117", features = ["full"] } +syn = { version = "2.0.117", features = ["full", "visit"] } diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 0fa320a3..54a79fe3 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -4,11 +4,15 @@ //! `pharmsol` crate instead. use proc_macro::TokenStream; -use proc_macro2::TokenTree; +use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens}; +use std::collections::HashSet; use syn::{ - parse::{Parse, ParseStream}, - ExprClosure, Ident, Pat, Token, + parse::{Parse, ParseStream, Parser}, + punctuated::Punctuated, + token, + visit::Visit, + Expr, ExprClosure, Ident, LitStr, Pat, Stmt, Token, }; // --------------------------------------------------------------------------- @@ -16,6 +20,13 @@ use syn::{ // --------------------------------------------------------------------------- struct OdeInput { + name: LitStr, + params: Vec, + covariates: Vec, + states: Vec, + outputs: Vec, + routes: Vec, + diffeq_mode: OdeDiffeqMode, diffeq: ExprClosure, lag: Option, fa: Option, @@ -23,45 +34,489 @@ struct OdeInput { out: ExprClosure, } +struct AnalyticalInput { + name: LitStr, + params: Vec, + covariates: Vec, + states: Vec, + outputs: Vec, + routes: Vec, + structure: Ident, + sec: Option, + lag: Option, + fa: Option, + init: Option, + out: ExprClosure, +} + +struct SdeInput { + name: LitStr, + params: Vec, + covariates: Vec, + states: Vec, + outputs: Vec, + routes: Vec, + particles: Expr, + drift: ExprClosure, + diffusion: ExprClosure, + lag: Option, + fa: Option, + init: Option, + out: ExprClosure, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum OdeDiffeqMode { + InjectedRouteInputs, + ExplicitRouteVectors, +} + +struct OdeRouteDecl { + kind: OdeRouteKind, + input: Ident, + destination: Ident, +} + +#[derive(Clone, Copy)] +enum OdeRouteKind { + Bolus, + Infusion, +} + +struct AnalyticalKernelSpec { + runtime_path: TokenStream2, + metadata_kernel: TokenStream2, + parameter_arity: usize, + state_count: usize, +} + +struct RoutePropertyEntry { + route: Ident, + value: Expr, +} + +impl Parse for OdeRouteDecl { + fn parse(input: ParseStream) -> syn::Result { + let kind_ident: Ident = input.parse()?; + let kind = match kind_ident.to_string().as_str() { + "bolus" => OdeRouteKind::Bolus, + "infusion" => OdeRouteKind::Infusion, + other => { + return Err(syn::Error::new_spanned( + &kind_ident, + format!("unknown route kind `{other}`, expected `bolus` or `infusion`"), + )); + } + }; + + let content; + syn::parenthesized!(content in input); + let route_input: Ident = content.parse()?; + if !content.is_empty() { + return Err(content.error("expected a single route input name inside `(...)`")); + } + + if !input.peek(Token![->]) { + return Err( + input.error("expected `->` followed by a destination state in route declaration") + ); + } + input.parse::]>()?; + let destination: Ident = input.parse()?; + + if input.peek(token::Brace) { + return Err( + input.error("route properties are not supported in declaration-first `ode!` yet") + ); + } + + Ok(Self { + kind, + input: route_input, + destination, + }) + } +} + impl Parse for OdeInput { fn parse(input: ParseStream) -> syn::Result { + let mut name = None; + let mut params = None; + let mut covariates = None; + let mut states = None; + let mut outputs = None; + let mut routes = None; let mut diffeq = None; let mut lag = None; - let mut fa_val = None; + let mut fa = None; + let mut init = None; + let mut out = None; + + while !input.is_empty() { + let key: Ident = input.parse()?; + input.parse::()?; + + match key.to_string().as_str() { + "name" => set_once_ode(&mut name, input.parse()?, &key, "name")?, + "params" => set_once_ode(&mut params, parse_ident_list(input)?, &key, "params")?, + "covariates" => set_once_ode( + &mut covariates, + parse_ident_list(input)?, + &key, + "covariates", + )?, + "states" => set_once_ode(&mut states, parse_ident_list(input)?, &key, "states")?, + "outputs" => set_once_ode(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, + "routes" => set_once_ode(&mut routes, parse_route_list(input)?, &key, "routes")?, + "diffeq" => set_once_ode(&mut diffeq, input.parse()?, &key, "diffeq")?, + "lag" => set_once_ode(&mut lag, input.parse()?, &key, "lag")?, + "fa" => set_once_ode(&mut fa, input.parse()?, &key, "fa")?, + "init" => set_once_ode(&mut init, input.parse()?, &key, "init")?, + "out" => set_once_ode(&mut out, input.parse()?, &key, "out")?, + other => { + return Err(syn::Error::new_spanned( + &key, + format!( + "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, diffeq, lag, fa, init, out" + ), + )); + } + } + + if !input.is_empty() { + input.parse::()?; + } + } + + let name = name.ok_or_else(|| { + syn::Error::new( + Span::call_site(), + "declaration-first `ode!` requires `name`, `params`, `states`, `outputs`, and `routes`; the old inferred-dimensions form has been removed", + ) + })?; + let params = params.ok_or_else(|| missing_required_ode_field("params"))?; + let covariates = covariates.unwrap_or_default(); + let states = states.ok_or_else(|| missing_required_ode_field("states"))?; + let outputs = outputs.ok_or_else(|| missing_required_ode_field("outputs"))?; + let routes = routes.ok_or_else(|| missing_required_ode_field("routes"))?; + let diffeq = diffeq.ok_or_else(|| missing_required_ode_field("diffeq"))?; + let out = out.ok_or_else(|| missing_required_ode_field("out"))?; + let diffeq_mode = classify_diffeq_mode(&diffeq, &routes)?; + + validate_unique_idents("parameter", ¶ms, "ode!")?; + validate_unique_idents("covariate", &covariates, "ode!")?; + validate_unique_idents("state", &states, "ode!")?; + validate_unique_idents("output", &outputs, "ode!")?; + validate_routes(&routes, &states, "ode!")?; + validate_named_binding_compatibility( + NamedBindingSets { + params: ¶ms, + covariates: &covariates, + states: &states, + outputs: &outputs, + routes: &routes, + }, + OdeBindingClosures { + diffeq: &diffeq, + common: CommonBindingClosures { + lag: lag.as_ref(), + fa: fa.as_ref(), + init: init.as_ref(), + out: &out, + }, + diffeq_mode, + }, + )?; + + Ok(Self { + name, + params, + covariates, + states, + outputs, + routes, + diffeq_mode, + diffeq, + lag, + fa, + init, + out, + }) + } +} + +impl Parse for RoutePropertyEntry { + fn parse(input: ParseStream) -> syn::Result { + let route: Ident = input.parse()?; + input.parse::]>()?; + let value: Expr = input.parse()?; + Ok(Self { route, value }) + } +} + +impl Parse for AnalyticalInput { + fn parse(input: ParseStream) -> syn::Result { + let mut name = None; + let mut params = None; + let mut covariates = None; + let mut states = None; + let mut outputs = None; + let mut routes = None; + let mut structure = None; + let mut sec = None; + let mut lag = None; + let mut fa = None; + let mut init = None; + let mut out = None; + + while !input.is_empty() { + let key: Ident = input.parse()?; + input.parse::()?; + + match key.to_string().as_str() { + "name" => set_once_analytical(&mut name, input.parse()?, &key, "name")?, + "params" => { + set_once_analytical(&mut params, parse_ident_list(input)?, &key, "params")? + } + "covariates" => set_once_analytical( + &mut covariates, + parse_ident_list(input)?, + &key, + "covariates", + )?, + "states" => { + set_once_analytical(&mut states, parse_ident_list(input)?, &key, "states")? + } + "outputs" => { + set_once_analytical(&mut outputs, parse_ident_list(input)?, &key, "outputs")? + } + "routes" => { + set_once_analytical(&mut routes, parse_route_list(input)?, &key, "routes")? + } + "structure" => { + set_once_analytical(&mut structure, input.parse()?, &key, "structure")? + } + "sec" => set_once_analytical(&mut sec, input.parse()?, &key, "sec")?, + "lag" => set_once_analytical(&mut lag, input.parse()?, &key, "lag")?, + "fa" => set_once_analytical(&mut fa, input.parse()?, &key, "fa")?, + "init" => set_once_analytical(&mut init, input.parse()?, &key, "init")?, + "out" => set_once_analytical(&mut out, input.parse()?, &key, "out")?, + other => { + return Err(syn::Error::new_spanned( + &key, + format!( + "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, structure, sec, lag, fa, init, out" + ), + )); + } + } + + if !input.is_empty() { + input.parse::()?; + } + } + + let name = name.ok_or_else(|| missing_required_analytical_field("name"))?; + let params = params.ok_or_else(|| missing_required_analytical_field("params"))?; + let covariates = covariates.unwrap_or_default(); + let states = states.ok_or_else(|| missing_required_analytical_field("states"))?; + let outputs = outputs.ok_or_else(|| missing_required_analytical_field("outputs"))?; + let routes = routes.ok_or_else(|| missing_required_analytical_field("routes"))?; + let structure = structure.ok_or_else(|| missing_required_analytical_field("structure"))?; + let out = out.ok_or_else(|| missing_required_analytical_field("out"))?; + + validate_unique_idents("parameter", ¶ms, "analytical!")?; + validate_unique_idents("covariate", &covariates, "analytical!")?; + validate_unique_idents("state", &states, "analytical!")?; + validate_unique_idents("output", &outputs, "analytical!")?; + validate_routes(&routes, &states, "analytical!")?; + + let kernel_spec = resolve_analytical_structure(&structure)?; + if params.len() < kernel_spec.parameter_arity { + return Err(syn::Error::new_spanned( + &structure, + format!( + "analytical structure `{}` requires at least {} parameter value(s), but `params` declares {}", + structure, kernel_spec.parameter_arity, params.len() + ), + )); + } + if states.len() != kernel_spec.state_count { + return Err(syn::Error::new_spanned( + &structure, + format!( + "analytical structure `{}` expects {} state value(s), but `states` declares {}", + structure, + kernel_spec.state_count, + states.len() + ), + )); + } + + validate_analytical_named_binding_compatibility( + NamedBindingSets { + params: ¶ms, + covariates: &covariates, + states: &states, + outputs: &outputs, + routes: &routes, + }, + AnalyticalBindingClosures { + sec: sec.as_ref(), + common: CommonBindingClosures { + lag: lag.as_ref(), + fa: fa.as_ref(), + init: init.as_ref(), + out: &out, + }, + }, + )?; + + if let Some(lag) = lag.as_ref() { + let lag_routes = + extract_route_property_routes("built-in `analytical!`", "lag", lag, &routes)?; + validate_route_property_kinds("built-in `analytical!`", "lag", &routes, &lag_routes)?; + } + + if let Some(fa) = fa.as_ref() { + let fa_routes = + extract_route_property_routes("built-in `analytical!`", "fa", fa, &routes)?; + validate_route_property_kinds("built-in `analytical!`", "fa", &routes, &fa_routes)?; + } + + Ok(Self { + name, + params, + covariates, + states, + outputs, + routes, + structure, + sec, + lag, + fa, + init, + out, + }) + } +} + +impl Parse for SdeInput { + fn parse(input: ParseStream) -> syn::Result { + let mut name = None; + let mut params = None; + let mut covariates = None; + let mut states = None; + let mut outputs = None; + let mut routes = None; + let mut particles = None; + let mut drift = None; + let mut diffusion = None; + let mut lag = None; + let mut fa = None; let mut init = None; let mut out = None; while !input.is_empty() { let key: Ident = input.parse()?; input.parse::()?; - let closure: ExprClosure = input.parse()?; match key.to_string().as_str() { - "diffeq" => diffeq = Some(closure), - "lag" => lag = Some(closure), - "fa" => fa_val = Some(closure), - "init" => init = Some(closure), - "out" => out = Some(closure), + "name" => set_once_sde(&mut name, input.parse()?, &key, "name")?, + "params" => set_once_sde(&mut params, parse_ident_list(input)?, &key, "params")?, + "covariates" => set_once_sde( + &mut covariates, + parse_ident_list(input)?, + &key, + "covariates", + )?, + "states" => set_once_sde(&mut states, parse_ident_list(input)?, &key, "states")?, + "outputs" => set_once_sde(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, + "routes" => set_once_sde(&mut routes, parse_route_list(input)?, &key, "routes")?, + "particles" => set_once_sde(&mut particles, input.parse()?, &key, "particles")?, + "drift" => set_once_sde(&mut drift, input.parse()?, &key, "drift")?, + "diffusion" => set_once_sde(&mut diffusion, input.parse()?, &key, "diffusion")?, + "lag" => set_once_sde(&mut lag, input.parse()?, &key, "lag")?, + "fa" => set_once_sde(&mut fa, input.parse()?, &key, "fa")?, + "init" => set_once_sde(&mut init, input.parse()?, &key, "init")?, + "out" => set_once_sde(&mut out, input.parse()?, &key, "out")?, other => { return Err(syn::Error::new_spanned( &key, - format!("unknown field `{other}`, expected: diffeq, lag, fa, init, out"), + format!( + "unknown field `{other}`, expected one of: name, params, covariates, states, outputs, routes, particles, drift, diffusion, lag, fa, init, out" + ), )); } } - // optional trailing comma if !input.is_empty() { input.parse::()?; } } - Ok(OdeInput { - diffeq: diffeq.ok_or_else(|| input.error("missing required field `diffeq`"))?, + let name = name.ok_or_else(|| missing_required_sde_field("name"))?; + let params = params.ok_or_else(|| missing_required_sde_field("params"))?; + let covariates = covariates.unwrap_or_default(); + let states = states.ok_or_else(|| missing_required_sde_field("states"))?; + let outputs = outputs.ok_or_else(|| missing_required_sde_field("outputs"))?; + let routes = routes.ok_or_else(|| missing_required_sde_field("routes"))?; + let particles = particles.ok_or_else(|| missing_required_sde_field("particles"))?; + let drift = drift.ok_or_else(|| missing_required_sde_field("drift"))?; + let diffusion = diffusion.ok_or_else(|| missing_required_sde_field("diffusion"))?; + let out = out.ok_or_else(|| missing_required_sde_field("out"))?; + + validate_unique_idents("parameter", ¶ms, "sde!")?; + validate_unique_idents("covariate", &covariates, "sde!")?; + validate_unique_idents("state", &states, "sde!")?; + validate_unique_idents("output", &outputs, "sde!")?; + validate_routes(&routes, &states, "sde!")?; + validate_sde_named_binding_compatibility( + NamedBindingSets { + params: ¶ms, + covariates: &covariates, + states: &states, + outputs: &outputs, + routes: &routes, + }, + SdeBindingClosures { + drift: &drift, + diffusion: &diffusion, + common: CommonBindingClosures { + lag: lag.as_ref(), + fa: fa.as_ref(), + init: init.as_ref(), + out: &out, + }, + }, + )?; + + if let Some(lag) = lag.as_ref() { + let lag_routes = + extract_route_property_routes("declaration-first `sde!`", "lag", lag, &routes)?; + validate_route_property_kinds("declaration-first `sde!`", "lag", &routes, &lag_routes)?; + } + + if let Some(fa) = fa.as_ref() { + let fa_routes = + extract_route_property_routes("declaration-first `sde!`", "fa", fa, &routes)?; + validate_route_property_kinds("declaration-first `sde!`", "fa", &routes, &fa_routes)?; + } + + Ok(Self { + name, + params, + covariates, + states, + outputs, + routes, + particles, + drift, + diffusion, lag, - fa: fa_val, + fa, init, - out: out.ok_or_else(|| input.error("missing required field `out`"))?, + out, }) } } @@ -70,7 +525,86 @@ impl Parse for OdeInput { // Helpers // --------------------------------------------------------------------------- -/// Return the identifier string for a closure parameter (empty for wildcards). +fn missing_required_ode_field(name: &str) -> syn::Error { + syn::Error::new( + Span::call_site(), + format!("missing required field `{name}` in declaration-first `ode!`"), + ) +} + +fn missing_required_analytical_field(name: &str) -> syn::Error { + syn::Error::new( + Span::call_site(), + format!("missing required field `{name}` in built-in `analytical!`"), + ) +} + +fn missing_required_sde_field(name: &str) -> syn::Error { + syn::Error::new( + Span::call_site(), + format!("missing required field `{name}` in declaration-first `sde!`"), + ) +} + +fn set_once_ode(slot: &mut Option, value: T, key: &Ident, name: &str) -> syn::Result<()> { + if slot.is_some() { + Err(syn::Error::new_spanned( + key, + format!("duplicate field `{name}` in `ode!`"), + )) + } else { + *slot = Some(value); + Ok(()) + } +} + +fn set_once_analytical( + slot: &mut Option, + value: T, + key: &Ident, + name: &str, +) -> syn::Result<()> { + if slot.is_some() { + Err(syn::Error::new_spanned( + key, + format!("duplicate field `{name}` in `analytical!`"), + )) + } else { + *slot = Some(value); + Ok(()) + } +} + +fn set_once_sde(slot: &mut Option, value: T, key: &Ident, name: &str) -> syn::Result<()> { + if slot.is_some() { + Err(syn::Error::new_spanned( + key, + format!("duplicate field `{name}` in `sde!`"), + )) + } else { + *slot = Some(value); + Ok(()) + } +} + +fn parse_ident_list(input: ParseStream) -> syn::Result> { + let content; + syn::bracketed!(content in input); + Ok(Punctuated::::parse_terminated(&content)? + .into_iter() + .collect()) +} + +fn parse_route_list(input: ParseStream) -> syn::Result> { + let content; + syn::braced!(content in input); + Ok( + Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(), + ) +} + fn param_name(pat: &Pat) -> String { match pat { Pat::Ident(p) => p.ident.to_string(), @@ -82,208 +616,2571 @@ fn closure_param_names(c: &ExprClosure) -> Vec { c.inputs.iter().map(param_name).collect() } -/// Recursively scan `tokens` for `ident[literal_int]` patterns where the -/// ident matches one of `names`. Returns the maximum literal integer found. -fn max_literal_index(tokens: proc_macro2::TokenStream, names: &[&str]) -> Option { - let tts: Vec = tokens.into_iter().collect(); - let mut best: Option = None; - - for (i, tt) in tts.iter().enumerate() { - match tt { - TokenTree::Ident(ident) => { - let s = ident.to_string(); - if names.contains(&s.as_str()) { - if let Some(TokenTree::Group(g)) = tts.get(i + 1) { - if g.delimiter() == proc_macro2::Delimiter::Bracket { - let inner: Vec = g.stream().into_iter().collect(); - if inner.len() == 1 { - if let TokenTree::Literal(lit) = &inner[0] { - if let Ok(n) = lit.to_string().parse::() { - best = Some(best.map_or(n, |m: usize| m.max(n))); - } - } - } - } - } - } - } - // recurse into brace / paren groups (bracket groups are indexing, handled above) - TokenTree::Group(g) - if matches!( - g.delimiter(), - proc_macro2::Delimiter::Brace | proc_macro2::Delimiter::Parenthesis - ) => +fn closure_param_ident(c: &ExprClosure, index: usize) -> Option { + c.inputs.get(index).and_then(|pat| match pat { + Pat::Ident(pat_ident) => Some(pat_ident.ident.clone()), + _ => None, + }) +} + +fn generated_ident(name: &str) -> Ident { + Ident::new(name, Span::call_site()) +} + +#[derive(Default)] +struct ClosureBodyUsage { + idents: HashSet, + indexed_idents: HashSet, + assigned_indexed_idents: HashSet, + contains_macro: bool, +} + +impl ClosureBodyUsage { + fn analyze(expr: &Expr) -> Self { + let mut usage = Self::default(); + usage.visit_expr(expr); + usage + } + + fn uses(&self, ident: &Ident) -> bool { + self.contains_macro || self.idents.contains(&ident.to_string()) + } + + fn mentions(&self, ident: &Ident) -> bool { + self.idents.contains(&ident.to_string()) + } + + fn indexes(&self, ident: &Ident) -> bool { + self.indexed_idents.contains(&ident.to_string()) + } + + fn assigns_index(&self, ident: &Ident) -> bool { + self.assigned_indexed_idents.contains(&ident.to_string()) + } +} + +impl<'ast> Visit<'ast> for ClosureBodyUsage { + fn visit_expr_path(&mut self, expr_path: &'ast syn::ExprPath) { + if expr_path.qself.is_none() + && expr_path.path.leading_colon.is_none() + && expr_path.path.segments.len() == 1 + { + self.idents + .insert(expr_path.path.segments[0].ident.to_string()); + } + + syn::visit::visit_expr_path(self, expr_path); + } + + fn visit_expr_macro(&mut self, expr_macro: &'ast syn::ExprMacro) { + self.contains_macro = true; + syn::visit::visit_expr_macro(self, expr_macro); + } + + fn visit_stmt_macro(&mut self, stmt_macro: &'ast syn::StmtMacro) { + self.contains_macro = true; + syn::visit::visit_stmt_macro(self, stmt_macro); + } + + fn visit_expr_index(&mut self, expr_index: &'ast syn::ExprIndex) { + if let Expr::Path(expr_path) = expr_index.expr.as_ref() { + if expr_path.qself.is_none() + && expr_path.path.leading_colon.is_none() + && expr_path.path.segments.len() == 1 { - if let Some(n) = max_literal_index(g.stream(), names) { - best = Some(best.map_or(n, |m: usize| m.max(n))); + self.indexed_idents + .insert(expr_path.path.segments[0].ident.to_string()); + } + } + + syn::visit::visit_expr_index(self, expr_index); + } + + fn visit_expr_assign(&mut self, expr_assign: &'ast syn::ExprAssign) { + if let Expr::Index(expr_index) = expr_assign.left.as_ref() { + if let Expr::Path(expr_path) = expr_index.expr.as_ref() { + if expr_path.qself.is_none() + && expr_path.path.leading_colon.is_none() + && expr_path.path.segments.len() == 1 + { + self.assigned_indexed_idents + .insert(expr_path.path.segments[0].ident.to_string()); } } - _ => {} } + + syn::visit::visit_expr_assign(self, expr_assign); + } +} + +fn generate_closure_input_aliases( + closure: &ExprClosure, + internal_names: &[Ident], +) -> syn::Result { + if closure.inputs.len() != internal_names.len() { + return Err(syn::Error::new_spanned( + closure, + "internal named binding generation error: closure arity mismatch", + )); } - best + let aliases = + closure + .inputs + .iter() + .zip(internal_names.iter()) + .map(|(pattern, internal_name)| { + quote! { + let #pattern = #internal_name; + } + }); + + Ok(quote! { + #(#aliases)* + }) } -// --------------------------------------------------------------------------- -// Proc macro -// --------------------------------------------------------------------------- +fn generate_supported_input_aliases( + closure: &ExprClosure, + supported_internal_names: &[&[Ident]], + error_message: &str, +) -> syn::Result { + for internal_names in supported_internal_names { + if closure.inputs.len() == internal_names.len() { + return generate_closure_input_aliases(closure, internal_names); + } + } -/// Build an `equation::ODE` while **inferring** `nstates`, `ndrugs` and -/// `nout` from the maximum literal bracket-indices used in the closures. -/// -/// # Fields (any order, comma-separated) -/// -/// | Field | Required | Signature | -/// |----------|----------|-------------------------------------------------| -/// | `diffeq` | **yes** | `\|x, p, t, dx, bolus, rateiv, cov\| { … }` | -/// | `out` | **yes** | `\|x, p, t, cov, y\| { … }` | -/// | `init` | no | `\|p, t, cov, x\| { … }` | -/// | `lag` | no | `\|p, t, cov\| lag! { … }` | -/// | `fa` | no | `\|p, t, cov\| fa! { … }` | -/// -/// # Inference rules -/// -/// * **nstates** = max literal index of the state / derivative vectors + 1 -/// * **ndrugs** = max literal index of bolus / rateiv vectors + 1 -/// * **nout** = max literal index of the output vector + 1 -/// -/// Parameter names are taken from the closure signatures so you can name them -/// however you like. Only **literal** integer indices (e.g. `x[2]`) are -/// detected; computed indices require manual `.with_nstates()` etc. -/// -/// # Example -/// -/// ```ignore -/// use pharmsol::prelude::*; -/// -/// let ode = ode! { -/// diffeq: |x, p, _t, dx, b, rateiv, _cov| { -/// fetch_params!(p, ke, kcp, kpc, _v); -/// dx[0] = rateiv[0] + b[0] - ke * x[0] - kcp * x[0] + kpc * x[1]; -/// dx[1] = kcp * x[0] - kpc * x[1]; -/// }, -/// out: |x, p, _t, _cov, y| { -/// fetch_params!(p, _ke, _kcp, _kpc, v); -/// y[0] = x[0] / v; -/// }, -/// }; -/// // Inferred: nstates=2, ndrugs=1, nout=1 -/// ``` -#[proc_macro] -pub fn ode(input: TokenStream) -> TokenStream { - let input = syn::parse_macro_input!(input as OdeInput); + Err(syn::Error::new_spanned(closure, error_message)) +} - // ── Validate parameter counts ──────────────────────────────── - let de_params = closure_param_names(&input.diffeq); - if de_params.len() != 7 { - return syn::Error::new_spanned( - &input.diffeq, - "diffeq closure must have 7 parameters: |x, p, t, dx, bolus, rateiv, cov|", - ) - .to_compile_error() - .into(); +fn generate_parameter_bindings( + params: &[Ident], + closure: &ExprClosure, + parameter_vector: &Ident, +) -> TokenStream2 { + let usage = ClosureBodyUsage::analyze(closure.body.as_ref()); + let bindings = params + .iter() + .enumerate() + .filter(|(_, ident)| usage.uses(ident)) + .map(|(index, ident)| { + quote! { + #[allow(unused_variables)] + let #ident = #parameter_vector[#index]; + } + }); + + quote! { + #(#bindings)* } +} - let out_params = closure_param_names(&input.out); - if out_params.len() != 5 { - return syn::Error::new_spanned( - &input.out, - "out closure must have 5 parameters: |x, p, t, cov, y|", - ) - .to_compile_error() - .into(); +fn generate_mutable_parameter_bindings( + params: &[Ident], + closure: &ExprClosure, + parameter_vector: &Ident, +) -> (TokenStream2, TokenStream2) { + let usage = ClosureBodyUsage::analyze(closure.body.as_ref()); + let used_params = params + .iter() + .enumerate() + .filter(|(_, ident)| usage.uses(ident)) + .collect::>(); + + let bindings = used_params.iter().map(|(index, ident)| { + quote! { + #[allow(unused_mut, unused_variables)] + let mut #ident = #parameter_vector[#index]; + } + }); + let writebacks = used_params.iter().map(|(index, ident)| { + quote! { + #parameter_vector[#index] = #ident; + } + }); + + (quote! { #(#bindings)* }, quote! { #(#writebacks)* }) +} + +fn generate_covariate_bindings( + covariates: &[Ident], + closure: &ExprClosure, + covariate_map: &Ident, + time: &Ident, +) -> TokenStream2 { + let usage = ClosureBodyUsage::analyze(closure.body.as_ref()); + let used_covariates = covariates + .iter() + .filter(|ident| usage.uses(ident)) + .collect::>(); + + if used_covariates.is_empty() { + quote! {} + } else { + quote! { + ::pharmsol::fetch_cov!(#covariate_map, #time, #(#used_covariates),*); + } } +} - // ── Collect names by role ──────────────────────────────────── - // diffeq positions: 0=x 3=dx 4=bolus 5=rateiv - // out positions: 0=x 4=y - // init positions: 3=x - let mut state_names: Vec = vec![ - de_params[0].clone(), - de_params[3].clone(), - out_params[0].clone(), - ]; - if let Some(ref ic) = input.init { - let ip = closure_param_names(ic); - if ip.len() >= 4 { - state_names.push(ip[3].clone()); +fn classify_diffeq_mode( + diffeq: &ExprClosure, + routes: &[OdeRouteDecl], +) -> syn::Result { + match closure_param_names(diffeq).len() { + 3 => Ok(OdeDiffeqMode::InjectedRouteInputs), + 7 => Ok(OdeDiffeqMode::ExplicitRouteVectors), + 5 => { + let usage = ClosureBodyUsage::analyze(diffeq.body.as_ref()); + let route_inputs = route_input_idents(routes); + let fourth_param = closure_param_ident(diffeq, 3); + let fifth_param = closure_param_ident(diffeq, 4); + let mentions_route_inputs = route_inputs.iter().any(|route| usage.mentions(route)); + let indexes_fifth_param = fifth_param.as_ref().is_some_and(|ident| usage.indexes(ident)); + let reads_fourth_param_as_input = fourth_param + .as_ref() + .is_some_and(|ident| usage.indexes(ident) && !usage.assigns_index(ident)); + + if mentions_route_inputs || indexes_fifth_param || reads_fourth_param_as_input { + Ok(OdeDiffeqMode::ExplicitRouteVectors) + } else { + Ok(OdeDiffeqMode::InjectedRouteInputs) + } } + _ => Err(syn::Error::new_spanned( + diffeq, + "declaration-first `ode!` requires `diffeq` to have either 3 parameters: |x, t, dx|, 5 parameters: |x, p, t, dx, cov| or |x, t, dx, bolus, rateiv|, or 7 parameters: |x, p, t, dx, bolus, rateiv, cov|", + )), } - state_names.sort(); - state_names.dedup(); +} + +fn route_input_idents(routes: &[OdeRouteDecl]) -> Vec { + routes.iter().map(|route| route.input.clone()).collect() +} - let drug_names = [de_params[4].clone(), de_params[5].clone()]; - let output_names = [out_params[4].clone()]; +fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(Ident, usize)> { + let mut next_bolus_index = 0usize; + let mut next_infusion_index = 0usize; - // filter empties (from wildcard `_` params) - let state_refs: Vec<&str> = state_names - .iter() - .map(String::as_str) - .filter(|s| !s.is_empty()) - .collect(); - let drug_refs: Vec<&str> = drug_names + routes .iter() - .map(String::as_str) - .filter(|s| !s.is_empty()) - .collect(); - let output_refs: Vec<&str> = output_names + .map(|route| { + let index = match route.kind { + OdeRouteKind::Bolus => { + let index = next_bolus_index; + next_bolus_index += 1; + index + } + OdeRouteKind::Infusion => { + let index = next_infusion_index; + next_infusion_index += 1; + index + } + }; + (route.input.clone(), index) + }) + .collect() +} + +fn dense_index_len(bindings: &[(Ident, usize)]) -> usize { + bindings .iter() - .map(String::as_str) - .filter(|s| !s.is_empty()) - .collect(); - - // ── Scan closure bodies ────────────────────────────────────── - let de_tokens = input.diffeq.body.to_token_stream(); - let out_tokens = input.out.body.to_token_stream(); - let init_tokens = input.init.as_ref().map(|c| c.body.to_token_stream()); - - let max_state = [ - max_literal_index(de_tokens.clone(), &state_refs), - max_literal_index(out_tokens.clone(), &state_refs), - init_tokens.and_then(|t| max_literal_index(t, &state_refs)), - ] - .into_iter() - .flatten() - .max(); - - let max_drug = max_literal_index(de_tokens, &drug_refs); - let max_out = max_literal_index(out_tokens, &output_refs); - - let nstates = max_state.map_or(1, |n| n + 1); - let ndrugs = max_drug.map_or(1, |n| n + 1); - let nout = max_out.map_or(1, |n| n + 1); - - // ── Generate output ────────────────────────────────────────── - let diffeq = &input.diffeq; - let out = &input.out; - - let lag = input.lag.as_ref().map_or_else( - || quote! { |_, _, _| ::std::collections::HashMap::new() }, - |c| quote! { #c }, - ); - - let fa = input.fa.as_ref().map_or_else( - || quote! { |_, _, _| ::std::collections::HashMap::new() }, - |c| quote! { #c }, - ); - - let init = input - .init - .as_ref() - .map_or_else(|| quote! { |_, _, _, _| {} }, |c| quote! { #c }); + .map(|(_, index)| index + 1) + .max() + .unwrap_or(0) +} - quote! { - equation::ODE::new( - #diffeq, - #lag, - #fa, - #init, - #out, - ) - .with_nstates(#nstates) - .with_ndrugs(#ndrugs) - .with_nout(#nout) +fn validate_binding_conflicts( + left_label: &str, + left: &[Ident], + right_label: &str, + right: &[Ident], + context: &str, +) -> syn::Result<()> { + let right_names = right.iter().map(Ident::to_string).collect::>(); + + for ident in left { + let name = ident.to_string(); + if right_names.contains(&name) { + return Err(syn::Error::new_spanned( + ident, + format!( + "named {left_label} binding `{name}` conflicts with named {right_label} binding in {context}" + ), + )); + } + } + + Ok(()) +} + +fn validate_closure_param_conflicts( + closure_label: &str, + closure: &ExprClosure, + bindings: &[Ident], + binding_label: &str, +) -> syn::Result<()> { + let parameter_names = closure_param_names(closure) + .into_iter() + .filter(|name| !name.is_empty()) + .collect::>(); + + for ident in bindings { + let name = ident.to_string(); + if parameter_names.contains(&name) { + return Err(syn::Error::new_spanned( + ident, + format!( + "named {binding_label} binding `{name}` conflicts with `{closure_label}` closure parameter `{name}`" + ), + )); + } + } + + Ok(()) +} + +#[derive(Clone, Copy)] +struct NamedBindingSets<'a> { + params: &'a [Ident], + covariates: &'a [Ident], + states: &'a [Ident], + outputs: &'a [Ident], + routes: &'a [OdeRouteDecl], +} + +#[derive(Clone, Copy)] +struct CommonBindingClosures<'a> { + lag: Option<&'a ExprClosure>, + fa: Option<&'a ExprClosure>, + init: Option<&'a ExprClosure>, + out: &'a ExprClosure, +} + +#[derive(Clone, Copy)] +struct AnalyticalBindingClosures<'a> { + sec: Option<&'a ExprClosure>, + common: CommonBindingClosures<'a>, +} + +#[derive(Clone, Copy)] +struct OdeBindingClosures<'a> { + diffeq: &'a ExprClosure, + common: CommonBindingClosures<'a>, + diffeq_mode: OdeDiffeqMode, +} + +#[derive(Clone, Copy)] +struct SdeBindingClosures<'a> { + drift: &'a ExprClosure, + diffusion: &'a ExprClosure, + common: CommonBindingClosures<'a>, +} + +fn validate_named_binding_compatibility( + bindings: NamedBindingSets<'_>, + closures: OdeBindingClosures<'_>, +) -> syn::Result<()> { + let NamedBindingSets { + params, + covariates, + states, + outputs, + routes, + } = bindings; + let OdeBindingClosures { + diffeq, + common: CommonBindingClosures { lag, fa, init, out }, + diffeq_mode, + } = closures; + let route_inputs = route_input_idents(routes); + + validate_binding_conflicts( + "parameter", + params, + "covariate", + covariates, + "declaration-first `ode!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "state", + states, + "`diffeq` and `out` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "output", + outputs, + "`out` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "output", + outputs, + "`out` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "state", + states, + "declaration-first `ode!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "output", + outputs, + "declaration-first `ode!` named binding generation", + )?; + + validate_closure_param_conflicts("diffeq", diffeq, params, "parameter")?; + validate_closure_param_conflicts("diffeq", diffeq, covariates, "covariate")?; + validate_closure_param_conflicts("diffeq", diffeq, states, "state")?; + + if diffeq_mode == OdeDiffeqMode::ExplicitRouteVectors { + validate_binding_conflicts( + "parameter", + params, + "route", + &route_inputs, + "`diffeq` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "route", + &route_inputs, + "`diffeq` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`diffeq` named binding generation", + )?; + validate_closure_param_conflicts("diffeq", diffeq, &route_inputs, "route")?; + } + + if let Some(lag) = lag { + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`lag` named binding generation", + )?; + validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, covariates, "covariate")?; + validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; + } + + if let Some(fa) = fa { + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`fa` named binding generation", + )?; + validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, covariates, "covariate")?; + validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; + } + + if let Some(init) = init { + validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, covariates, "covariate")?; + validate_closure_param_conflicts("init", init, states, "state")?; + } + + validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, covariates, "covariate")?; + validate_closure_param_conflicts("out", out, states, "state")?; + validate_closure_param_conflicts("out", out, outputs, "output")?; + + Ok(()) +} + +fn validate_analytical_named_binding_compatibility( + bindings: NamedBindingSets<'_>, + closures: AnalyticalBindingClosures<'_>, +) -> syn::Result<()> { + let NamedBindingSets { + params, + covariates, + states, + outputs, + routes, + } = bindings; + let AnalyticalBindingClosures { + sec, + common: CommonBindingClosures { lag, fa, init, out }, + } = closures; + let route_inputs = route_input_idents(routes); + + validate_binding_conflicts( + "parameter", + params, + "covariate", + covariates, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "state", + states, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "output", + outputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "state", + states, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "output", + outputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "output", + outputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + validate_binding_conflicts( + "output", + outputs, + "route", + &route_inputs, + "`analytical!` named binding generation", + )?; + + if let Some(sec) = sec { + validate_closure_param_conflicts("sec", sec, params, "parameter")?; + validate_closure_param_conflicts("sec", sec, covariates, "covariate")?; + } + + if let Some(lag) = lag { + validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, covariates, "covariate")?; + validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; + } + + if let Some(fa) = fa { + validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, covariates, "covariate")?; + validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; + } + + if let Some(init) = init { + validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, covariates, "covariate")?; + validate_closure_param_conflicts("init", init, states, "state")?; + } + + validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, covariates, "covariate")?; + validate_closure_param_conflicts("out", out, states, "state")?; + validate_closure_param_conflicts("out", out, outputs, "output")?; + + Ok(()) +} + +fn validate_sde_named_binding_compatibility( + bindings: NamedBindingSets<'_>, + closures: SdeBindingClosures<'_>, +) -> syn::Result<()> { + let NamedBindingSets { + params, + covariates, + states, + outputs, + routes, + } = bindings; + let SdeBindingClosures { + drift, + diffusion, + common: CommonBindingClosures { lag, fa, init, out }, + } = closures; + let route_inputs = route_input_idents(routes); + + validate_binding_conflicts( + "parameter", + params, + "covariate", + covariates, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "state", + states, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "output", + outputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "state", + states, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "output", + outputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "covariate", + covariates, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "parameter", + params, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "output", + outputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "state", + states, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + validate_binding_conflicts( + "output", + outputs, + "route", + &route_inputs, + "`sde!` named binding generation", + )?; + + validate_closure_param_conflicts("drift", drift, params, "parameter")?; + validate_closure_param_conflicts("drift", drift, covariates, "covariate")?; + validate_closure_param_conflicts("drift", drift, states, "state")?; + validate_closure_param_conflicts("diffusion", diffusion, params, "parameter")?; + validate_closure_param_conflicts("diffusion", diffusion, states, "state")?; + + if let Some(lag) = lag { + validate_closure_param_conflicts("lag", lag, params, "parameter")?; + validate_closure_param_conflicts("lag", lag, covariates, "covariate")?; + validate_closure_param_conflicts("lag", lag, &route_inputs, "route")?; + } + + if let Some(fa) = fa { + validate_closure_param_conflicts("fa", fa, params, "parameter")?; + validate_closure_param_conflicts("fa", fa, covariates, "covariate")?; + validate_closure_param_conflicts("fa", fa, &route_inputs, "route")?; + } + + if let Some(init) = init { + validate_closure_param_conflicts("init", init, params, "parameter")?; + validate_closure_param_conflicts("init", init, covariates, "covariate")?; + validate_closure_param_conflicts("init", init, states, "state")?; + } + + validate_closure_param_conflicts("out", out, params, "parameter")?; + validate_closure_param_conflicts("out", out, covariates, "covariate")?; + validate_closure_param_conflicts("out", out, states, "state")?; + validate_closure_param_conflicts("out", out, outputs, "output")?; + + Ok(()) +} + +fn generate_index_consts(idents: &[Ident]) -> TokenStream2 { + let bindings = idents.iter().enumerate().map(|(index, ident)| { + quote! { + #[allow(non_upper_case_globals, dead_code)] + const #ident: usize = #index; + } + }); + + quote! { + #(#bindings)* + } +} + +fn generate_mapped_index_consts(bindings: &[(Ident, usize)]) -> TokenStream2 { + let bindings = bindings.iter().map(|(ident, index)| { + quote! { + #[allow(non_upper_case_globals, dead_code)] + const #ident: usize = #index; + } + }); + + quote! { + #(#bindings)* + } +} + +fn expand_out( + out: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], + outputs: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let output_consts = generate_index_consts(outputs); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let y = generated_ident("__pharmsol_y"); + let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()]; + let reduced_inputs = [x.clone(), t.clone(), y.clone()]; + let input_aliases = generate_supported_input_aliases( + out, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|", + )?; + let parameter_bindings = generate_parameter_bindings(params, out, &p); + let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); + let body = &out.body; + + Ok(quote! {{ + let __pharmsol_out: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #y: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #output_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_out + }}) +} + +fn route_property_error(macro_name: &str, label: &str, node: T) -> syn::Error { + syn::Error::new_spanned( + node, + format!( + "{macro_name} requires `{label}` to return `{label}! {{ ... }}` so route-property metadata can be synthesized" + ), + ) +} + +fn find_terminal_macro_invocation( + macro_name: &str, + label: &str, + closure: &ExprClosure, +) -> syn::Result { + match closure.body.as_ref() { + Expr::Macro(expr_macro) if expr_macro.mac.path.is_ident(label) => { + Ok(expr_macro.mac.clone()) + } + Expr::Macro(expr_macro) => Err(route_property_error(macro_name, label, expr_macro)), + Expr::Block(expr_block) => { + for stmt in expr_block.block.stmts.iter().rev() { + match stmt { + Stmt::Expr(Expr::Macro(expr_macro), _) + if expr_macro.mac.path.is_ident(label) => + { + return Ok(expr_macro.mac.clone()); + } + Stmt::Expr(Expr::Macro(expr_macro), _) => { + return Err(route_property_error(macro_name, label, expr_macro)); + } + Stmt::Expr(other, _) => { + return Err(route_property_error(macro_name, label, other)); + } + Stmt::Macro(stmt_macro) if stmt_macro.mac.path.is_ident(label) => { + return Ok(stmt_macro.mac.clone()); + } + Stmt::Macro(stmt_macro) => { + return Err(route_property_error(macro_name, label, stmt_macro)); + } + _ => continue, + } + } + + Err(route_property_error(macro_name, label, expr_block)) + } + other => Err(route_property_error(macro_name, label, other)), + } +} + +fn extract_route_property_routes( + macro_name: &str, + label: &str, + closure: &ExprClosure, + routes: &[OdeRouteDecl], +) -> syn::Result> { + let macro_expr = find_terminal_macro_invocation(macro_name, label, closure)?; + let entries = Punctuated::::parse_terminated + .parse2(macro_expr.tokens.clone())?; + let known_routes = route_input_idents(routes) + .into_iter() + .map(|route| route.to_string()) + .collect::>(); + let mut seen = HashSet::new(); + + for entry in entries { + let route_name = entry.route.to_string(); + if !known_routes.contains(&route_name) { + return Err(syn::Error::new_spanned( + &entry.route, + format!( + "route `{route_name}` in `{label}!` is not declared in the `routes` section" + ), + )); + } + if !seen.insert(route_name.clone()) { + return Err(syn::Error::new_spanned( + &entry.route, + format!("duplicate route `{route_name}` in `{label}!`"), + )); + } + let _ = entry.value; + } + + Ok(seen) +} + +fn validate_route_property_kinds( + macro_name: &str, + label: &str, + routes: &[OdeRouteDecl], + property_routes: &HashSet, +) -> syn::Result<()> { + for route in routes { + if property_routes.contains(&route.input.to_string()) + && matches!(route.kind, OdeRouteKind::Infusion) + { + return Err(syn::Error::new_spanned( + &route.input, + format!( + "{macro_name} does not allow `{label}` on infusion route `{}`", + route.input + ), + )); + } + } + + Ok(()) +} + +fn expand_ode_route_map( + label: &str, + closure: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + route_bindings: &[(Ident, usize)], +) -> syn::Result { + let route_consts = generate_mapped_index_consts(route_bindings); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + closure, + &[&full_inputs, &reduced_inputs], + &format!( + "declaration-first `ode!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|" + ), + )?; + let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); + let body = &closure.body; + + Ok(quote! {{ + let __pharmsol_route_map: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) -> ::std::collections::HashMap = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #route_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_route_map + }}) +} + +fn expand_ode_init( + init: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let x = generated_ident("__pharmsol_x"); + let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()]; + let reduced_inputs = [t.clone(), x.clone()]; + let input_aliases = generate_supported_input_aliases( + init, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|", + )?; + let parameter_bindings = generate_parameter_bindings(params, init, &p); + let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t); + let body = &init.body; + + Ok(quote! {{ + let __pharmsol_init: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #x: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_init + }}) +} + +fn expand_route_metadata( + routes: &[OdeRouteDecl], + diffeq_mode: OdeDiffeqMode, + lag_routes: &HashSet, + fa_routes: &HashSet, +) -> Vec { + routes + .iter() + .map(|route| { + let input = &route.input; + let destination = &route.destination; + let route_name = route.input.to_string(); + let route_builder = match route.kind { + OdeRouteKind::Bolus => { + quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + } + OdeRouteKind::Infusion => { + quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + } + }; + let input_policy = match diffeq_mode { + OdeDiffeqMode::InjectedRouteInputs => quote! { .inject_input_to_destination() }, + OdeDiffeqMode::ExplicitRouteVectors => quote! { .expect_explicit_input() }, + }; + let lag_flag = if lag_routes.contains(&route_name) { + quote! { .with_lag() } + } else { + quote! {} + }; + let fa_flag = if fa_routes.contains(&route_name) { + quote! { .with_bioavailability() } + } else { + quote! {} + }; + + quote! { + #route_builder + .to_state(stringify!(#destination)) + #lag_flag + #fa_flag + #input_policy + } + }) + .collect() +} + +fn expand_analytical_route_metadata( + routes: &[OdeRouteDecl], + lag_routes: &HashSet, + fa_routes: &HashSet, +) -> Vec { + routes + .iter() + .map(|route| { + let input = &route.input; + let destination = &route.destination; + let route_name = route.input.to_string(); + let route_builder = match route.kind { + OdeRouteKind::Bolus => { + quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + } + OdeRouteKind::Infusion => { + quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + } + }; + let lag_flag = if lag_routes.contains(&route_name) { + quote! { .with_lag() } + } else { + quote! {} + }; + let fa_flag = if fa_routes.contains(&route_name) { + quote! { .with_bioavailability() } + } else { + quote! {} + }; + + quote! { + #route_builder + .to_state(stringify!(#destination)) + #lag_flag + #fa_flag + } + }) + .collect() +} + +fn expand_sde_route_metadata( + routes: &[OdeRouteDecl], + lag_routes: &HashSet, + fa_routes: &HashSet, +) -> Vec { + routes + .iter() + .map(|route| { + let input = &route.input; + let destination = &route.destination; + let route_name = route.input.to_string(); + let route_builder = match route.kind { + OdeRouteKind::Bolus => { + quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } + } + OdeRouteKind::Infusion => { + quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } + } + }; + let lag_flag = if lag_routes.contains(&route_name) { + quote! { .with_lag() } + } else { + quote! {} + }; + let fa_flag = if fa_routes.contains(&route_name) { + quote! { .with_bioavailability() } + } else { + quote! {} + }; + + quote! { + #route_builder + .to_state(stringify!(#destination)) + .inject_input_to_destination() + #lag_flag + #fa_flag + } + }) + .collect() +} + +fn route_destination_index(route: &OdeRouteDecl, states: &[Ident]) -> usize { + states + .iter() + .position(|state| state == &route.destination) + .expect("validated route destination should exist") +} + +fn expand_injected_ode_route_terms( + routes: &[OdeRouteDecl], + states: &[Ident], + route_bindings: &[(Ident, usize)], + dx: &Ident, + bolus: &Ident, + rateiv: &Ident, +) -> TokenStream2 { + let terms = routes + .iter() + .zip(route_bindings.iter()) + .map(|(route, (_, channel_index))| { + let destination = route_destination_index(route, states); + match route.kind { + OdeRouteKind::Bolus => quote! { + #dx[#destination] += #bolus[#channel_index]; + }, + OdeRouteKind::Infusion => quote! { + #dx[#destination] += #rateiv[#channel_index]; + }, + } + }); + + quote! { + #(#terms)* + } +} + +fn expand_injected_sde_rate_terms( + routes: &[OdeRouteDecl], + states: &[Ident], + route_bindings: &[(Ident, usize)], + dx: &Ident, + rateiv: &Ident, +) -> TokenStream2 { + let terms = + routes + .iter() + .zip(route_bindings.iter()) + .filter_map(|(route, (_, channel_index))| match route.kind { + OdeRouteKind::Bolus => None, + OdeRouteKind::Infusion => { + let destination = route_destination_index(route, states); + Some(quote! { + #dx[#destination] += #rateiv[#channel_index]; + }) + } + }); + + quote! { + #(#terms)* + } +} + +fn expand_injected_sde_bolus_mappings( + routes: &[OdeRouteDecl], + states: &[Ident], + route_bindings: &[(Ident, usize)], +) -> TokenStream2 { + let mut destinations = vec![quote! { None }; dense_index_len(route_bindings)]; + + for (route, (_, channel_index)) in routes.iter().zip(route_bindings.iter()) { + if let OdeRouteKind::Bolus = route.kind { + let destination = route_destination_index(route, states); + destinations[*channel_index] = quote! { Some(#destination) }; + } + } + + quote! { + .with_injected_bolus_inputs(&[#(#destinations),*]) + } +} + +fn validate_unique_idents(kind: &str, idents: &[Ident], macro_name: &str) -> syn::Result<()> { + let mut seen = HashSet::new(); + for ident in idents { + let name = ident.to_string(); + if !seen.insert(name.clone()) { + return Err(syn::Error::new_spanned( + ident, + format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"), + )); + } + } + Ok(()) +} + +fn validate_routes(routes: &[OdeRouteDecl], states: &[Ident], macro_name: &str) -> syn::Result<()> { + let known_states = states.iter().map(Ident::to_string).collect::>(); + let mut seen_routes = HashSet::new(); + + for route in routes { + let route_name = route.input.to_string(); + if !seen_routes.insert(route_name.clone()) { + return Err(syn::Error::new_spanned( + &route.input, + format!("duplicate route `{route_name}` in declaration-first `{macro_name}`"), + )); + } + + if !known_states.contains(&route.destination.to_string()) { + return Err(syn::Error::new_spanned( + &route.destination, + format!( + "route destination `{}` is not declared in the `states` section", + route.destination + ), + )); + } + } + + Ok(()) +} + +fn expand_diffeq( + diffeq: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], + routes: &[OdeRouteDecl], + route_bindings: &[(Ident, usize)], + diffeq_mode: OdeDiffeqMode, +) -> syn::Result { + let state_consts = generate_index_consts(states); + + match diffeq_mode { + OdeDiffeqMode::ExplicitRouteVectors => { + let route_consts = generate_mapped_index_consts(route_bindings); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let bolus = generated_ident("__pharmsol_bolus"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [ + x.clone(), + p.clone(), + t.clone(), + dx.clone(), + bolus.clone(), + rateiv.clone(), + cov.clone(), + ]; + let reduced_inputs = [ + x.clone(), + t.clone(), + dx.clone(), + bolus.clone(), + rateiv.clone(), + ]; + let input_aliases = generate_supported_input_aliases( + diffeq, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` explicit-route `diffeq` requires either 7 parameters: |x, p, t, dx, bolus, rateiv, cov| or 5 parameters: |x, t, dx, bolus, rateiv|", + )?; + let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); + let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); + let body = &diffeq.body; + + Ok(quote! {{ + let __pharmsol_diffeq: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #bolus: &::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #route_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_diffeq + }}) + } + OdeDiffeqMode::InjectedRouteInputs => { + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let bolus = generated_ident("__pharmsol_bolus"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; + let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; + let input_aliases = generate_supported_input_aliases( + diffeq, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", + )?; + let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); + let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); + let body = &diffeq.body; + let dx_binding = if diffeq.inputs.len() == full_inputs.len() { + closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()) + } else { + closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone()) + }; + let route_terms = expand_injected_ode_route_terms( + routes, + states, + route_bindings, + &dx_binding, + &bolus, + &rateiv, + ); + + Ok(quote! {{ + let __pharmsol_diffeq: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #bolus: &::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + #route_terms + }; + __pharmsol_diffeq + }}) + } + } +} + +fn resolve_analytical_structure(structure: &Ident) -> syn::Result { + let structure_name = structure.to_string(); + let (runtime_path, metadata_kernel, parameter_arity, state_count) = match structure_name + .as_str() + { + "one_compartment" => ( + quote! { ::pharmsol::equation::one_compartment }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartment }, + 1, + 1, + ), + "one_compartment_cl" => ( + quote! { ::pharmsol::equation::one_compartment_cl }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentCl }, + 2, + 1, + ), + "one_compartment_cl_with_absorption" => ( + quote! { ::pharmsol::equation::one_compartment_cl_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentClWithAbsorption }, + 3, + 2, + ), + "one_compartment_with_absorption" => ( + quote! { ::pharmsol::equation::one_compartment_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::OneCompartmentWithAbsorption }, + 2, + 2, + ), + "two_compartments" => ( + quote! { ::pharmsol::equation::two_compartments }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartments }, + 3, + 2, + ), + "two_compartments_cl" => ( + quote! { ::pharmsol::equation::two_compartments_cl }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsCl }, + 4, + 2, + ), + "two_compartments_cl_with_absorption" => ( + quote! { ::pharmsol::equation::two_compartments_cl_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsClWithAbsorption }, + 5, + 3, + ), + "two_compartments_with_absorption" => ( + quote! { ::pharmsol::equation::two_compartments_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::TwoCompartmentsWithAbsorption }, + 4, + 3, + ), + "three_compartments" => ( + quote! { ::pharmsol::equation::three_compartments }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartments }, + 5, + 3, + ), + "three_compartments_cl" => ( + quote! { ::pharmsol::equation::three_compartments_cl }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsCl }, + 6, + 3, + ), + "three_compartments_cl_with_absorption" => ( + quote! { ::pharmsol::equation::three_compartments_cl_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsClWithAbsorption }, + 7, + 4, + ), + "three_compartments_with_absorption" => ( + quote! { ::pharmsol::equation::three_compartments_with_absorption }, + quote! { ::pharmsol::equation::AnalyticalKernel::ThreeCompartmentsWithAbsorption }, + 6, + 4, + ), + _ => { + return Err(syn::Error::new_spanned( + structure, + format!("unknown analytical structure `{structure_name}`"), + )); + } + }; + + Ok(AnalyticalKernelSpec { + runtime_path, + metadata_kernel, + parameter_arity, + state_count, + }) +} + +fn expand_analytical_route_map( + label: &str, + closure: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + route_bindings: &[(Ident, usize)], +) -> syn::Result { + let route_consts = generate_mapped_index_consts(route_bindings); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + closure, + &[&full_inputs, &reduced_inputs], + &format!( + "built-in `analytical!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|" + ), + )?; + let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); + let body = &closure.body; + + Ok(quote! {{ + let __pharmsol_route_map: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) -> ::std::collections::HashMap = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #route_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_route_map + }}) +} + +fn expand_analytical_sec( + sec: &ExprClosure, + params: &[Ident], + covariates: &[Ident], +) -> syn::Result { + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + sec, + &[&full_inputs, &reduced_inputs], + "built-in `analytical!` requires `sec` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|", + )?; + let parameter_vector = if sec.inputs.len() == full_inputs.len() { + closure_param_ident(sec, 0).unwrap_or_else(|| p.clone()) + } else { + p.clone() + }; + let (parameter_bindings, parameter_writebacks) = + generate_mutable_parameter_bindings(params, sec, ¶meter_vector); + let covariate_bindings = generate_covariate_bindings(covariates, sec, &cov, &t); + let body = &sec.body; + + Ok(quote! {{ + let __pharmsol_sec: fn( + &mut ::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) = |#p: &mut ::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #parameter_bindings + #covariate_bindings + #body + #parameter_writebacks + }; + __pharmsol_sec + }}) +} + +fn expand_analytical_init( + init: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let x = generated_ident("__pharmsol_x"); + let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()]; + let reduced_inputs = [t.clone(), x.clone()]; + let input_aliases = generate_supported_input_aliases( + init, + &[&full_inputs, &reduced_inputs], + "built-in `analytical!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|", + )?; + let parameter_bindings = generate_parameter_bindings(params, init, &p); + let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t); + let body = &init.body; + + Ok(quote! {{ + let __pharmsol_init: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #x: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_init + }}) +} + +fn expand_analytical_out( + out: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], + outputs: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let output_consts = generate_index_consts(outputs); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let y = generated_ident("__pharmsol_y"); + let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()]; + let reduced_inputs = [x.clone(), t.clone(), y.clone()]; + let input_aliases = generate_supported_input_aliases( + out, + &[&full_inputs, &reduced_inputs], + "built-in `analytical!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|", + )?; + let parameter_bindings = generate_parameter_bindings(params, out, &p); + let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); + let body = &out.body; + + Ok(quote! {{ + let __pharmsol_out: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #y: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #output_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_out + }}) +} + +fn expand_sde_drift( + drift: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], + routes: &[OdeRouteDecl], + route_bindings: &[(Ident, usize)], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; + let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; + let input_aliases = generate_supported_input_aliases( + drift, + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `drift` to have either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", + )?; + let parameter_bindings = generate_parameter_bindings(params, drift, &p); + let covariate_bindings = generate_covariate_bindings(covariates, drift, &cov, &t); + let body = &drift.body; + let dx_binding = if drift.inputs.len() == full_inputs.len() { + closure_param_ident(drift, 3).unwrap_or_else(|| dx.clone()) + } else { + closure_param_ident(drift, 2).unwrap_or_else(|| dx.clone()) + }; + let rate_terms = + expand_injected_sde_rate_terms(routes, states, route_bindings, &dx_binding, &rateiv); + + Ok(quote! {{ + let __pharmsol_drift: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + #rate_terms + }; + __pharmsol_drift + }}) +} + +fn expand_sde_diffusion( + diffusion: &ExprClosure, + params: &[Ident], + states: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let sigma = generated_ident("__pharmsol_sigma"); + let full_inputs = [p.clone(), sigma.clone()]; + let reduced_inputs = [sigma.clone()]; + let input_aliases = generate_supported_input_aliases( + diffusion, + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `diffusion` to have either 2 parameters: |p, sigma| or 1 parameter: |sigma|", + )?; + let parameter_bindings = generate_parameter_bindings(params, diffusion, &p); + let body = &diffusion.body; + + Ok(quote! {{ + let __pharmsol_diffusion: fn( + &::pharmsol::simulator::V, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #sigma: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #body + }; + __pharmsol_diffusion + }}) +} + +fn expand_sde_route_map( + label: &str, + closure: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + route_bindings: &[(Ident, usize)], +) -> syn::Result { + let route_consts = generate_mapped_index_consts(route_bindings); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [p.clone(), t.clone(), cov.clone()]; + let reduced_inputs = [t.clone()]; + let input_aliases = generate_supported_input_aliases( + closure, + &[&full_inputs, &reduced_inputs], + &format!( + "declaration-first `sde!` requires `{label}` to have either 3 parameters: |p, t, cov| or 1 parameter: |t|" + ), + )?; + let parameter_bindings = generate_parameter_bindings(params, closure, &p); + let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); + let body = &closure.body; + + Ok(quote! {{ + let __pharmsol_route_map: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + ) -> ::std::collections::HashMap = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #route_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_route_map + }}) +} + +fn expand_sde_init( + init: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let x = generated_ident("__pharmsol_x"); + let full_inputs = [p.clone(), t.clone(), cov.clone(), x.clone()]; + let reduced_inputs = [t.clone(), x.clone()]; + let input_aliases = generate_supported_input_aliases( + init, + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `init` to have either 4 parameters: |p, t, cov, x| or 2 parameters: |t, x|", + )?; + let parameter_bindings = generate_parameter_bindings(params, init, &p); + let covariate_bindings = generate_covariate_bindings(covariates, init, &cov, &t); + let body = &init.body; + + Ok(quote! {{ + let __pharmsol_init: fn( + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #x: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_init + }}) +} + +fn expand_sde_out( + out: &ExprClosure, + params: &[Ident], + covariates: &[Ident], + states: &[Ident], + outputs: &[Ident], +) -> syn::Result { + let state_consts = generate_index_consts(states); + let output_consts = generate_index_consts(outputs); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let cov = generated_ident("__pharmsol_cov"); + let y = generated_ident("__pharmsol_y"); + let full_inputs = [x.clone(), p.clone(), t.clone(), cov.clone(), y.clone()]; + let reduced_inputs = [x.clone(), t.clone(), y.clone()]; + let input_aliases = generate_supported_input_aliases( + out, + &[&full_inputs, &reduced_inputs], + "declaration-first `sde!` requires `out` to have either 5 parameters: |x, p, t, cov, y| or 3 parameters: |x, t, y|", + )?; + let parameter_bindings = generate_parameter_bindings(params, out, &p); + let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); + let body = &out.body; + + Ok(quote! {{ + let __pharmsol_out: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &::pharmsol::data::Covariates, + &mut ::pharmsol::simulator::V, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #cov: &::pharmsol::data::Covariates, + #y: &mut ::pharmsol::simulator::V| { + #input_aliases + #state_consts + #output_consts + #parameter_bindings + #covariate_bindings + #body + }; + __pharmsol_out + }}) +} + +// --------------------------------------------------------------------------- +// Proc macros +// --------------------------------------------------------------------------- + +#[proc_macro] +pub fn ode(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as OdeInput); + + let route_bindings = ode_route_channel_bindings(&input.routes); + + let lag_routes = match input.lag.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `ode!`", + "lag", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `ode!`", + "lag", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let fa_routes = match input.fa.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `ode!`", + "fa", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `ode!`", + "fa", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let diffeq = match expand_diffeq( + &input.diffeq, + &input.params, + &input.covariates, + &input.states, + &input.routes, + &route_bindings, + input.diffeq_mode, + ) { + Ok(diffeq) => diffeq, + Err(error) => return error.to_compile_error().into(), + }; + + let out = match expand_out( + &input.out, + &input.params, + &input.covariates, + &input.states, + &input.outputs, + ) { + Ok(out) => out, + Err(error) => return error.to_compile_error().into(), + }; + + let nstates = input.states.len(); + let ndrugs = dense_index_len(&route_bindings); + let nout = input.outputs.len(); + + let name = &input.name; + let params = &input.params; + let covariates = &input.covariates; + let states = &input.states; + let outputs = &input.outputs; + let routes = expand_route_metadata(&input.routes, input.diffeq_mode, &lag_routes, &fa_routes); + let covariate_metadata = if covariates.is_empty() { + quote! {} + } else { + quote! { + .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + } + }; + + let lag = match input.lag.as_ref() { + Some(closure) => match expand_ode_route_map( + "lag", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(lag) => lag, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let fa = match input.fa.as_ref() { + Some(closure) => { + match expand_ode_route_map( + "fa", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(fa) => fa, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let init = match input.init.as_ref() { + Some(closure) => { + match expand_ode_init(closure, &input.params, &input.covariates, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _, _| {} }, + }; + + quote! {{ + let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) + .parameters([#(stringify!(#params)),*]) + #covariate_metadata + .states([#(stringify!(#states)),*]) + .outputs([#(stringify!(#outputs)),*]) + #(.route(#routes))*; + + ::pharmsol::equation::ODE::new( + #diffeq, + #lag, + #fa, + #init, + #out, + ) + .with_nstates(#nstates) + .with_ndrugs(#ndrugs) + .with_nout(#nout) + .with_metadata(__pharmsol_metadata) + .expect("declaration-first `ode!` generated invalid metadata") + }} + .into() +} + +#[proc_macro] +pub fn analytical(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as AnalyticalInput); + let route_bindings = ode_route_channel_bindings(&input.routes); + + let kernel_spec = match resolve_analytical_structure(&input.structure) { + Ok(spec) => spec, + Err(error) => return error.to_compile_error().into(), + }; + + let lag_routes = match input.lag.as_ref() { + Some(closure) => match extract_route_property_routes( + "built-in `analytical!`", + "lag", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "built-in `analytical!`", + "lag", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let fa_routes = match input.fa.as_ref() { + Some(closure) => match extract_route_property_routes( + "built-in `analytical!`", + "fa", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "built-in `analytical!`", + "fa", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let sec = match input.sec.as_ref() { + Some(closure) => match expand_analytical_sec(closure, &input.params, &input.covariates) { + Ok(sec) => sec, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _| {} }, + }; + + let out = match expand_analytical_out( + &input.out, + &input.params, + &input.covariates, + &input.states, + &input.outputs, + ) { + Ok(out) => out, + Err(error) => return error.to_compile_error().into(), + }; + + let lag = match input.lag.as_ref() { + Some(closure) => { + match expand_analytical_route_map( + "lag", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(lag) => lag, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let fa = match input.fa.as_ref() { + Some(closure) => { + match expand_analytical_route_map( + "fa", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(fa) => fa, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let init = match input.init.as_ref() { + Some(closure) => { + match expand_analytical_init(closure, &input.params, &input.covariates, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _, _| {} }, + }; + + let nstates = input.states.len(); + let ndrugs = dense_index_len(&route_bindings); + let nout = input.outputs.len(); + + let name = &input.name; + let params = &input.params; + let covariates = &input.covariates; + let states = &input.states; + let outputs = &input.outputs; + let routes = expand_analytical_route_metadata(&input.routes, &lag_routes, &fa_routes); + let runtime_path = kernel_spec.runtime_path; + let metadata_kernel = kernel_spec.metadata_kernel; + let covariate_metadata = if covariates.is_empty() { + quote! {} + } else { + quote! { + .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + } + }; + + quote! {{ + let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) + .kind(::pharmsol::equation::ModelKind::Analytical) + .parameters([#(stringify!(#params)),*]) + #covariate_metadata + .states([#(stringify!(#states)),*]) + .outputs([#(stringify!(#outputs)),*]) + #(.route(#routes))* + .analytical_kernel(#metadata_kernel); + + ::pharmsol::equation::Analytical::new( + #runtime_path, + #sec, + #lag, + #fa, + #init, + #out, + ) + .with_nstates(#nstates) + .with_ndrugs(#ndrugs) + .with_nout(#nout) + .with_metadata(__pharmsol_metadata) + .expect("built-in `analytical!` generated invalid metadata") + }} + .into() +} + +#[proc_macro] +pub fn sde(input: TokenStream) -> TokenStream { + let input = syn::parse_macro_input!(input as SdeInput); + let route_bindings = ode_route_channel_bindings(&input.routes); + + let lag_routes = match input.lag.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `sde!`", + "lag", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `sde!`", + "lag", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let fa_routes = match input.fa.as_ref() { + Some(closure) => match extract_route_property_routes( + "declaration-first `sde!`", + "fa", + closure, + &input.routes, + ) { + Ok(routes) => { + if let Err(error) = validate_route_property_kinds( + "declaration-first `sde!`", + "fa", + &input.routes, + &routes, + ) { + return error.to_compile_error().into(); + } + routes + } + Err(error) => return error.to_compile_error().into(), + }, + None => HashSet::new(), + }; + + let drift = match expand_sde_drift( + &input.drift, + &input.params, + &input.covariates, + &input.states, + &input.routes, + &route_bindings, + ) { + Ok(drift) => drift, + Err(error) => return error.to_compile_error().into(), + }; + + let diffusion = match expand_sde_diffusion(&input.diffusion, &input.params, &input.states) { + Ok(diffusion) => diffusion, + Err(error) => return error.to_compile_error().into(), + }; + + let lag = match input.lag.as_ref() { + Some(closure) => match expand_sde_route_map( + "lag", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(lag) => lag, + Err(error) => return error.to_compile_error().into(), + }, + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let fa = match input.fa.as_ref() { + Some(closure) => { + match expand_sde_route_map( + "fa", + closure, + &input.params, + &input.covariates, + &route_bindings, + ) { + Ok(fa) => fa, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _| ::std::collections::HashMap::new() }, + }; + + let init = match input.init.as_ref() { + Some(closure) => { + match expand_sde_init(closure, &input.params, &input.covariates, &input.states) { + Ok(init) => init, + Err(error) => return error.to_compile_error().into(), + } + } + None => quote! { |_, _, _, _| {} }, + }; + + let out = match expand_sde_out( + &input.out, + &input.params, + &input.covariates, + &input.states, + &input.outputs, + ) { + Ok(out) => out, + Err(error) => return error.to_compile_error().into(), + }; + + let nstates = input.states.len(); + let ndrugs = dense_index_len(&route_bindings); + let nout = input.outputs.len(); + + let name = &input.name; + let params = &input.params; + let covariates = &input.covariates; + let states = &input.states; + let outputs = &input.outputs; + let particles = &input.particles; + let routes = expand_sde_route_metadata(&input.routes, &lag_routes, &fa_routes); + let bolus_mappings = + expand_injected_sde_bolus_mappings(&input.routes, &input.states, &route_bindings); + let covariate_metadata = if covariates.is_empty() { + quote! {} + } else { + quote! { + .covariates([#(::pharmsol::equation::Covariate::continuous(stringify!(#covariates))),*]) + } + }; + + quote! {{ + let __pharmsol_particles: usize = #particles; + let __pharmsol_metadata = ::pharmsol::equation::metadata::new(#name) + .kind(::pharmsol::equation::ModelKind::Sde) + .parameters([#(stringify!(#params)),*]) + #covariate_metadata + .states([#(stringify!(#states)),*]) + .outputs([#(stringify!(#outputs)),*]) + #(.route(#routes))* + .particles(__pharmsol_particles); + + ::pharmsol::equation::SDE::new( + #drift, + #diffusion, + #lag, + #fa, + #init, + #out, + __pharmsol_particles, + ) + .with_nstates(#nstates) + .with_ndrugs(#ndrugs) + .with_nout(#nout) + #bolus_mappings + .with_metadata(__pharmsol_metadata) + .expect("declaration-first `sde!` generated invalid metadata") + }} + .into() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn rejects_removed_legacy_form() { + let error = syn::parse_str::( + "diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("legacy macro form must fail"); + + assert!(error + .to_string() + .contains("requires `name`, `params`, `states`, `outputs`, and `routes`")); + assert!(error + .to_string() + .contains("old inferred-dimensions form has been removed")); + } + + #[test] + fn validates_route_destinations() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> peripheral }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown route destination must fail"); + + assert!(error + .to_string() + .contains("route destination `peripheral` is not declared in the `states` section")); + } + + #[test] + fn rejects_named_binding_collisions() { + let error = syn::parse_str::( + "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("parameter/state binding collisions must fail"); + + assert!(error + .to_string() + .contains("named parameter binding `central` conflicts with named state binding")); + } + + #[test] + fn ode_route_bindings_share_channels_by_kind_local_ordinal() { + let input = syn::parse_str::( + "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: { bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot }, diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", + ) + .expect("declaration-first ode input should parse"); + + let bindings = ode_route_channel_bindings(&input.routes); + + assert_eq!(dense_index_len(&bindings), 2); + assert_eq!(bindings[0].0.to_string(), "oral"); + assert_eq!(bindings[0].1, 0); + assert_eq!(bindings[1].0.to_string(), "iv"); + assert_eq!(bindings[1].1, 0); + assert_eq!(bindings[2].0.to_string(), "sc"); + assert_eq!(bindings[2].1, 1); + } + + #[test] + fn generated_parameter_bindings_only_include_referenced_locals_in_hot_closures() { + let params = vec![generated_ident("ke"), generated_ident("v")]; + let closure = syn::parse_str::( + "|x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }", + ) + .expect("closure should parse"); + + let bindings = + generate_parameter_bindings(¶ms, &closure, &generated_ident("__pharmsol_p")) + .to_string(); + + assert!( + bindings.contains("let ke = __pharmsol_p [0usize] ;") + || bindings.contains("let ke = __pharmsol_p [ 0 ] ;") + ); + assert!(!bindings.contains("let v =")); + } + + #[test] + fn generated_parameter_bindings_fall_back_to_all_params_for_stmt_macros() { + let params = vec![generated_ident("ka"), generated_ident("tlag")]; + let closure = syn::parse_str::("|_p, _t, _cov| { lag! { oral => tlag } }") + .expect("closure should parse"); + + let bindings = + generate_parameter_bindings(¶ms, &closure, &generated_ident("__pharmsol_p")) + .to_string(); + + assert!(bindings.contains("let ka =")); + assert!(bindings.contains("let tlag =")); + } + + #[test] + fn analytical_accepts_extra_parameters_beyond_kernel_arity() { + let input = syn::parse_str::( + "name: \"demo\", params: [ka, ke, v, tlag, tvke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, sec: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}", + ) + .expect("extra declared parameters should be allowed"); + + assert_eq!(input.params.len(), 5); + assert_eq!(input.covariates.len(), 2); + assert!(input.sec.is_some()); + assert_eq!(input.states.len(), 2); + } + + #[test] + fn analytical_rejects_unknown_structure() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, structure: mystery, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown analytical structure must fail"); + + assert!(error + .to_string() + .contains("unknown analytical structure `mystery`")); + } + + #[test] + fn analytical_rejects_insufficient_kernel_parameters() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("insufficient kernel parameters must fail"); + + assert!(error + .to_string() + .contains("requires at least 2 parameter value(s)")); + } + + #[test] + fn analytical_rejects_unknown_route_property_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown lag route must fail"); + + assert!(error + .to_string() + .contains("route `iv` in `lag!` is not declared in the `routes` section")); + } + + #[test] + fn analytical_rejects_infusion_lag_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("infusion lag must fail"); + + assert!(error + .to_string() + .contains("built-in `analytical!` does not allow `lag` on infusion route `iv`")); + } + + #[test] + fn sde_requires_particles() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("missing particles must fail"); + + assert!(error + .to_string() + .contains("missing required field `particles` in declaration-first `sde!`")); + } + + #[test] + fn sde_rejects_unknown_route_property_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("unknown lag route must fail"); + + assert!(error + .to_string() + .contains("route `oral` in `lag!` is not declared in the `routes` section")); + } + + #[test] + fn sde_rejects_infusion_lag_binding() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("infusion lag must fail"); + + assert!(error + .to_string() + .contains("declaration-first `sde!` does not allow `lag` on infusion route `iv`")); } - .into() } diff --git a/src/dsl/compiled_backend_abi.rs b/src/dsl/compiled_backend_abi.rs index 26a2f825..8717c416 100644 --- a/src/dsl/compiled_backend_abi.rs +++ b/src/dsl/compiled_backend_abi.rs @@ -324,7 +324,9 @@ mod tests { }], routes: vec![NativeRouteInfo { name: "iv".to_string(), + declaration_index: 0, index: 0, + kind: None, destination_offset: 1, inject_input_to_destination: true, }], diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index ed516387..b0f1fe4a 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -731,7 +731,7 @@ fn lower_load( ExecutionLoad::Parameter(index) => load_fixed(builder, env.args.params, *index, ty), ExecutionLoad::Covariate(index) => load_fixed(builder, env.args.covariates, *index, ty), ExecutionLoad::Derived(index) => load_fixed(builder, env.args.derived, *index, ty), - ExecutionLoad::RouteInput(index) => load_fixed(builder, env.args.routes, *index, ty), + ExecutionLoad::RouteInput { index, .. } => load_fixed(builder, env.args.routes, *index, ty), ExecutionLoad::Local(index) => { let binding = env.locals.get(index).ok_or_else(|| { JitCompileError::new(format!("unknown local slot {index}"), Some(span)) @@ -1330,6 +1330,89 @@ mod tests { assert!(debugged.contains("error[DSL4000]"), "{}", debugged); } + #[test] + fn authoring_runtime_shares_channel_between_bolus_and_infusion_routes() { + let source = r#" +name = shared_authoring +kind = ode + +params = ka, ke, v +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#; + let parsed = pharmsol_dsl::parse_model(source).expect("authoring model parses"); + let typed = pharmsol_dsl::analyze_model(&parsed).expect("authoring model analyzes"); + let model = pharmsol_dsl::lower_typed_model(&typed).expect("authoring model lowers"); + let jit = compile_ode_model_to_jit(&model) + .expect("compile jit ode model") + .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); + + let oral = jit.route_index("oral").expect("oral route"); + let iv = jit.route_index("iv").expect("iv route"); + let cp = jit.output_index("cp").expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(iv, 0); + + let subject = Subject::builder("ode") + .bolus(0.0, 120.0, oral) + .infusion(6.0, 60.0, iv, 2.0) + .observation(0.5, 0.0, cp) + .observation(1.0, 0.0, cp) + .observation(2.0, 0.0, cp) + .observation(6.0, 0.0, cp) + .observation(7.0, 0.0, cp) + .observation(9.0, 0.0, cp) + .build(); + + let support = vec![1.2, 0.15, 40.0]; + let jit_predictions = jit + .estimate_predictions(&subject, &support) + .expect("jit predictions"); + + let reference = ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + let ka = p[0]; + let ke = p[1]; + dx[0] = -ka * x[0] + bolus[0]; + dx[1] = ka * x[0] - ke * x[1] + rateiv[0]; + }, + |_p, _t, _cov| std::collections::HashMap::new(), + |_p, _t, _cov| std::collections::HashMap::new(), + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + y[0] = x[1] / p[2]; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); + + let reference_predictions = reference + .estimate_predictions(&subject, &support) + .expect("reference ode predictions"); + + for (jit_pred, reference_pred) in jit_predictions + .predictions() + .iter() + .zip(reference_predictions.predictions()) + { + assert_relative_eq!( + jit_pred.prediction(), + reference_pred.prediction(), + max_relative = 1e-4 + ); + } + } + fn slot_index(layout: &DenseBufferLayout, name: &str) -> usize { layout .slots diff --git a/src/dsl/model_info.rs b/src/dsl/model_info.rs index 8e48a022..0094059f 100644 --- a/src/dsl/model_info.rs +++ b/src/dsl/model_info.rs @@ -1,10 +1,12 @@ +use std::collections::BTreeMap; + use serde::{Deserialize, Serialize}; use pharmsol_dsl::execution::{ ExecutionExpr, ExecutionExprKind, ExecutionLoad, ExecutionModel, ExecutionStmt, ExecutionStmtKind, KernelImplementation, KernelRole, }; -use pharmsol_dsl::{AnalyticalKernel, ModelKind}; +use pharmsol_dsl::{AnalyticalKernel, ModelKind, RouteKind}; #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeModelInfo { @@ -31,7 +33,11 @@ pub struct NativeCovariateInfo { #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeRouteInfo { pub name: String, + #[serde(default)] + pub declaration_index: usize, pub index: usize, + #[serde(default)] + pub kind: Option, pub destination_offset: usize, pub inject_input_to_destination: bool, } @@ -69,10 +75,12 @@ impl NativeModelInfo { .iter() .map(|route| NativeRouteInfo { name: route.name.clone(), + declaration_index: route.declaration_index, index: route.index, + kind: route.kind, destination_offset: route.destination.state_offset, inject_input_to_destination: !explicit_route_input_usage - .get(route.index) + .get(route.declaration_index) .copied() .unwrap_or(false), }) @@ -97,6 +105,12 @@ impl NativeModelInfo { } fn explicit_route_input_usage(model: &ExecutionModel) -> Vec { + let declaration_slots = model + .metadata + .routes + .iter() + .map(|route| (route.symbol, route.declaration_index)) + .collect::>(); let Some(kernel) = (match model.kind { ModelKind::Ode => model.kernel(KernelRole::Dynamics), ModelKind::Sde => model.kernel(KernelRole::Drift), @@ -107,54 +121,155 @@ fn explicit_route_input_usage(model: &ExecutionModel) -> Vec { let mut usage = vec![false; model.metadata.routes.len()]; if let KernelImplementation::Statements(program) = &kernel.implementation { - mark_route_inputs_in_statements(&program.body.statements, &mut usage); + mark_route_inputs_in_statements(&program.body.statements, &declaration_slots, &mut usage); } usage } -fn mark_route_inputs_in_statements(statements: &[ExecutionStmt], usage: &mut [bool]) { +fn mark_route_inputs_in_statements( + statements: &[ExecutionStmt], + declaration_slots: &BTreeMap, + usage: &mut [bool], +) { for statement in statements { match &statement.kind { ExecutionStmtKind::Let(let_stmt) => { - mark_route_inputs_in_expr(&let_stmt.value, usage); + mark_route_inputs_in_expr(&let_stmt.value, declaration_slots, usage); } ExecutionStmtKind::Assign(assign_stmt) => { - mark_route_inputs_in_expr(&assign_stmt.value, usage); + mark_route_inputs_in_expr(&assign_stmt.value, declaration_slots, usage); } ExecutionStmtKind::If(if_stmt) => { - mark_route_inputs_in_expr(&if_stmt.condition, usage); - mark_route_inputs_in_statements(&if_stmt.then_branch, usage); + mark_route_inputs_in_expr(&if_stmt.condition, declaration_slots, usage); + mark_route_inputs_in_statements(&if_stmt.then_branch, declaration_slots, usage); if let Some(else_branch) = &if_stmt.else_branch { - mark_route_inputs_in_statements(else_branch, usage); + mark_route_inputs_in_statements(else_branch, declaration_slots, usage); } } ExecutionStmtKind::For(for_stmt) => { - mark_route_inputs_in_expr(&for_stmt.range.start, usage); - mark_route_inputs_in_expr(&for_stmt.range.end, usage); - mark_route_inputs_in_statements(&for_stmt.body, usage); + mark_route_inputs_in_expr(&for_stmt.range.start, declaration_slots, usage); + mark_route_inputs_in_expr(&for_stmt.range.end, declaration_slots, usage); + mark_route_inputs_in_statements(&for_stmt.body, declaration_slots, usage); } } } } -fn mark_route_inputs_in_expr(expr: &ExecutionExpr, usage: &mut [bool]) { +fn mark_route_inputs_in_expr( + expr: &ExecutionExpr, + declaration_slots: &BTreeMap, + usage: &mut [bool], +) { match &expr.kind { ExecutionExprKind::Literal(_) => {} - ExecutionExprKind::Load(ExecutionLoad::RouteInput(index)) => { - if let Some(slot) = usage.get_mut(*index) { + ExecutionExprKind::Load(ExecutionLoad::RouteInput { route, .. }) => { + if let Some(slot) = declaration_slots + .get(route) + .and_then(|index| usage.get_mut(*index)) + { *slot = true; } } ExecutionExprKind::Load(_) => {} - ExecutionExprKind::Unary { expr, .. } => mark_route_inputs_in_expr(expr, usage), + ExecutionExprKind::Unary { expr, .. } => { + mark_route_inputs_in_expr(expr, declaration_slots, usage) + } ExecutionExprKind::Binary { lhs, rhs, .. } => { - mark_route_inputs_in_expr(lhs, usage); - mark_route_inputs_in_expr(rhs, usage); + mark_route_inputs_in_expr(lhs, declaration_slots, usage); + mark_route_inputs_in_expr(rhs, declaration_slots, usage); } ExecutionExprKind::Call { args, .. } => { for arg in args { - mark_route_inputs_in_expr(arg, usage); + mark_route_inputs_in_expr(arg, declaration_slots, usage); } } } } + +#[cfg(test)] +mod tests { + use super::*; + use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model}; + + fn load_model_info(src: &str) -> NativeModelInfo { + let model = parse_model(src).expect("model parses"); + let typed = analyze_model(&model).expect("model analyzes"); + let lowered = lower_typed_model(&typed).expect("model lowers"); + NativeModelInfo::from_execution_model(&lowered) + } + + #[test] + fn declaration_first_routes_inject_by_default() { + let info = load_model_info( + r#" +model implicit_route_injection { + kind ode + states { central } + routes { iv -> central } + dynamics { + ddt(central) = 0 + } + outputs { + cp = central + } +} +"#, + ); + + assert_eq!(info.routes.len(), 1); + assert!(info.routes[0].inject_input_to_destination); + } + + #[test] + fn explicit_rate_usage_disables_automatic_injection() { + let info = load_model_info( + r#" +model explicit_route_usage { + kind ode + states { central } + routes { iv -> central } + dynamics { + ddt(central) = rate(iv) + } + outputs { + cp = central + } +} +"#, + ); + + assert_eq!(info.routes.len(), 1); + assert!(!info.routes[0].inject_input_to_destination); + } + + #[test] + fn authoring_shared_channel_routes_keep_declaration_specific_injection() { + let info = load_model_info( + r#" +name = shared_authoring +kind = ode + +params = ka, ke, v +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#, + ); + + assert_eq!(info.route_len, 1); + assert_eq!(info.routes.len(), 2); + assert_eq!(info.routes[0].kind, Some(RouteKind::Bolus)); + assert_eq!(info.routes[1].kind, Some(RouteKind::Infusion)); + assert_eq!(info.routes[0].index, 0); + assert_eq!(info.routes[1].index, 0); + assert!(info.routes[0].inject_input_to_destination); + assert!(!info.routes[1].inject_input_to_destination); + } +} diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 4a94715f..202fd45a 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -14,7 +14,7 @@ use cranelift_jit::JITModule; #[cfg(feature = "dsl-aot-load")] use libloading::Library; use pharmsol_dsl::execution::KernelRole; -use pharmsol_dsl::AnalyticalKernel; +use pharmsol_dsl::{AnalyticalKernel, RouteKind}; pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, @@ -264,12 +264,74 @@ impl RuntimeArtifact for NativeExecutionArtifact { #[derive(Clone, Debug)] struct SharedNativeModel { info: Arc, + route_semantics: Arc, artifact: Arc, } +#[derive(Clone, Debug)] +struct RouteInputSemantics { + bolus_destinations: Vec>, + infusion_inputs: Vec, + injected_infusion_destinations: Vec>, +} + +impl RouteInputSemantics { + fn from_model_info(info: &NativeModelInfo) -> Self { + let mut bolus_destinations = vec![None; info.route_len]; + let mut infusion_inputs = vec![false; info.route_len]; + let mut injected_infusion_destinations = vec![None; info.route_len]; + + for route in &info.routes { + match route.kind { + Some(RouteKind::Bolus) => { + bolus_destinations[route.index] = Some(route.destination_offset); + } + Some(RouteKind::Infusion) => { + infusion_inputs[route.index] = true; + if route.inject_input_to_destination { + injected_infusion_destinations[route.index] = + Some(route.destination_offset); + } + } + None => { + bolus_destinations[route.index] = Some(route.destination_offset); + infusion_inputs[route.index] = true; + if route.inject_input_to_destination { + injected_infusion_destinations[route.index] = + Some(route.destination_offset); + } + } + } + } + + Self { + bolus_destinations, + infusion_inputs, + injected_infusion_destinations, + } + } + + fn supports_input(&self, input: usize, kind: RouteKind) -> bool { + match kind { + RouteKind::Bolus => self + .bolus_destinations + .get(input) + .copied() + .flatten() + .is_some(), + RouteKind::Infusion => self.infusion_inputs.get(input).copied().unwrap_or(false), + } + } + + fn bolus_destination(&self, input: usize) -> Option { + self.bolus_destinations.get(input).copied().flatten() + } +} + impl SharedNativeModel { fn new(info: NativeModelInfo, artifact: impl RuntimeArtifact + 'static) -> Self { Self { + route_semantics: Arc::new(RouteInputSemantics::from_model_info(&info)), info: Arc::new(info), artifact: Arc::new(artifact), } @@ -313,6 +375,18 @@ impl SharedNativeModel { Ok(()) } + fn validate_input_for_kind(&self, input: usize, kind: RouteKind) -> Result<(), PharmsolError> { + self.validate_input(input)?; + if self.route_semantics.supports_input(input, kind) { + return Ok(()); + } + + Err(PharmsolError::OtherError(format!( + "model `{}` does not declare a {:?} route for input channel {}", + self.info.name, kind, input + ))) + } + fn fill_cov_buffer(&self, covariates: &Covariates, time: f64, buf: &mut [f64]) { for covariate in &self.info.covariates { buf[covariate.index] = match covariates.get_covariate(&covariate.name) { @@ -323,9 +397,14 @@ impl SharedNativeModel { } fn apply_route_inputs_to_rates(&self, rates: &mut [f64], route_inputs: &[f64]) { - for route in &self.info.routes { - if route.inject_input_to_destination { - rates[route.destination_offset] += route_inputs[route.index]; + for (input, destination) in self + .route_semantics + .injected_infusion_destinations + .iter() + .enumerate() + { + if let Some(destination) = destination { + rates[*destination] += route_inputs[input]; } } } @@ -451,7 +530,7 @@ impl SharedNativeModel { for event in events.iter_mut() { if let Event::Bolus(bolus) = event { - self.validate_input(bolus.input())?; + self.validate_input_for_kind(bolus.input(), RouteKind::Bolus)?; if self.artifact.has_kernel(KernelRole::RouteLag) { lag_values.fill(0.0); @@ -525,9 +604,17 @@ impl SharedNativeModel { input: usize, amount: f64, ) -> Result<(), PharmsolError> { - self.validate_input(input)?; - let destination = &self.info.routes[input]; - state[destination.destination_offset] += amount; + self.validate_input_for_kind(input, RouteKind::Bolus)?; + let destination = self + .route_semantics + .bolus_destination(input) + .ok_or_else(|| { + PharmsolError::OtherError(format!( + "model `{}` does not declare a bolus route for input channel {}", + self.info.name, input + )) + })?; + state[destination] += amount; Ok(()) } @@ -654,7 +741,8 @@ impl NativeOdeModel { .collect::>(); for infusion in &infusions { - self.shared.validate_input(infusion.input())?; + self.shared + .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; } let mut events = occasion.process_events(None, true); @@ -919,7 +1007,8 @@ impl NativeAnalyticalModel { .collect::>(); for infusion in &infusions { - self.shared.validate_input(infusion.input())?; + self.shared + .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; } let mut events = occasion.process_events(None, true); @@ -959,6 +1048,7 @@ impl NativeAnalyticalModel { if let Some(next_event) = events.get(index + 1) { self.solve_interval( + &mut *session, &mut state, support_point, occasion.covariates(), @@ -975,6 +1065,7 @@ impl NativeAnalyticalModel { fn solve_interval( &self, + session: &mut dyn KernelSession, state: &mut [f64], support_point: &[f64], covariates: &Covariates, @@ -1001,11 +1092,25 @@ impl NativeAnalyticalModel { breakpoints.dedup_by(|lhs, rhs| (*lhs - *rhs).abs() < 1e-12); let mut current = breakpoints[0]; - let projected = project_analytical_parameters(&self.shared.info, support_point)?; + let mut cov_buf = vec![0.0; self.shared.info.covariates.len()]; + let mut derived = vec![0.0; self.shared.info.derived_len]; for next in breakpoints.iter().copied().skip(1) { let dt = next - current; - let route_inputs = active_route_inputs(infusions, current, self.shared.info.route_len); + let route_inputs = + interval_route_inputs(infusions, current, next, self.shared.info.route_len); + self.shared.refresh_derived( + session, + next, + state, + support_point, + covariates, + &route_inputs, + &mut derived, + &mut cov_buf, + )?; + let projected = + project_analytical_parameters(&self.shared.info, support_point, &derived)?; let next_state = apply_analytical_kernel( self.shared.info.analytical.ok_or_else(|| { PharmsolError::OtherError(format!( @@ -1073,7 +1178,8 @@ impl NativeSdeModel { .collect::>(); for infusion in &infusions { - self.shared.validate_input(infusion.input())?; + self.shared + .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; } let mut events = occasion.process_events(None, true); @@ -1302,6 +1408,22 @@ fn active_route_inputs(infusions: &[Infusion], time: f64, route_len: usize) -> V values } +fn interval_route_inputs( + infusions: &[Infusion], + start_time: f64, + end_time: f64, + route_len: usize, +) -> Vec { + let mut values = vec![0.0; route_len]; + for infusion in infusions { + let finish = infusion.time() + infusion.duration(); + if infusion.input() < route_len && start_time >= infusion.time() && end_time <= finish { + values[infusion.input()] += infusion.amount() / infusion.duration(); + } + } + values +} + fn sort_events(events: &mut [Event]) { events.sort_by(|lhs, rhs| { fn order(event: &Event) -> u8 { @@ -1323,6 +1445,7 @@ fn sort_events(events: &mut [Event]) { fn project_analytical_parameters( info: &NativeModelInfo, support_point: &[f64], + derived: &[f64], ) -> Result { let kernel = info.analytical.ok_or_else(|| { PharmsolError::OtherError(format!( @@ -1339,6 +1462,13 @@ fn project_analytical_parameters( support_point.len() ))); } + + // Analytical authoring models can project kernel arguments through a derive + // kernel by declaring exactly the built-in kernel arity in `derived`. + if derived.len() == arity { + return Ok(V::from_vec(derived.to_vec(), NalgebraContext)); + } + Ok(V::from_vec( support_point[..arity].to_vec(), NalgebraContext, diff --git a/src/dsl/rust_backend.rs b/src/dsl/rust_backend.rs index 19b3e5cc..850e13b7 100644 --- a/src/dsl/rust_backend.rs +++ b/src/dsl/rust_backend.rs @@ -264,7 +264,7 @@ fn emit_load(load: &ExecutionLoad, ty: ValueType) -> Result { ExecutionLoad::Covariate(index) => format!("load_f64(covariates, {index})"), ExecutionLoad::Derived(index) => format!("load_f64(derived, {index})"), ExecutionLoad::Local(index) => return Ok(format!("local_{index}")), - ExecutionLoad::RouteInput(index) => format!("load_f64(routes, {index})"), + ExecutionLoad::RouteInput { index, .. } => format!("load_f64(routes, {index})"), ExecutionLoad::State(state) => { let index = emit_state_ref_index(state)?; format!("load_f64(states, {index})") diff --git a/src/dsl/wasm.rs b/src/dsl/wasm.rs index 16884952..f2504d44 100644 --- a/src/dsl/wasm.rs +++ b/src/dsl/wasm.rs @@ -778,7 +778,9 @@ mod tests { covariates: Vec::new(), routes: vec![NativeRouteInfo { name: "oral".to_string(), + declaration_index: 0, index: 0, + kind: None, destination_offset: 0, inject_input_to_destination: true, }], diff --git a/src/dsl/wasm_compile.rs b/src/dsl/wasm_compile.rs index 66995e8a..caa60216 100644 --- a/src/dsl/wasm_compile.rs +++ b/src/dsl/wasm_compile.rs @@ -848,7 +848,7 @@ mod tests { }; const SIMPLE_SOURCE: &str = r#" -model = example_ode +name = example_ode kind = ode params = ke, v @@ -901,7 +901,7 @@ out(cp) = central / v ~ continuous() cache .compile_module_source_to_wasm_module( r#" -model = second_ode +name = second_ode kind = ode params = ke, v @@ -949,7 +949,7 @@ out(cp) = central / v ~ continuous() #[test] fn compile_module_source_to_wasm_module_preserves_semantic_diagnostic_structure() { let source = r#" -model = broken +name = broken kind = ode states = central @@ -995,7 +995,7 @@ out(cp) = central ~ continuous() #[test] fn compile_module_source_to_wasm_module_preserves_lowering_diagnostic_structure() { let source = r#" -model = broken +name = broken kind = ode states = transit[4], central diff --git a/src/dsl/wasm_direct_emitter.rs b/src/dsl/wasm_direct_emitter.rs index 2d92ad1d..857f2ac7 100644 --- a/src/dsl/wasm_direct_emitter.rs +++ b/src/dsl/wasm_direct_emitter.rs @@ -922,7 +922,7 @@ fn emit_load( function.instruction(&Instruction::LocalGet(local.wasm_local)); emit_cast_stack(local.ty, target_ty, function, state.model_name) } - ExecutionLoad::RouteInput(index) => emit_dense_load( + ExecutionLoad::RouteInput { index, .. } => emit_dense_load( function, KERNEL_PARAM_ROUTES, *index, diff --git a/src/lib.rs b/src/lib.rs index f2691579..c84d4ee1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,28 +28,31 @@ pub use crate::data::Interpolation::*; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::data::*; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] -pub use crate::equation::*; -#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::optimize::effect::get_e2; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::optimize::spp::SppOptimizer; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub use crate::simulator::equation::analytical::*; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] +pub use crate::simulator::equation::metadata; +#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use crate::simulator::equation::{ self, ode::{ExplicitRkTableau, OdeSolver, SdirkTableau}, - ODE, + Analytical, AnalyticalKernel, Cache, Equation, ModelKind, ModelMetadata, ModelMetadataError, + NameDomain, Predictions, RouteInputPolicy, RouteKind, State, ValidatedModelMetadata, ODE, SDE, }; pub use error::PharmsolError; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use nalgebra::dmatrix; -pub use pharmsol_macros::ode; +pub use pharmsol_macros::{analytical, ode, sde}; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use std::collections::HashMap; /// Prelude module that re-exports all commonly used types and traits. /// -/// Use `use pharmsol::prelude::*;` to import everything needed for basic -/// pharmacometric modeling. +/// Importing `pharmsol::prelude::*` brings the main modeling, simulation, +/// and data APIs into scope. /// /// # Example /// ```rust @@ -92,7 +95,7 @@ pub mod prelude { pub use crate::data::auc::{auc, auc_interval, aumc, interpolate_linear}; #[allow(deprecated)] - // Simulator submodule for internal use and advanced users + // Simulator submodule for organized access to simulation types. pub mod simulator { pub use crate::simulator::{ cache::{self, PredictionCache, SdeLikelihoodCache, DEFAULT_CACHE_SIZE}, @@ -136,6 +139,8 @@ pub mod prelude { // Re-export macros (they are exported at crate root via #[macro_export]) #[doc(inline)] + pub use crate::analytical; + #[doc(inline)] pub use crate::fa; #[doc(inline)] pub use crate::fetch_cov; @@ -145,6 +150,8 @@ pub mod prelude { pub use crate::lag; #[doc(inline)] pub use crate::ode; + #[doc(inline)] + pub use crate::sde; } #[macro_export] diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index 0aff0936..4734886c 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -8,6 +8,8 @@ pub mod two_compartment_models; use diffsol::{NalgebraContext, Vector, VectorHost}; pub use one_compartment_cl_models::*; pub use one_compartment_models::*; +use pharmsol_dsl::ModelKind; +use thiserror::Error; pub use three_compartment_cl_models::*; pub use three_compartment_models::*; pub use two_compartment_cl_models::*; @@ -15,12 +17,28 @@ pub use two_compartment_models::*; use super::spphash; +use super::{ + EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, + ValidatedModelMetadata, +}; use crate::data::error_model::AssayErrorModels; use crate::simulator::cache::{PredictionCache, DEFAULT_CACHE_SIZE}; use crate::PharmsolError; -use crate::{ - data::Covariates, simulator::*, Equation, EquationPriv, EquationTypes, Observation, Subject, -}; +use crate::{data::Covariates, simulator::*, Observation, Subject}; + +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum AnalyticalMetadataError { + #[error(transparent)] + Validation(#[from] ModelMetadataError), + #[error("analytical model declares {declared} state metadata entries but model has {expected} states")] + StateCountMismatch { expected: usize, declared: usize }, + #[error( + "analytical model declares {declared} route metadata entries but model has {expected} input channels" + )] + RouteCountMismatch { expected: usize, declared: usize }, + #[error("analytical model declares {declared} output metadata entries but model has {expected} outputs")] + OutputCountMismatch { expected: usize, declared: usize }, +} /// Model equation using analytical solutions. /// @@ -35,6 +53,7 @@ pub struct Analytical { init: Init, out: Out, neqs: Neqs, + metadata: Option, cache: Option, } @@ -88,6 +107,7 @@ impl Analytical { init, out, neqs: Neqs::default(), + metadata: None, cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), } } @@ -95,20 +115,94 @@ impl Analytical { /// Set the number of state variables. pub fn with_nstates(mut self, nstates: usize) -> Self { self.neqs.nstates = nstates; + self.invalidate_metadata(); self } /// Set the number of drug input channels (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; + self.invalidate_metadata(); self } /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { self.neqs.nout = nout; + self.invalidate_metadata(); self } + + /// Attach validated handwritten-model metadata to this analytical model. + pub fn with_metadata( + mut self, + metadata: ModelMetadata, + ) -> Result { + let metadata = metadata.validate_for(ModelKind::Analytical)?; + validate_metadata_dimensions(&metadata, &self.neqs)?; + self.metadata = Some(metadata); + Ok(self) + } + + /// Access the validated metadata attached to this analytical model, if any. + pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.metadata()?.parameter_index(name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.metadata()?.covariate_index(name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.metadata()?.state_index(name) + } + + pub fn route_index(&self, name: &str) -> Option { + self.metadata()?.route_index(name) + } + + pub fn output_index(&self, name: &str) -> Option { + self.metadata()?.output_index(name) + } + + fn invalidate_metadata(&mut self) { + self.metadata = None; + } +} + +fn validate_metadata_dimensions( + metadata: &ValidatedModelMetadata, + neqs: &Neqs, +) -> Result<(), AnalyticalMetadataError> { + let declared_states = metadata.states().len(); + if declared_states != neqs.nstates { + return Err(AnalyticalMetadataError::StateCountMismatch { + expected: neqs.nstates, + declared: declared_states, + }); + } + + let declared_routes = metadata.route_channel_count(); + if declared_routes != neqs.ndrugs { + return Err(AnalyticalMetadataError::RouteCountMismatch { + expected: neqs.ndrugs, + declared: declared_routes, + }); + } + + let declared_outputs = metadata.outputs().len(); + if declared_outputs != neqs.nout { + return Err(AnalyticalMetadataError::OutputCountMismatch { + expected: neqs.nout, + declared: declared_outputs, + }); + } + + Ok(()) } impl super::Cache for Analytical { @@ -302,6 +396,7 @@ pub(crate) mod tests { use crate::SubjectBuilderExt; use approx::assert_relative_eq; use diffsol::Vector; + use pharmsol_dsl::AnalyticalKernel; use std::collections::HashMap; pub(crate) enum SubjectInfo { @@ -423,6 +518,158 @@ pub(crate) mod tests { assert_eq!(predictions.predictions()[0].prediction(), 4.0); } + fn simple_analytical() -> Analytical { + let eq = |x: &V, _p: &V, _dt: f64, _rateiv: &V, _cov: &Covariates| x.clone(); + let seq_eq = |_params: &mut V, _t: f64, _cov: &Covariates| {}; + let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[0]; + }; + + Analytical::new(eq, seq_eq, lag, fa, init, out) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + } + + #[test] + fn handwritten_analytical_metadata_exposes_name_lookup() { + let analytical = simple_analytical() + .with_metadata( + super::super::metadata::new("one_cmt_analytical") + .parameters(["ke", "v"]) + .covariates([super::super::Covariate::continuous("wt")]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate"); + + assert_eq!(analytical.parameter_index("ke"), Some(0)); + assert_eq!(analytical.parameter_index("v"), Some(1)); + assert_eq!(analytical.covariate_index("wt"), Some(0)); + assert_eq!(analytical.state_index("central"), Some(0)); + assert_eq!(analytical.route_index("iv"), Some(0)); + assert_eq!(analytical.output_index("cp"), Some(0)); + assert_eq!( + analytical.metadata().expect("metadata exists").kind(), + ModelKind::Analytical + ); + } + + #[test] + fn handwritten_analytical_without_metadata_keeps_raw_path() { + let analytical = simple_analytical(); + + assert!(analytical.metadata().is_none()); + assert_eq!(analytical.state_index("central"), None); + assert_eq!(analytical.route_index("iv"), None); + assert_eq!(analytical.output_index("cp"), None); + } + + #[test] + fn handwritten_analytical_rejects_dimension_mismatches() { + let error = simple_analytical() + .with_metadata( + super::super::metadata::new("wrong_outputs") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp", "auc"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect_err("output-count mismatches must fail"); + + assert_eq!( + error, + AnalyticalMetadataError::OutputCountMismatch { + expected: 1, + declared: 2, + } + ); + } + + #[test] + fn handwritten_analytical_rejects_particles_metadata() { + let error = simple_analytical() + .with_metadata( + super::super::metadata::new("invalid_particles") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")) + .particles(64), + ) + .expect_err("analytical metadata must reject particles"); + + assert_eq!( + error, + AnalyticalMetadataError::Validation(ModelMetadataError::ParticlesNotAllowed { + kind: ModelKind::Analytical, + }) + ); + } + + #[test] + fn built_in_analytical_models_can_advertise_kernel_identity() { + let seq_eq = |_params: &mut V, _t: f64, _cov: &Covariates| {}; + let lag = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let fa = |_p: &V, _t: f64, _cov: &Covariates| HashMap::new(); + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[1]; + }; + + let analytical = + Analytical::new(one_compartment_with_absorption, seq_eq, lag, fa, init, out) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("one_cmt_abs") + .parameters(["ka", "ke", "v"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + super::super::Route::bolus("oral").to_state("gut"), + super::super::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("built-in analytical metadata should validate"); + + assert_eq!( + analytical + .metadata() + .expect("metadata exists") + .analytical_kernel(), + Some(AnalyticalKernel::OneCompartmentWithAbsorption) + ); + assert_eq!(analytical.route_index("oral"), Some(0)); + assert_eq!(analytical.route_index("iv"), Some(0)); + } + + #[test] + fn changing_dimensions_after_metadata_clears_analytical_metadata() { + let analytical = simple_analytical() + .with_metadata( + super::super::metadata::new("one_cmt_analytical") + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate") + .with_ndrugs(2); + + assert!(analytical.metadata().is_none()); + assert_eq!(analytical.route_index("iv"), None); + } + fn assert_pm_wrapper_matches_native( native: AnalyticalEq, wrapper: AnalyticalEq, @@ -567,8 +814,8 @@ impl Equation for Analytical { ypred.log_likelihood(error_models) } - fn kind() -> crate::EqnKind { - crate::EqnKind::Analytical + fn kind() -> EqnKind { + EqnKind::Analytical } } diff --git a/src/simulator/equation/meta.rs b/src/simulator/equation/meta.rs deleted file mode 100644 index 1b38ae35..00000000 --- a/src/simulator/equation/meta.rs +++ /dev/null @@ -1,64 +0,0 @@ -#[repr(C)] -#[derive(Debug, Clone)] -/// Model metadata container. -/// -/// This structure holds the metadata associated with a pharmacometric model, -/// including parameter names and other model-specific information that needs -/// to be preserved across simulation and estimation activities. -/// -/// # Examples -/// -/// ``` -/// use pharmsol::simulator::equation::Meta; -/// -/// let model_metadata = Meta::new(vec!["CL", "V", "KA"]); -/// assert_eq!(model_metadata.get_params().len(), 3); -/// ``` -pub struct Meta { - params: Vec, -} - -impl Meta { - /// Creates a new metadata container with the specified parameter names. - /// - /// # Arguments - /// - /// * `params` - A vector of string slices representing parameter names - /// - /// # Returns - /// - /// A new `Meta` instance containing the converted parameter names - /// - /// # Examples - /// - /// ``` - /// use pharmsol::simulator::equation::Meta; - /// - /// let metadata = Meta::new(vec!["CL", "V", "KA"]); - /// ``` - pub fn new(params: Vec<&str>) -> Self { - let params = params.iter().map(|x| x.to_string()).collect(); - Meta { params } - } - - /// Retrieves the parameter names stored in this metadata container. - /// - /// # Returns - /// - /// A reference to the vector of parameter names - /// - /// # Examples - /// - /// ``` - /// use pharmsol::simulator::equation::Meta; - /// - /// let metadata = Meta::new(vec!["CL", "V", "KA"]); - /// let params = metadata.get_params(); - /// assert_eq!(params[0], "CL"); - /// assert_eq!(params[1], "V"); - /// assert_eq!(params[2], "KA"); - /// ``` - pub fn get_params(&self) -> &Vec { - &self.params - } -} diff --git a/src/simulator/equation/metadata.rs b/src/simulator/equation/metadata.rs new file mode 100644 index 00000000..ecf51a52 --- /dev/null +++ b/src/simulator/equation/metadata.rs @@ -0,0 +1,1211 @@ +//! Shared model metadata for handwritten simulator models. +//! +//! This module defines the public metadata contract that handwritten ODE, +//! analytical, and SDE models can attach to. The field set is intentionally +//! aligned with the public subset of the DSL/runtime metadata surface. +//! +//! Internal runtime layout details such as dense buffer lengths, derived buffer +//! shape, or ABI-specific offsets remain internal for now. + +use pharmsol_dsl::{AnalyticalKernel, CovariateInterpolation, ModelKind}; +use std::fmt; +use thiserror::Error; + +/// Create a new handwritten-model metadata builder. +pub fn new(name: impl Into) -> ModelMetadata { + ModelMetadata::new(name) +} + +/// Validation errors for handwritten model metadata. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +pub enum ModelMetadataError { + #[error("model kind is required for metadata validation")] + MissingModelKind, + #[error("metadata declares kind `{declared:?}` but validation requested `{requested:?}`")] + ModelKindConflict { + declared: ModelKind, + requested: ModelKind, + }, + #[error("duplicate {domain} name `{name}`")] + DuplicateName { domain: NameDomain, name: String }, + #[error("route `{route}` must declare a destination state")] + MissingRouteDestination { route: String }, + #[error("route `{route}` targets unknown state `{destination}`")] + UnknownRouteDestination { route: String, destination: String }, + #[error("infusion route `{route}` cannot declare lag")] + InfusionLagNotAllowed { route: String }, + #[error("infusion route `{route}` cannot declare bioavailability")] + InfusionBioavailabilityNotAllowed { route: String }, + #[error("{kind:?} metadata cannot declare particles")] + ParticlesNotAllowed { kind: ModelKind }, + #[error("Sde metadata requires particles")] + MissingParticles, + #[error( + "metadata declares {declared} particle(s) but validation provided {fallback} fallback particle(s)" + )] + ParticleCountConflict { declared: usize, fallback: usize }, + #[error("{kind:?} metadata cannot declare an analytical kernel")] + AnalyticalKernelNotAllowed { kind: ModelKind }, +} + +/// Name domain used in duplicate-name validation messages. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum NameDomain { + Parameter, + Covariate, + State, + Route, + Output, +} + +impl fmt::Display for NameDomain { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let domain = match self { + Self::Parameter => "parameter", + Self::Covariate => "covariate", + Self::State => "state", + Self::Route => "route", + Self::Output => "output", + }; + f.write_str(domain) + } +} + +/// Immutable validated metadata view used by later attachment slices. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidatedModelMetadata { + name: String, + kind: ModelKind, + parameters: Vec, + covariates: Vec, + states: Vec, + routes: Vec, + route_channel_count: usize, + outputs: Vec, + particles: Option, + analytical: Option, +} + +impl ValidatedModelMetadata { + pub fn name(&self) -> &str { + &self.name + } + + pub fn kind(&self) -> ModelKind { + self.kind + } + + pub fn parameters(&self) -> &[Parameter] { + &self.parameters + } + + pub fn covariates(&self) -> &[Covariate] { + &self.covariates + } + + pub fn states(&self) -> &[State] { + &self.states + } + + pub fn routes(&self) -> &[ValidatedRoute] { + &self.routes + } + + pub fn route_channel_count(&self) -> usize { + self.route_channel_count + } + + pub fn outputs(&self) -> &[Output] { + &self.outputs + } + + pub fn particles(&self) -> Option { + self.particles + } + + pub fn analytical_kernel(&self) -> Option { + self.analytical + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.parameters + .iter() + .position(|parameter| parameter.name() == name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.covariates + .iter() + .position(|covariate| covariate.name() == name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.states.iter().position(|state| state.name() == name) + } + + pub fn route_index(&self, name: &str) -> Option { + self.route(name).map(ValidatedRoute::channel_index) + } + + pub fn route_declaration_index(&self, name: &str) -> Option { + self.routes.iter().position(|route| route.name() == name) + } + + pub fn output_index(&self, name: &str) -> Option { + self.outputs.iter().position(|output| output.name() == name) + } + + pub fn parameter(&self, name: &str) -> Option<&Parameter> { + self.parameter_index(name) + .map(|index| &self.parameters[index]) + } + + pub fn covariate(&self, name: &str) -> Option<&Covariate> { + self.covariate_index(name) + .map(|index| &self.covariates[index]) + } + + pub fn state(&self, name: &str) -> Option<&State> { + self.state_index(name).map(|index| &self.states[index]) + } + + pub fn route(&self, name: &str) -> Option<&ValidatedRoute> { + self.route_declaration_index(name) + .map(|index| &self.routes[index]) + } + + pub fn output(&self, name: &str) -> Option<&Output> { + self.output_index(name).map(|index| &self.outputs[index]) + } +} + +/// One validated route declaration with resolved destination state index. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ValidatedRoute { + name: String, + kind: RouteKind, + declaration_index: usize, + channel_index: usize, + destination: String, + destination_index: usize, + has_lag: bool, + has_bioavailability: bool, + input_policy: Option, +} + +impl ValidatedRoute { + pub fn name(&self) -> &str { + &self.name + } + + pub fn kind(&self) -> RouteKind { + self.kind + } + + pub fn declaration_index(&self) -> usize { + self.declaration_index + } + + pub fn channel_index(&self) -> usize { + self.channel_index + } + + pub fn destination(&self) -> &str { + &self.destination + } + + pub fn destination_index(&self) -> usize { + self.destination_index + } + + pub fn has_lag(&self) -> bool { + self.has_lag + } + + pub fn has_bioavailability(&self) -> bool { + self.has_bioavailability + } + + pub fn input_policy(&self) -> Option { + self.input_policy + } +} + +/// Metadata describing one handwritten simulator model. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModelMetadata { + name: String, + kind: Option, + parameters: Vec, + covariates: Vec, + states: Vec, + routes: Vec, + outputs: Vec, + particles: Option, + analytical: Option, +} + +impl ModelMetadata { + /// Create a new metadata builder with a model name. + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + kind: None, + parameters: Vec::new(), + covariates: Vec::new(), + states: Vec::new(), + routes: Vec::new(), + outputs: Vec::new(), + particles: None, + analytical: None, + } + } + + /// Set the model kind explicitly. + pub fn kind(mut self, kind: ModelKind) -> Self { + self.kind = Some(kind); + self + } + + /// Replace the ordered parameter list. + pub fn parameters(mut self, parameters: I) -> Self + where + I: IntoIterator, + Parameter: From, + { + self.parameters = parameters.into_iter().map(Parameter::from).collect(); + self + } + + /// Replace the ordered covariate list. + pub fn covariates(mut self, covariates: I) -> Self + where + I: IntoIterator, + { + self.covariates = covariates.into_iter().collect(); + self + } + + /// Replace the ordered state list. + pub fn states(mut self, states: I) -> Self + where + I: IntoIterator, + State: From, + { + self.states = states.into_iter().map(State::from).collect(); + self + } + + /// Add one route declaration. + pub fn route(mut self, route: Route) -> Self { + self.routes.push(route); + self + } + + /// Extend with multiple route declarations. + pub fn routes(mut self, routes: I) -> Self + where + I: IntoIterator, + { + self.routes.extend(routes); + self + } + + /// Replace the ordered output list. + pub fn outputs(mut self, outputs: I) -> Self + where + I: IntoIterator, + Output: From, + { + self.outputs = outputs.into_iter().map(Output::from).collect(); + self + } + + /// Set the particle count for stochastic models. + pub fn particles(mut self, particles: usize) -> Self { + self.particles = Some(particles); + self + } + + /// Set the analytical kernel identity for built-in analytical models. + pub fn analytical_kernel(mut self, analytical: AnalyticalKernel) -> Self { + self.analytical = Some(analytical); + self + } + + /// Get the model name. + pub fn name(&self) -> &str { + &self.name + } + + /// Get the explicit model kind, if already declared. + pub fn kind_decl(&self) -> Option { + self.kind + } + + /// Get the ordered parameter metadata. + pub fn parameters_decl(&self) -> &[Parameter] { + &self.parameters + } + + /// Get the ordered covariate metadata. + pub fn covariates_decl(&self) -> &[Covariate] { + &self.covariates + } + + /// Get the ordered state metadata. + pub fn states_decl(&self) -> &[State] { + &self.states + } + + /// Get the ordered route metadata. + pub fn routes_decl(&self) -> &[Route] { + &self.routes + } + + /// Get the ordered output metadata. + pub fn outputs_decl(&self) -> &[Output] { + &self.outputs + } + + /// Get the declared particle count. + pub fn particles_decl(&self) -> Option { + self.particles + } + + /// Get the declared analytical kernel identity. + pub fn analytical_kernel_decl(&self) -> Option { + self.analytical + } + + /// Validate this metadata using its declared kind. + pub fn validate(self) -> Result { + self.validate_internal(None, None) + } + + /// Validate this metadata for a specific model kind. + pub fn validate_for( + self, + kind: ModelKind, + ) -> Result { + self.validate_internal(Some(kind), None) + } + + /// Validate this metadata for a specific model kind, using a fallback + /// particle count when the metadata itself does not declare one. + pub fn validate_for_with_particles( + self, + kind: ModelKind, + fallback_particles: usize, + ) -> Result { + self.validate_internal(Some(kind), Some(fallback_particles)) + } + + fn validate_internal( + self, + requested_kind: Option, + fallback_particles: Option, + ) -> Result { + let kind = resolve_kind(self.kind, requested_kind)?; + validate_unique_names(&self.parameters, NameDomain::Parameter, Parameter::name)?; + validate_unique_names(&self.covariates, NameDomain::Covariate, Covariate::name)?; + validate_unique_names(&self.states, NameDomain::State, State::name)?; + validate_unique_names(&self.routes, NameDomain::Route, Route::name)?; + validate_unique_names(&self.outputs, NameDomain::Output, Output::name)?; + + let particles = resolve_particles(kind, self.particles, fallback_particles)?; + validate_kind_specific_fields(kind, self.analytical, particles)?; + + let (routes, route_channel_count) = validate_routes(self.routes, &self.states)?; + + Ok(ValidatedModelMetadata { + name: self.name, + kind, + parameters: self.parameters, + covariates: self.covariates, + states: self.states, + routes, + route_channel_count, + outputs: self.outputs, + particles, + analytical: self.analytical, + }) + } +} + +/// One named parameter in model order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Parameter { + name: String, +} + +impl Parameter { + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + pub fn name(&self) -> &str { + &self.name + } +} + +impl From for Parameter +where + S: Into, +{ + fn from(value: S) -> Self { + Self::new(value) + } +} + +/// One named covariate plus interpolation semantics. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Covariate { + name: String, + interpolation: Option, +} + +impl Covariate { + pub fn new(name: impl Into) -> Self { + Self { + name: name.into(), + interpolation: None, + } + } + + pub fn continuous(name: impl Into) -> Self { + Self::new(name).with_interpolation(CovariateInterpolation::Linear) + } + + pub fn locf(name: impl Into) -> Self { + Self::new(name).with_interpolation(CovariateInterpolation::Locf) + } + + pub fn with_interpolation(mut self, interpolation: CovariateInterpolation) -> Self { + self.interpolation = Some(interpolation); + self + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn interpolation(&self) -> Option { + self.interpolation + } +} + +/// One named state in model order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct State { + name: String, +} + +impl State { + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + pub fn name(&self) -> &str { + &self.name + } +} + +impl From for State +where + S: Into, +{ + fn from(value: S) -> Self { + Self::new(value) + } +} + +/// One named output in model order. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Output { + name: String, +} + +impl Output { + pub fn new(name: impl Into) -> Self { + Self { name: name.into() } + } + + pub fn name(&self) -> &str { + &self.name + } +} + +impl From for Output +where + S: Into, +{ + fn from(value: S) -> Self { + Self::new(value) + } +} + +/// Route declaration kind. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RouteKind { + Bolus, + Infusion, +} + +/// How route inputs should be interpreted by the execution layer. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum RouteInputPolicy { + InjectToDestination, + ExplicitInputVector, +} + +/// One named route declaration. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct Route { + name: String, + kind: RouteKind, + destination: Option, + has_lag: bool, + has_bioavailability: bool, + input_policy: Option, +} + +impl Route { + pub fn bolus(name: impl Into) -> Self { + Self::new(name, RouteKind::Bolus) + } + + pub fn infusion(name: impl Into) -> Self { + Self::new(name, RouteKind::Infusion) + } + + pub fn new(name: impl Into, kind: RouteKind) -> Self { + Self { + name: name.into(), + kind, + destination: None, + has_lag: false, + has_bioavailability: false, + input_policy: None, + } + } + + pub fn to_state(mut self, destination: impl Into) -> Self { + self.destination = Some(destination.into()); + self + } + + pub fn with_lag(mut self) -> Self { + self.has_lag = true; + self + } + + pub fn with_bioavailability(mut self) -> Self { + self.has_bioavailability = true; + self + } + + pub fn inject_input_to_destination(mut self) -> Self { + self.input_policy = Some(RouteInputPolicy::InjectToDestination); + self + } + + pub fn expect_explicit_input(mut self) -> Self { + self.input_policy = Some(RouteInputPolicy::ExplicitInputVector); + self + } + + pub fn name(&self) -> &str { + &self.name + } + + pub fn kind(&self) -> RouteKind { + self.kind + } + + pub fn destination(&self) -> Option<&str> { + self.destination.as_deref() + } + + pub fn has_lag(&self) -> bool { + self.has_lag + } + + pub fn has_bioavailability(&self) -> bool { + self.has_bioavailability + } + + pub fn input_policy(&self) -> Option { + self.input_policy + } +} + +fn resolve_kind( + declared_kind: Option, + requested_kind: Option, +) -> Result { + match (declared_kind, requested_kind) { + (Some(declared), Some(requested)) if declared != requested => { + Err(ModelMetadataError::ModelKindConflict { + declared, + requested, + }) + } + (Some(declared), _) => Ok(declared), + (None, Some(requested)) => Ok(requested), + (None, None) => Err(ModelMetadataError::MissingModelKind), + } +} + +fn resolve_particles( + kind: ModelKind, + declared_particles: Option, + fallback_particles: Option, +) -> Result, ModelMetadataError> { + let particles = match (declared_particles, fallback_particles) { + (Some(declared), Some(fallback)) if declared != fallback => { + return Err(ModelMetadataError::ParticleCountConflict { declared, fallback }); + } + (Some(declared), _) => Some(declared), + (None, Some(fallback)) => Some(fallback), + (None, None) => None, + }; + + match kind { + ModelKind::Ode | ModelKind::Analytical if particles.is_some() => { + Err(ModelMetadataError::ParticlesNotAllowed { kind }) + } + ModelKind::Sde if particles.is_none() => Err(ModelMetadataError::MissingParticles), + _ => Ok(particles), + } +} + +fn validate_kind_specific_fields( + kind: ModelKind, + analytical: Option, + particles: Option, +) -> Result<(), ModelMetadataError> { + match kind { + ModelKind::Ode => { + if analytical.is_some() { + return Err(ModelMetadataError::AnalyticalKernelNotAllowed { kind }); + } + if particles.is_some() { + return Err(ModelMetadataError::ParticlesNotAllowed { kind }); + } + } + ModelKind::Analytical => { + if particles.is_some() { + return Err(ModelMetadataError::ParticlesNotAllowed { kind }); + } + } + ModelKind::Sde => { + if analytical.is_some() { + return Err(ModelMetadataError::AnalyticalKernelNotAllowed { kind }); + } + } + } + Ok(()) +} + +fn validate_unique_names( + values: &[T], + domain: NameDomain, + name_of: impl Fn(&T) -> &str, +) -> Result<(), ModelMetadataError> { + let mut names = std::collections::HashSet::with_capacity(values.len()); + for value in values { + let name = name_of(value); + if !names.insert(name) { + return Err(ModelMetadataError::DuplicateName { + domain, + name: name.to_string(), + }); + } + } + Ok(()) +} + +fn validate_routes( + routes: Vec, + states: &[State], +) -> Result<(Vec, usize), ModelMetadataError> { + let mut bolus_channels = 0; + let mut infusion_channels = 0; + let mut validated_routes = Vec::with_capacity(routes.len()); + + for (declaration_index, route) in routes.into_iter().enumerate() { + let channel_index = match route.kind { + RouteKind::Bolus => { + let index = bolus_channels; + bolus_channels += 1; + index + } + RouteKind::Infusion => { + let index = infusion_channels; + infusion_channels += 1; + index + } + }; + + validated_routes.push(validate_route( + route, + declaration_index, + channel_index, + states, + )?); + } + + Ok((validated_routes, bolus_channels.max(infusion_channels))) +} + +fn validate_route( + route: Route, + declaration_index: usize, + channel_index: usize, + states: &[State], +) -> Result { + if route.kind == RouteKind::Infusion && route.has_lag { + return Err(ModelMetadataError::InfusionLagNotAllowed { + route: route.name.clone(), + }); + } + + if route.kind == RouteKind::Infusion && route.has_bioavailability { + return Err(ModelMetadataError::InfusionBioavailabilityNotAllowed { + route: route.name.clone(), + }); + } + + let destination = + route + .destination + .clone() + .ok_or_else(|| ModelMetadataError::MissingRouteDestination { + route: route.name.clone(), + })?; + let destination_index = states + .iter() + .position(|state| state.name() == destination) + .ok_or_else(|| ModelMetadataError::UnknownRouteDestination { + route: route.name.clone(), + destination: destination.clone(), + })?; + + Ok(ValidatedRoute { + name: route.name, + kind: route.kind, + declaration_index, + channel_index, + destination, + destination_index, + has_lag: route.has_lag, + has_bioavailability: route.has_bioavailability, + input_policy: route.input_policy, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn builds_ode_metadata_shape() { + let metadata = new("bimodal_ke") + .kind(ModelKind::Ode) + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")); + + assert_eq!(metadata.name(), "bimodal_ke"); + assert_eq!(metadata.kind_decl(), Some(ModelKind::Ode)); + assert_eq!(metadata.parameters_decl()[0].name(), "ke"); + assert_eq!(metadata.parameters_decl()[1].name(), "v"); + assert_eq!(metadata.states_decl()[0].name(), "central"); + assert_eq!(metadata.outputs_decl()[0].name(), "cp"); + assert_eq!(metadata.routes_decl()[0].name(), "iv"); + assert_eq!(metadata.routes_decl()[0].kind(), RouteKind::Infusion); + assert_eq!(metadata.routes_decl()[0].destination(), Some("central")); + } + + #[test] + fn builds_analytical_metadata_shape() { + let metadata = new("one_cmt_abs") + .kind(ModelKind::Analytical) + .parameters(["ka", "ke", "v"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route(Route::bolus("oral").to_state("gut").with_bioavailability()) + .route(Route::infusion("iv").to_state("central")) + .analytical_kernel(AnalyticalKernel::OneCompartmentWithAbsorption); + + assert_eq!(metadata.kind_decl(), Some(ModelKind::Analytical)); + assert_eq!(metadata.states_decl()[0].name(), "gut"); + assert_eq!(metadata.states_decl()[1].name(), "central"); + assert_eq!(metadata.routes_decl()[0].kind(), RouteKind::Bolus); + assert!(metadata.routes_decl()[0].has_bioavailability()); + assert_eq!( + metadata.analytical_kernel_decl(), + Some(AnalyticalKernel::OneCompartmentWithAbsorption) + ); + } + + #[test] + fn builds_sde_metadata_shape() { + let metadata = new("one_cmt_sde") + .kind(ModelKind::Sde) + .parameters(["ke", "sigma", "v"]) + .covariates([Covariate::continuous("wt"), Covariate::locf("age")]) + .states(["central"]) + .outputs(["cp"]) + .route( + Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(128); + + assert_eq!(metadata.kind_decl(), Some(ModelKind::Sde)); + assert_eq!(metadata.covariates_decl()[0].name(), "wt"); + assert_eq!( + metadata.covariates_decl()[0].interpolation(), + Some(CovariateInterpolation::Linear) + ); + assert_eq!(metadata.covariates_decl()[1].name(), "age"); + assert_eq!( + metadata.covariates_decl()[1].interpolation(), + Some(CovariateInterpolation::Locf) + ); + assert_eq!(metadata.particles_decl(), Some(128)); + assert_eq!( + metadata.routes_decl()[0].input_policy(), + Some(RouteInputPolicy::InjectToDestination) + ); + } + + #[test] + fn validates_metadata_and_exposes_lookup_helpers() { + let metadata = new("bimodal_ke") + .kind(ModelKind::Ode) + .parameters(["ke", "v"]) + .covariates([Covariate::continuous("wt")]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect("metadata should validate"); + + assert_eq!(metadata.parameter_index("ke"), Some(0)); + assert_eq!(metadata.parameter_index("v"), Some(1)); + assert_eq!(metadata.covariate_index("wt"), Some(0)); + assert_eq!(metadata.state_index("central"), Some(0)); + assert_eq!(metadata.route_index("iv"), Some(0)); + assert_eq!(metadata.route_declaration_index("iv"), Some(0)); + assert_eq!(metadata.route_channel_count(), 1); + assert_eq!(metadata.output_index("cp"), Some(0)); + assert_eq!( + metadata.route("iv").expect("route exists").destination(), + "central" + ); + assert_eq!( + metadata + .route("iv") + .expect("route exists") + .declaration_index(), + 0 + ); + assert_eq!( + metadata.route("iv").expect("route exists").channel_index(), + 0 + ); + assert_eq!( + metadata + .route("iv") + .expect("route exists") + .destination_index(), + 0 + ); + } + + #[test] + fn duplicate_names_fail_validation() { + let error = new("dup_params") + .kind(ModelKind::Ode) + .parameters(["ke", "ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect_err("duplicate parameters must fail"); + + assert_eq!( + error, + ModelMetadataError::DuplicateName { + domain: NameDomain::Parameter, + name: "ke".to_string(), + } + ); + } + + #[test] + fn missing_route_destination_fails_validation() { + let error = new("missing_route_destination") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv")) + .validate() + .expect_err("route destination is required"); + + assert_eq!( + error, + ModelMetadataError::MissingRouteDestination { + route: "iv".to_string(), + } + ); + } + + #[test] + fn unknown_route_destination_fails_validation() { + let error = new("unknown_route_destination") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("peripheral")) + .validate() + .expect_err("unknown destinations must fail"); + + assert_eq!( + error, + ModelMetadataError::UnknownRouteDestination { + route: "iv".to_string(), + destination: "peripheral".to_string(), + } + ); + } + + #[test] + fn shared_channel_routes_preserve_declaration_and_channel_identity() { + let metadata = new("shared_channel") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + Route::bolus("oral").to_state("gut"), + Route::infusion("iv").to_state("central"), + ]) + .validate() + .expect("shared-channel metadata should validate"); + + assert_eq!(metadata.routes().len(), 2); + assert_eq!(metadata.route_channel_count(), 1); + assert_eq!(metadata.route_index("oral"), Some(0)); + assert_eq!(metadata.route_index("iv"), Some(0)); + assert_eq!(metadata.route_declaration_index("oral"), Some(0)); + assert_eq!(metadata.route_declaration_index("iv"), Some(1)); + assert_eq!( + metadata.route("oral").expect("oral route").channel_index(), + 0 + ); + assert_eq!(metadata.route("iv").expect("iv route").channel_index(), 0); + assert_eq!( + metadata + .route("oral") + .expect("oral route") + .declaration_index(), + 0 + ); + assert_eq!( + metadata.route("iv").expect("iv route").declaration_index(), + 1 + ); + } + + #[test] + fn infusion_routes_reject_lag_and_bioavailability() { + let lag_error = new("infusion_lag") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central").with_lag()) + .validate() + .expect_err("infusion lag must fail"); + + assert_eq!( + lag_error, + ModelMetadataError::InfusionLagNotAllowed { + route: "iv".to_string(), + } + ); + + let fa_error = new("infusion_fa") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route( + Route::infusion("iv") + .to_state("central") + .with_bioavailability(), + ) + .validate() + .expect_err("infusion bioavailability must fail"); + + assert_eq!( + fa_error, + ModelMetadataError::InfusionBioavailabilityNotAllowed { + route: "iv".to_string(), + } + ); + } + + #[test] + fn validate_requires_or_accepts_a_kind() { + let error = new("kind_required") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect_err("kindless metadata needs explicit validation kind"); + + assert_eq!(error, ModelMetadataError::MissingModelKind); + + let validated = new("kind_override") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate_for(ModelKind::Ode) + .expect("caller-provided kind should validate"); + + assert_eq!(validated.kind(), ModelKind::Ode); + } + + #[test] + fn conflicting_kinds_fail_validation() { + let error = new("kind_conflict") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate_for(ModelKind::Sde) + .expect_err("conflicting kinds must fail"); + + assert_eq!( + error, + ModelMetadataError::ModelKindConflict { + declared: ModelKind::Ode, + requested: ModelKind::Sde, + } + ); + } + + #[test] + fn particles_are_rejected_for_ode_and_analytical() { + let ode_error = new("ode_particles") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64) + .validate() + .expect_err("ODE metadata cannot declare particles"); + + assert_eq!( + ode_error, + ModelMetadataError::ParticlesNotAllowed { + kind: ModelKind::Ode, + } + ); + + let analytical_error = new("analytical_particles") + .kind(ModelKind::Analytical) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64) + .validate() + .expect_err("Analytical metadata cannot declare particles"); + + assert_eq!( + analytical_error, + ModelMetadataError::ParticlesNotAllowed { + kind: ModelKind::Analytical, + } + ); + } + + #[test] + fn analytical_kernel_is_limited_to_analytical_models() { + let error = new("ode_kernel") + .kind(ModelKind::Ode) + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .analytical_kernel(AnalyticalKernel::OneCompartment) + .validate() + .expect_err("ODE metadata cannot declare an analytical kernel"); + + assert_eq!( + error, + ModelMetadataError::AnalyticalKernelNotAllowed { + kind: ModelKind::Ode, + } + ); + } + + #[test] + fn sde_requires_particles_or_a_fallback_count() { + let error = new("sde_missing_particles") + .kind(ModelKind::Sde) + .parameters(["ke", "sigma"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate() + .expect_err("SDE metadata requires particles"); + + assert_eq!(error, ModelMetadataError::MissingParticles); + + let validated = new("sde_fallback_particles") + .parameters(["ke", "sigma"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .validate_for_with_particles(ModelKind::Sde, 128) + .expect("fallback particle count should satisfy SDE validation"); + + assert_eq!(validated.kind(), ModelKind::Sde); + assert_eq!(validated.particles(), Some(128)); + } + + #[test] + fn conflicting_particle_counts_fail_validation() { + let error = new("sde_particle_conflict") + .parameters(["ke", "sigma"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64) + .validate_for_with_particles(ModelKind::Sde, 128) + .expect_err("mismatched particle counts must fail"); + + assert_eq!( + error, + ModelMetadataError::ParticleCountConflict { + declared: 64, + fallback: 128, + } + ); + } +} diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 39cd741f..60cb2d8f 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -1,11 +1,12 @@ use std::fmt::Debug; pub mod analytical; -pub mod meta; +pub mod metadata; pub mod ode; pub mod sde; pub use analytical::*; -pub use meta::*; +pub use metadata::*; pub use ode::*; +pub use pharmsol_dsl::{AnalyticalKernel, ModelKind}; pub use sde::*; use crate::{ diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index 17b04235..cafe6a96 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -27,8 +27,13 @@ use diffsol::{ OdeSolverStopReason, Vector, VectorHost, }; use nalgebra::DVector; +use pharmsol_dsl::ModelKind; +use thiserror::Error; -use super::{Equation, EquationPriv, EquationTypes, State}; +use super::{ + EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, State, + ValidatedModelMetadata, +}; const RTOL: f64 = 1e-4; const ATOL: f64 = 1e-4; @@ -76,6 +81,20 @@ pub enum ExplicitRkTableau { Tsit45, } +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum OdeMetadataError { + #[error(transparent)] + Validation(#[from] ModelMetadataError), + #[error("ODE declares {declared} state metadata entries but model has {expected} states")] + StateCountMismatch { expected: usize, declared: usize }, + #[error( + "ODE declares {declared} route metadata entries but model has {expected} input channels" + )] + RouteCountMismatch { expected: usize, declared: usize }, + #[error("ODE declares {declared} output metadata entries but model has {expected} outputs")] + OutputCountMismatch { expected: usize, declared: usize }, +} + #[derive(Clone, Debug)] pub struct ODE { diffeq: DiffEq, @@ -87,6 +106,7 @@ pub struct ODE { solver: OdeSolver, rtol: f64, atol: f64, + metadata: Option, cache: Option, } @@ -102,6 +122,7 @@ impl ODE { solver: OdeSolver::default(), rtol: RTOL, atol: ATOL, + metadata: None, cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), } } @@ -109,18 +130,21 @@ impl ODE { /// Set the number of state variables (ODE compartments). pub fn with_nstates(mut self, nstates: usize) -> Self { self.neqs.nstates = nstates; + self.invalidate_metadata(); self } /// Set the number of drug input channels (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; + self.invalidate_metadata(); self } /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { self.neqs.nout = nout; + self.invalidate_metadata(); self } @@ -136,6 +160,74 @@ impl ODE { self.atol = atol; self } + + /// Attach validated handwritten-model metadata to this ODE. + pub fn with_metadata(mut self, metadata: ModelMetadata) -> Result { + let metadata = metadata.validate_for(ModelKind::Ode)?; + validate_metadata_dimensions(&metadata, &self.neqs)?; + self.metadata = Some(metadata); + Ok(self) + } + + /// Access the validated metadata attached to this ODE, if any. + pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.metadata()?.parameter_index(name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.metadata()?.covariate_index(name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.metadata()?.state_index(name) + } + + pub fn route_index(&self, name: &str) -> Option { + self.metadata()?.route_index(name) + } + + pub fn output_index(&self, name: &str) -> Option { + self.metadata()?.output_index(name) + } + + fn invalidate_metadata(&mut self) { + self.metadata = None; + } +} + +fn validate_metadata_dimensions( + metadata: &ValidatedModelMetadata, + neqs: &Neqs, +) -> Result<(), OdeMetadataError> { + let declared_states = metadata.states().len(); + if declared_states != neqs.nstates { + return Err(OdeMetadataError::StateCountMismatch { + expected: neqs.nstates, + declared: declared_states, + }); + } + + let declared_routes = metadata.route_channel_count(); + if declared_routes != neqs.ndrugs { + return Err(OdeMetadataError::RouteCountMismatch { + expected: neqs.ndrugs, + declared: declared_routes, + }); + } + + let declared_outputs = metadata.outputs().len(); + if declared_outputs != neqs.nout { + return Err(OdeMetadataError::OutputCountMismatch { + expected: neqs.nout, + declared: declared_outputs, + }); + } + + Ok(()) } impl super::Cache for ODE { @@ -280,7 +372,7 @@ impl EquationPriv for ODE { impl ODE { /// Generic event-loop runner, parameterized over the concrete solver type. #[allow(clippy::too_many_arguments)] - fn run_events<'a, S: OdeSolverMethod<'a, PMProblem<'a, DiffEq>>>( + fn run_events<'a, F, S>( &self, solver: &mut S, events: &[Event], @@ -295,7 +387,11 @@ impl ODE { y_out: &mut V, likelihood: &mut Vec, output: &mut SubjectPredictions, - ) -> Result<(), PharmsolError> { + ) -> Result<(), PharmsolError> + where + F: Fn(&V, &V, f64, &mut V, &V, &V, &Covariates) + 'a, + S: OdeSolverMethod<'a, PMProblem<'a, F>>, + { for (index, event) in events.iter().enumerate() { let next_event = events.get(index + 1); @@ -420,8 +516,8 @@ impl Equation for ODE { ypred.log_likelihood(error_models) } - fn kind() -> crate::EqnKind { - crate::EqnKind::ODE + fn kind() -> EqnKind { + EqnKind::ODE } fn simulate_subject( @@ -467,7 +563,9 @@ impl Equation for ODE { .h0(1e-3) .p(support_point.to_vec()) .build_from_eqn(PMProblem::with_params_v( - self.diffeq, + move |x, p, t, dx, bolus, rateiv, cov| { + (self.diffeq)(x, p, t, dx, bolus, rateiv, cov); + }, nstates, ndrugs, support_point.to_vec(), @@ -560,3 +658,235 @@ impl Equation for ODE { Ok((output, ll)) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{fa, lag, Subject, SubjectBuilderExt}; + use approx::assert_relative_eq; + + fn simple_ode() -> ODE { + ODE::new( + |_x, _p, _t, _dx, _b, _rateiv, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, _y| {}, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + } + + fn route_policy_subject() -> Subject { + Subject::builder("route_policy") + .bolus(0.0, 100.0, 0) + .infusion(0.0, 100.0, 0, 1.0) + .observation(1.0, 0.0, 0) + .build() + } + + fn explicit_route_kernel( + _x: &V, + _p: &V, + _t: f64, + dx: &mut V, + b: &V, + rateiv: &V, + _cov: &Covariates, + ) { + dx[0] = b[0] + rateiv[0]; + } + + fn injected_route_kernel( + _x: &V, + _p: &V, + _t: f64, + dx: &mut V, + _b: &V, + _rateiv: &V, + _cov: &Covariates, + ) { + dx[0] = 0.0; + } + + fn zero_lag(_p: &V, _t: f64, _cov: &Covariates) -> std::collections::HashMap { + std::collections::HashMap::new() + } + + fn unit_fa(_p: &V, _t: f64, _cov: &Covariates) -> std::collections::HashMap { + std::collections::HashMap::new() + } + + fn zero_init(_p: &V, _t: f64, _cov: &Covariates, _x: &mut V) {} + + fn state_output(x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V) { + y[0] = x[0]; + } + + #[test] + fn handwritten_ode_metadata_exposes_name_lookup() { + let ode = simple_ode() + .with_metadata( + super::super::metadata::new("bimodal_ke") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate"); + + assert_eq!(ode.parameter_index("ke"), Some(0)); + assert_eq!(ode.parameter_index("v"), Some(1)); + assert_eq!(ode.state_index("central"), Some(0)); + assert_eq!(ode.route_index("iv"), Some(0)); + assert_eq!(ode.output_index("cp"), Some(0)); + assert_eq!( + ode.metadata().expect("metadata exists").kind(), + ModelKind::Ode + ); + } + + #[test] + fn handwritten_ode_without_metadata_keeps_raw_path() { + let ode = simple_ode(); + + assert!(ode.metadata().is_none()); + assert_eq!(ode.state_index("central"), None); + assert_eq!(ode.route_index("iv"), None); + assert_eq!(ode.output_index("cp"), None); + } + + #[test] + fn handwritten_ode_rejects_dimension_mismatches() { + let error = simple_ode() + .with_metadata( + super::super::metadata::new("wrong_outputs") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp", "auc"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect_err("output-count mismatches must fail"); + + assert_eq!( + error, + OdeMetadataError::OutputCountMismatch { + expected: 1, + declared: 2, + } + ); + } + + #[test] + fn handwritten_ode_rejects_invalid_metadata() { + let error = simple_ode() + .with_metadata( + super::super::metadata::new("missing_destination") + .parameters(["ke"]) + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv")), + ) + .expect_err("invalid metadata must fail during attachment"); + + assert_eq!( + error, + OdeMetadataError::Validation(ModelMetadataError::MissingRouteDestination { + route: "iv".to_string(), + }) + ); + } + + #[test] + fn handwritten_ode_defaults_to_explicit_route_vectors() { + let ode = ODE::new( + explicit_route_kernel, + zero_lag, + unit_fa, + zero_init, + state_output, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("explicit_routes") + .states(["central"]) + .outputs(["cp"]) + .routes([ + super::super::Route::bolus("oral").to_state("central"), + super::super::Route::infusion("iv").to_state("central"), + ]), + ) + .expect("metadata attachment should validate"); + + let predictions = ode + .simulate_subject(&route_policy_subject(), &[], None) + .expect("simulation should succeed") + .0; + + assert_eq!(ode.route_index("oral").expect("oral route"), 0); + assert_eq!(ode.route_index("iv").expect("iv route"), 0); + assert_relative_eq!( + predictions.predictions()[0].prediction(), + 200.0, + epsilon = 1e-6 + ); + } + + #[test] + fn handwritten_ode_metadata_input_policy_is_descriptive_only() { + let ode = ODE::new( + injected_route_kernel, + zero_lag, + unit_fa, + zero_init, + state_output, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + super::super::metadata::new("injected_routes") + .states(["central"]) + .outputs(["cp"]) + .routes([ + super::super::Route::bolus("oral") + .to_state("central") + .inject_input_to_destination(), + super::super::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]), + ) + .expect("metadata attachment should validate"); + + let predictions = ode + .simulate_subject(&route_policy_subject(), &[], None) + .expect("simulation should succeed") + .0; + + assert_relative_eq!( + predictions.predictions()[0].prediction(), + 0.0, + epsilon = 1e-6 + ); + } + + #[test] + fn changing_dimensions_after_metadata_clears_route_metadata() { + let ode = simple_ode() + .with_metadata( + super::super::metadata::new("bimodal_ke") + .states(["central"]) + .outputs(["cp"]) + .route(super::super::Route::infusion("iv").to_state("central")), + ) + .expect("metadata attachment should validate") + .with_ndrugs(2); + + assert!(ode.metadata().is_none()); + assert_eq!(ode.route_index("iv"), None); + } +} diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index af8ea246..bdafbbc3 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -3,8 +3,10 @@ mod em; use diffsol::{NalgebraContext, Vector}; use nalgebra::DVector; use ndarray::{concatenate, Array2, Axis}; +use pharmsol_dsl::ModelKind; use rand::{rng, RngExt}; use rayon::prelude::*; +use thiserror::Error; use crate::{ data::{Covariates, Infusion}, @@ -21,7 +23,59 @@ use diffsol::VectorCommon; use crate::PharmsolError; -use super::{Equation, EquationPriv, EquationTypes, Predictions, State}; +use super::{ + EqnKind, Equation, EquationPriv, EquationTypes, ModelMetadata, ModelMetadataError, Predictions, + State, ValidatedModelMetadata, +}; + +#[derive(Clone, Debug, PartialEq, Eq, Error)] +pub enum SdeMetadataError { + #[error(transparent)] + Validation(#[from] ModelMetadataError), + #[error("SDE declares {declared} state metadata entries but model has {expected} states")] + StateCountMismatch { expected: usize, declared: usize }, + #[error( + "SDE declares {declared} route metadata entries but model has {expected} input channels" + )] + RouteCountMismatch { expected: usize, declared: usize }, + #[error("SDE declares {declared} output metadata entries but model has {expected} outputs")] + OutputCountMismatch { expected: usize, declared: usize }, +} + +#[derive(Clone, Debug, Default)] +struct InjectedBolusMappings { + destinations: Vec>, +} + +impl InjectedBolusMappings { + fn explicit(ndrugs: usize) -> Self { + Self { + destinations: vec![None; ndrugs], + } + } + + fn from_destinations(ndrugs: usize, destinations: &[Option]) -> Self { + let mut mappings = Self::explicit(ndrugs); + for (input, destination) in destinations.iter().copied().take(ndrugs).enumerate() { + mappings.destinations[input] = destination; + } + mappings + } + + fn invalidate_for_ndrugs(&mut self, ndrugs: usize) { + *self = Self::explicit(ndrugs); + } + + fn apply(&self, state: &mut [DVector], input: usize, amount: f64) -> bool { + let Some(destination) = self.destinations.get(input).copied().flatten() else { + return false; + }; + state.par_iter_mut().for_each(|particle| { + particle[destination] += amount; + }); + true + } +} /// Simulate a stochastic differential equation (SDE) event. /// @@ -44,7 +98,7 @@ use super::{Equation, EquationPriv, EquationTypes, Predictions, State}; /// The state vector at time `tf` after simulation. #[inline(always)] #[allow(clippy::too_many_arguments)] -pub(crate) fn simulate_sde_event( +fn simulate_sde_event( drift: &Drift, difussion: &Diffusion, x: V, @@ -133,6 +187,8 @@ pub struct SDE { out: Out, neqs: Neqs, nparticles: usize, + metadata: Option, + injected_bolus_mappings: InjectedBolusMappings, cache: Option, } @@ -164,6 +220,8 @@ impl SDE { out, neqs: Neqs::default(), nparticles, + metadata: None, + injected_bolus_mappings: InjectedBolusMappings::default(), cache: Some(SdeLikelihoodCache::new(DEFAULT_CACHE_SIZE)), } } @@ -171,20 +229,100 @@ impl SDE { /// Set the number of state variables. pub fn with_nstates(mut self, nstates: usize) -> Self { self.neqs.nstates = nstates; + self.invalidate_metadata(); self } /// Set the number of drug input channels (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; + self.invalidate_metadata(); self } /// Set the number of output equations. pub fn with_nout(mut self, nout: usize) -> Self { self.neqs.nout = nout; + self.invalidate_metadata(); self } + + /// Attach validated handwritten-model metadata to this SDE model. + pub fn with_metadata(mut self, metadata: ModelMetadata) -> Result { + let metadata = metadata.validate_for_with_particles(ModelKind::Sde, self.nparticles)?; + validate_metadata_dimensions(&metadata, &self.neqs)?; + self.metadata = Some(metadata); + Ok(self) + } + + #[doc(hidden)] + pub fn with_injected_bolus_inputs(mut self, destinations: &[Option]) -> Self { + self.injected_bolus_mappings = + InjectedBolusMappings::from_destinations(self.neqs.ndrugs, destinations); + self + } + + /// Access the validated metadata attached to this SDE model, if any. + pub fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + + pub fn parameter_index(&self, name: &str) -> Option { + self.metadata()?.parameter_index(name) + } + + pub fn covariate_index(&self, name: &str) -> Option { + self.metadata()?.covariate_index(name) + } + + pub fn state_index(&self, name: &str) -> Option { + self.metadata()?.state_index(name) + } + + pub fn route_index(&self, name: &str) -> Option { + self.metadata()?.route_index(name) + } + + pub fn output_index(&self, name: &str) -> Option { + self.metadata()?.output_index(name) + } + + fn invalidate_metadata(&mut self) { + self.metadata = None; + self.injected_bolus_mappings + .invalidate_for_ndrugs(self.neqs.ndrugs); + } +} + +fn validate_metadata_dimensions( + metadata: &ValidatedModelMetadata, + neqs: &Neqs, +) -> Result<(), SdeMetadataError> { + let declared_states = metadata.states().len(); + if declared_states != neqs.nstates { + return Err(SdeMetadataError::StateCountMismatch { + expected: neqs.nstates, + declared: declared_states, + }); + } + + let declared_routes = metadata.route_channel_count(); + if declared_routes != neqs.ndrugs { + return Err(SdeMetadataError::RouteCountMismatch { + expected: neqs.ndrugs, + declared: declared_routes, + }); + } + + let declared_outputs = metadata.outputs().len(); + if declared_outputs != neqs.nout { + return Err(SdeMetadataError::OutputCountMismatch { + expected: neqs.nout, + declared: declared_outputs, + }); + } + + Ok(()) } impl super::Cache for SDE { @@ -435,6 +573,63 @@ impl EquationPriv for SDE { } x } + + fn simulate_event( + &self, + support_point: &[f64], + event: &crate::Event, + next_event: Option<&crate::Event>, + error_models: Option<&AssayErrorModels>, + covariates: &Covariates, + x: &mut Self::S, + infusions: &mut Vec, + likelihood: &mut Vec, + output: &mut Self::P, + ) -> Result<(), PharmsolError> { + match event { + crate::Event::Bolus(bolus) => { + if bolus.input() >= self.get_ndrugs() { + return Err(PharmsolError::InputOutOfRange { + input: bolus.input(), + ndrugs: self.get_ndrugs(), + }); + } + if !self + .injected_bolus_mappings + .apply(x, bolus.input(), bolus.amount()) + { + x.add_bolus(bolus.input(), bolus.amount()); + } + } + crate::Event::Infusion(infusion) => { + infusions.push(infusion.clone()); + } + crate::Event::Observation(observation) => { + self.process_observation( + support_point, + observation, + error_models, + event.time(), + covariates, + x, + likelihood, + output, + )?; + } + } + + if let Some(next_event) = next_event { + self.solve( + x, + support_point, + covariates, + infusions, + event.time(), + next_event.time(), + )?; + } + Ok(()) + } } impl Equation for SDE { @@ -475,8 +670,8 @@ impl Equation for SDE { } } - fn kind() -> crate::EqnKind { - crate::EqnKind::SDE + fn kind() -> EqnKind { + EqnKind::SDE } } @@ -533,3 +728,276 @@ fn sysresample(q: &[f64]) -> Vec { } i } + +#[cfg(test)] +mod tests { + use super::*; + use crate::simulator::equation::{self, Covariate, Route}; + use crate::SubjectBuilderExt; + use crate::{fa, fetch_params, lag}; + + fn simple_sde() -> SDE { + let drift = |x: &V, _p: &V, _t: f64, dx: &mut V, rateiv: &V, _cov: &Covariates| { + dx[0] = rateiv[0] - x[0]; + }; + let diffusion = |_p: &V, g: &mut V| { + g[0] = 1.0; + }; + let lag = |_p: &V, _t: f64, _cov: &Covariates| lag! {}; + let fa = |_p: &V, _t: f64, _cov: &Covariates| fa! {}; + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x[0] = 0.0; + }; + let out = |x: &V, p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }; + + SDE::new(drift, diffusion, lag, fa, init, out, 128) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + } + + fn route_policy_sde(drift: Drift) -> SDE { + let diffusion = |_p: &V, sigma: &mut V| { + sigma.fill(0.0); + }; + let lag = |_p: &V, _t: f64, _cov: &Covariates| lag! {}; + let fa = |_p: &V, _t: f64, _cov: &Covariates| fa! {}; + let init = |_p: &V, _t: f64, _cov: &Covariates, x: &mut V| { + x.fill(0.0); + }; + let out = |x: &V, _p: &V, _t: f64, _cov: &Covariates, y: &mut V| { + y[0] = x[1]; + }; + + SDE::new(drift, diffusion, lag, fa, init, out, 16) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + } + + #[test] + fn handwritten_sde_metadata_exposes_name_lookup_and_particles() { + let sde = simple_sde() + .with_metadata( + equation::metadata::new("one_cmt_sde") + .parameters(["ke", "v"]) + .covariates([Covariate::continuous("wt")]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(128), + ) + .expect("SDE metadata attachment should validate"); + + let metadata = sde.metadata().expect("metadata exists"); + assert_eq!(metadata.kind(), ModelKind::Sde); + assert_eq!(metadata.particles(), Some(128)); + assert_eq!(sde.parameter_index("ke"), Some(0)); + assert_eq!(sde.parameter_index("v"), Some(1)); + assert_eq!(sde.covariate_index("wt"), Some(0)); + assert_eq!(sde.state_index("central"), Some(0)); + assert_eq!(sde.route_index("iv"), Some(0)); + assert_eq!(sde.output_index("cp"), Some(0)); + } + + #[test] + fn handwritten_sde_without_metadata_keeps_raw_path() { + let sde = simple_sde(); + + assert!(sde.metadata().is_none()); + assert_eq!(sde.parameter_index("ke"), None); + assert_eq!(sde.route_index("iv"), None); + assert_eq!(sde.output_index("cp"), None); + } + + #[test] + fn handwritten_sde_rejects_dimension_mismatches() { + let error = simple_sde() + .with_metadata( + equation::metadata::new("bad_sde") + .parameters(["ke", "v"]) + .states(["central", "peripheral"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(128), + ) + .expect_err("mismatched state metadata must fail"); + + assert_eq!( + error, + SdeMetadataError::StateCountMismatch { + expected: 1, + declared: 2, + } + ); + } + + #[test] + fn handwritten_sde_rejects_particle_mismatch() { + let error = simple_sde() + .with_metadata( + equation::metadata::new("particle_conflict") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(64), + ) + .expect_err("mismatched SDE particles must fail"); + + assert_eq!( + error, + SdeMetadataError::Validation(ModelMetadataError::ParticleCountConflict { + declared: 64, + fallback: 128, + }) + ); + } + + #[test] + fn changing_dimensions_after_metadata_clears_sde_metadata() { + let sde = simple_sde() + .with_metadata( + equation::metadata::new("one_cmt_sde") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(128), + ) + .expect("metadata attachment should validate") + .with_nout(2); + + assert!(sde.metadata().is_none()); + assert_eq!(sde.route_index("iv"), None); + assert_eq!(sde.output_index("cp"), None); + } + + #[test] + fn sde_metadata_input_policy_is_descriptive_only_for_bolus_routes() { + let zero_drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, _rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + }; + + let explicit = route_policy_sde(zero_drift) + .with_metadata( + equation::metadata::new("explicit_bolus") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route(Route::bolus("oral").to_state("central")) + .particles(16), + ) + .expect("explicit metadata should validate"); + + let injected = route_policy_sde(zero_drift) + .with_metadata( + equation::metadata::new("injected_bolus") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + Route::bolus("oral") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("injected metadata should validate"); + + let subject = Subject::builder("bolus_route") + .bolus(0.0, 100.0, 0) + .missing_observation(0.1, 0) + .build(); + + let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); + let injected_predictions = injected.estimate_predictions(&subject, &[0.0]).unwrap(); + + assert_eq!(explicit_predictions[[0, 0]].prediction(), 0.0); + assert_eq!(injected_predictions[[0, 0]].prediction(), 0.0); + } + + #[test] + fn sde_metadata_input_policy_does_not_change_explicit_rateiv_behavior() { + let rateiv_drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + dx[1] = rateiv[0]; + }; + + let explicit = route_policy_sde(rateiv_drift) + .with_metadata( + equation::metadata::new("explicit_infusion") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route(Route::infusion("iv").to_state("central")) + .particles(16), + ) + .expect("explicit metadata should validate"); + + let injected = route_policy_sde(rateiv_drift) + .with_metadata( + equation::metadata::new("injected_infusion") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("injected metadata should validate"); + + let subject = Subject::builder("infusion_route") + .infusion(0.0, 100.0, 0, 1.0) + .missing_observation(1.0, 0) + .build(); + + let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); + let injected_predictions = injected.estimate_predictions(&subject, &[0.0]).unwrap(); + + let explicit_prediction = explicit_predictions[[0, 0]].prediction(); + let injected_prediction = injected_predictions[[0, 0]].prediction(); + + assert!(explicit_prediction > 0.0); + assert!((injected_prediction - explicit_prediction).abs() < 1e-8); + } + + #[test] + fn clearing_sde_metadata_preserves_raw_bolus_behavior() { + let zero_drift = |_x: &V, _p: &V, _t: f64, dx: &mut V, _rateiv: &V, _cov: &Covariates| { + dx.fill(0.0); + }; + + let sde = route_policy_sde(zero_drift) + .with_metadata( + equation::metadata::new("injected_bolus") + .parameters(["theta"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + Route::bolus("oral") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("injected metadata should validate") + .with_nout(1); + + let subject = Subject::builder("bolus_route") + .bolus(0.0, 100.0, 0) + .missing_observation(0.1, 0) + .build(); + + let predictions = sde.estimate_predictions(&subject, &[0.0]).unwrap(); + + assert!(sde.metadata().is_none()); + assert_eq!(predictions[[0, 0]].prediction(), 0.0); + } +} diff --git a/src/test_fixtures.rs b/src/test_fixtures.rs index 7fb1610e..91d21e90 100644 --- a/src/test_fixtures.rs +++ b/src/test_fixtures.rs @@ -83,7 +83,7 @@ model one_cmt_abs { oral -> depot } analytical { - kernel = one_compartment_with_absorption + structure = one_compartment_with_absorption } outputs { cp = central / v diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs new file mode 100644 index 00000000..e025ec4f --- /dev/null +++ b/tests/analytical_macro_lowering.rs @@ -0,0 +1,495 @@ +use approx::assert_relative_eq; +use pharmsol::prelude::*; + +fn infusion_subject(input: usize) -> Subject { + Subject::builder("analytical-macro-iv") + .infusion(0.0, 120.0, input, 1.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .build() +} + +fn oral_subject(input: usize) -> Subject { + Subject::builder("analytical-macro-oral") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .build() +} + +fn shared_channel_subject(input: usize) -> Subject { + Subject::builder("analytical-macro-shared") + .bolus(0.0, 100.0, input) + .infusion(6.0, 60.0, input, 2.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .build() +} + +fn covariate_subject(oral: usize, iv: usize, cp: usize) -> Subject { + Subject::builder("analytical-macro-covariates") + .bolus(1.0, 100.0, oral) + .infusion(6.0, 140.0, iv, 2.0) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + +fn macro_one_compartment() -> equation::Analytical { + analytical! { + name: "one_cpt_iv", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + structure: one_compartment, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_one_compartment() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cpt_iv") + .kind(equation::ModelKind::Analytical) + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route(equation::Route::infusion("iv").to_state("central")) + .analytical_kernel(equation::AnalyticalKernel::OneCompartment), + ) + .expect("handwritten analytical metadata should validate") +} + +fn macro_one_compartment_with_absorption() -> equation::Analytical { + analytical! { + name: "one_cmt_abs", + params: [ka, ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + structure: one_compartment_with_absorption, + lag: |_t| { + lag! { oral => tlag } + }, + fa: |_t| { + fa! { oral => f_oral } + }, + init: |_t, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_one_compartment_with_absorption() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs") + .kind(equation::ModelKind::Analytical) + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + ) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten absorption metadata should validate") +} + +fn macro_shared_channel_analytical() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_shared", + params: [ka, ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + structure: one_compartment_with_absorption, + lag: |_t| { + lag! { oral => tlag } + }, + fa: |_t| { + fa! { oral => f_oral } + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_shared_channel_analytical() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_shared") + .kind(equation::ModelKind::Analytical) + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten shared-channel analytical metadata should validate") +} + +fn macro_covariate_analytical() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_covariates", + params: [ka, ke, v, tlag, f_oral, base_gut, base_central, tvke], + covariates: [wt, renal], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + structure: one_compartment_with_absorption, + sec: |_t| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + ke = tvke * wt_scale * renal_scale; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[gut] = base_gut + 0.03 * wt; + x[central] = base_central + 0.08 * renal; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_covariate_analytical() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |p, t, cov| { + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + p[1] = p[7] * wt_scale * renal_scale; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_covariates") + .kind(equation::ModelKind::Analytical) + .parameters([ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten covariate analytical metadata should validate") +} + +fn assert_prediction_match(left: &[f64], right: &[f64]) { + assert_eq!(left.len(), right.len()); + for (left, right) in left.iter().zip(right.iter()) { + assert_relative_eq!(left, right, epsilon = 1e-10); + } +} + +#[test] +fn analytical_macro_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_one_compartment(); + let handwritten_model = handwritten_one_compartment(); + let subject = infusion_subject(0); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("iv"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("central"), Some(0)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { + let macro_model = macro_one_compartment_with_absorption(); + let handwritten_model = handwritten_one_compartment_with_absorption(); + let subject = oral_subject(0); + let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("oral"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("gut"), Some(0)); + assert_eq!( + macro_model + .metadata() + .expect("macro metadata exists") + .analytical_kernel(), + Some(equation::AnalyticalKernel::OneCompartmentWithAbsorption) + ); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro absorption model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten absorption model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_channel_analytical(); + let handwritten_model = handwritten_shared_channel_analytical(); + let subject = shared_channel_subject(0); + let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("oral"), Some(0)); + assert_eq!(macro_model.route_index("iv"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("gut"), Some(0)); + assert_eq!(macro_model.state_index("central"), Some(1)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro shared-channel analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten shared-channel analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn analytical_macro_covariates_lower_to_handwritten_behavior() { + let macro_model = macro_covariate_analytical(); + let handwritten_model = handwritten_covariate_analytical(); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + + let oral = macro_model.route_index("oral").expect("oral route exists"); + let iv = macro_model.route_index("iv").expect("iv route exists"); + let cp = macro_model.output_index("cp").expect("cp output exists"); + let subject = covariate_subject(oral, iv, cp); + let support_point = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; + + assert_eq!(oral, iv); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro covariate analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten covariate analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs new file mode 100644 index 00000000..43621e8a --- /dev/null +++ b/tests/authoring_parity_corpus.rs @@ -0,0 +1,1365 @@ +#[cfg(feature = "dsl-jit")] +use pharmsol::dsl::{self, RuntimeCompilationTarget, RuntimePredictions}; +#[cfg(feature = "dsl-jit")] +use pharmsol::equation::RouteInputPolicy; +use pharmsol::equation::{ + self, AnalyticalKernel, RouteKind as HandwrittenRouteKind, ValidatedModelMetadata, +}; +use pharmsol::prelude::*; +#[cfg(feature = "dsl-jit")] +use pharmsol::Predictions; +use pharmsol_dsl::{ + analyze_model, lower_typed_model, parse_model, CovariateInterpolation, ExecutionModel, + ModelKind, RouteKind as DslRouteKind, +}; + +const ODE_DSL: &str = r#" +name = one_cmt_oral_iv +kind = ode + +params = ka, cl, v, tlag, f_oral +covariates = wt @linear +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - (cl / v) * central + +out(cp) = central / (v * (wt / 70.0)) ~ continuous() +"#; + +const ODE_MACRO_DSL: &str = r#" +name = one_cmt_oral_covariate_parity +kind = ode + +params = ka, cl, v, tlag, f_oral +covariates = wt @linear +states = depot, central +outputs = cp + +bolus(oral) -> depot +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - (cl / v) * central + +out(cp) = central / (v * (wt / 70.0)) ~ continuous() +"#; + +const ODE_INVALID_INFUSION_LAG_DSL: &str = r#" +name = invalid_infusion_lag_parity +kind = ode + +params = ke, v, tlag +states = central +outputs = cp + +infusion(iv) -> central +lag(iv) = tlag + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +name = shared_channel_one_cpt +kind = ode + +params = ka, ke, v, tlag, f_oral +states = depot, central +outputs = cp + +bolus(oral) -> depot +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central + +out(cp) = central / v ~ continuous() +"#; + +const ANALYTICAL_DSL: &str = r#" +name = one_cmt_abs_parity +kind = analytical + +params = ka, ke, v +states = depot, central +outputs = cp + +bolus(oral) -> depot +structure = one_compartment_with_absorption + +out(cp) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ANALYTICAL_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +name = one_cmt_abs_shared +kind = analytical + +params = ka, ke, v, tlag, f_oral +states = gut, central +outputs = cp + +bolus(oral) -> gut +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral +structure = one_compartment_with_absorption + +out(cp) = central / v ~ continuous() +"#; + +const SDE_DSL: &str = r#" +name = one_cmt_sde_parity +kind = sde + +params = ka, ke, v, sigma +covariates = wt @locf +states = depot, central +outputs = cp + +bolus(oral) -> depot +particles = 256 + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central +noise(central) = sigma + +out(cp) = central / (v * wt) ~ continuous() +"#; + +const SDE_MACRO_DSL: &str = r#" +name = one_cmt_sde_macro_parity +kind = sde + +params = ka, ke, v, sigma +states = depot, central +outputs = cp + +bolus(oral) -> depot +particles = 256 + +dx(depot) = -ka * depot +dx(central) = ka * depot - ke * central +noise(central) = sigma + +out(cp) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const SDE_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +name = one_cmt_shared_sde +kind = sde + +params = ka, ke, sigma_ke, v, tlag, f_oral +states = gut, central +outputs = cp +particles = 8 + +bolus(oral) -> gut +infusion(iv) -> central +lag(oral) = tlag +fa(oral) = f_oral + +dx(gut) = -ka * gut +dx(central) = ka * gut - ke * central +noise(central) = sigma_ke + +out(cp) = central / v ~ continuous() +"#; + +#[derive(Clone, Debug, PartialEq, Eq)] +struct MetadataParityView { + name: String, + kind: ModelKind, + parameters: Vec, + covariates: Vec, + states: Vec, + route_channel_count: usize, + routes: Vec, + outputs: Vec, + analytical_kernel: Option, + particles: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct NamedIndex { + name: String, + index: usize, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct CovariateParity { + name: String, + index: usize, + interpolation: Option, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct RouteParity { + name: String, + kind: Option, + declaration_index: usize, + channel_index: usize, + destination_name: String, + destination_index: usize, + has_lag: bool, + has_bioavailability: bool, +} + +#[cfg(feature = "dsl-jit")] +#[derive(Clone, Debug, PartialEq, Eq)] +struct RouteInputPolicyParity { + name: String, + declaration_index: usize, + channel_index: usize, + input_policy: RouteInputPolicy, +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum RouteKindParity { + Bolus, + Infusion, +} + +impl RouteKindParity { + fn from_dsl(kind: DslRouteKind) -> Self { + match kind { + DslRouteKind::Bolus => Self::Bolus, + DslRouteKind::Infusion => Self::Infusion, + } + } + + fn from_handwritten(kind: HandwrittenRouteKind) -> Self { + match kind { + HandwrittenRouteKind::Bolus => Self::Bolus, + HandwrittenRouteKind::Infusion => Self::Infusion, + } + } +} + +fn load_execution_model(src: &str) -> ExecutionModel { + let parsed = parse_model(src).expect("DSL model should parse"); + let typed = analyze_model(&parsed).expect("DSL model should analyze"); + lower_typed_model(&typed).expect("DSL model should lower") +} + +#[cfg(feature = "dsl-jit")] +fn compile_runtime_jit_model(src: &str, model_name: &str) -> dsl::CompiledRuntimeModel { + dsl::compile_module_source_to_runtime( + src, + Some(model_name), + RuntimeCompilationTarget::Jit, + |_, _| {}, + ) + .expect("DSL runtime model should compile") +} + +#[cfg(feature = "dsl-jit")] +fn shared_channel_prediction_subject(input: usize, output: usize) -> Subject { + Subject::builder("authoring-parity-shared-channel") + .bolus(0.0, 100.0, input) + .infusion(6.0, 60.0, input, 2.0) + .missing_observation(0.5, output) + .missing_observation(1.0, output) + .missing_observation(2.0, output) + .missing_observation(6.5, output) + .missing_observation(7.0, output) + .missing_observation(8.0, output) + .build() +} + +fn dsl_metadata_view(src: &str) -> MetadataParityView { + let model = load_execution_model(src); + + let parameters = model + .metadata + .parameters + .iter() + .map(|parameter| NamedIndex { + name: parameter.name.clone(), + index: parameter.index, + }) + .collect(); + let covariates = model + .metadata + .covariates + .iter() + .map(|covariate| CovariateParity { + name: covariate.name.clone(), + index: covariate.index, + interpolation: covariate.interpolation, + }) + .collect(); + let states = model + .metadata + .states + .iter() + .map(|state| NamedIndex { + name: state.name.clone(), + index: state.offset, + }) + .collect(); + let outputs = model + .metadata + .outputs + .iter() + .map(|output| NamedIndex { + name: output.name.clone(), + index: output.index, + }) + .collect(); + let routes = model + .metadata + .routes + .iter() + .map(|route| RouteParity { + name: route.name.clone(), + kind: route.kind.map(RouteKindParity::from_dsl), + declaration_index: route.declaration_index, + channel_index: route.index, + destination_name: route.destination.state_name.clone(), + destination_index: route.destination.state_offset, + has_lag: route.has_lag, + has_bioavailability: route.has_bioavailability, + }) + .collect(); + + MetadataParityView { + name: model.name, + kind: model.kind, + parameters, + covariates, + states, + route_channel_count: model.abi.route_buffer.len, + routes, + outputs, + analytical_kernel: model.metadata.analytical, + particles: model.metadata.particles, + } +} + +#[cfg(feature = "dsl-jit")] +fn dsl_route_input_policy_view(src: &str) -> Vec { + let model = load_execution_model(src); + let info = dsl::NativeModelInfo::from_execution_model(&model); + + info.routes + .into_iter() + .map(|route| RouteInputPolicyParity { + name: route.name, + declaration_index: route.declaration_index, + channel_index: route.index, + input_policy: if route.inject_input_to_destination { + RouteInputPolicy::InjectToDestination + } else { + RouteInputPolicy::ExplicitInputVector + }, + }) + .collect() +} + +fn validated_metadata_view(metadata: &ValidatedModelMetadata) -> MetadataParityView { + MetadataParityView { + name: metadata.name().to_string(), + kind: metadata.kind(), + parameters: metadata + .parameters() + .iter() + .enumerate() + .map(|(index, parameter)| NamedIndex { + name: parameter.name().to_string(), + index, + }) + .collect(), + covariates: metadata + .covariates() + .iter() + .enumerate() + .map(|(index, covariate)| CovariateParity { + name: covariate.name().to_string(), + index, + interpolation: covariate.interpolation(), + }) + .collect(), + states: metadata + .states() + .iter() + .enumerate() + .map(|(index, state)| NamedIndex { + name: state.name().to_string(), + index, + }) + .collect(), + route_channel_count: metadata.route_channel_count(), + routes: metadata + .routes() + .iter() + .map(|route| RouteParity { + name: route.name().to_string(), + kind: Some(RouteKindParity::from_handwritten(route.kind())), + declaration_index: route.declaration_index(), + channel_index: route.channel_index(), + destination_name: route.destination().to_string(), + destination_index: route.destination_index(), + has_lag: route.has_lag(), + has_bioavailability: route.has_bioavailability(), + }) + .collect(), + outputs: metadata + .outputs() + .iter() + .enumerate() + .map(|(index, output)| NamedIndex { + name: output.name().to_string(), + index, + }) + .collect(), + analytical_kernel: metadata.analytical_kernel(), + particles: metadata.particles(), + } +} + +#[cfg(feature = "dsl-jit")] +fn handwritten_route_input_policy_view( + metadata: &ValidatedModelMetadata, +) -> Vec { + metadata + .routes() + .iter() + .map(|route| RouteInputPolicyParity { + name: route.name().to_string(), + declaration_index: route.declaration_index(), + channel_index: route.channel_index(), + input_policy: route + .input_policy() + .expect("route input policy should be explicit in this handwritten fixture"), + }) + .collect() +} + +fn macro_ode_model() -> equation::ODE { + ode! { + name: "one_cmt_oral_covariate_parity", + params: [ka, cl, v, tlag, f_oral], + covariates: [wt], + states: [depot, central], + outputs: [cp], + routes: { + bolus(oral) -> depot, + }, + diffeq: |x, _p, _t, dx, _cov| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - (cl / v) * x[central]; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, t, cov, y| { + fetch_cov!(cov, t, wt); + y[cp] = x[central] / (v * (wt / 70.0)); + }, + } +} + +fn handwritten_ode_macro_model() -> equation::ODE { + equation::ODE::new( + |_x, _p, _t, dx, _bolus, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_oral_covariate_parity") + .parameters(["ka", "cl", "v", "tlag", "f_oral"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + ), + ) + .expect("handwritten macro-shape ODE metadata should validate") +} + +fn handwritten_ode_model() -> equation::ODE { + equation::ODE::new( + |_x, _p, _t, dx, _bolus, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_oral_iv") + .parameters(["ka", "cl", "v", "tlag", "f_oral"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("handwritten ODE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_macro_ode() -> equation::ODE { + ode! { + name: "shared_channel_one_cpt", + params: [ka, ke, v, tlag, f_oral], + states: [depot, central], + outputs: [cp], + routes: { + bolus(oral) -> depot, + infusion(iv) -> central, + }, + diffeq: |x, _p, _t, dx, bolus, rateiv, _cov| { + dx[depot] = bolus[oral] - ka * x[depot]; + dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("shared_channel_one_cpt") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("handwritten shared-channel ODE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_mismatched_shared_channel_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, _rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("shared_channel_one_cpt_mismatched") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("mismatched shared-channel ODE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_macro_analytical() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_shared", + params: [ka, ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + structure: one_compartment_with_absorption, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_handwritten_analytical() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_shared") + .kind(equation::ModelKind::Analytical) + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten shared-channel analytical metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_macro_sde() -> equation::SDE { + sde! { + name: "one_cmt_shared_sde", + params: [ka, ke, sigma_ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + drift: |x, _p, _t, dx, _cov| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + diffusion: |_p, sigma| { + sigma[gut] = 0.0; + sigma[central] = 0.0 * sigma_ke; + }, + lag: |_p, _t, _cov| { + lag! { oral => tlag } + }, + fa: |_p, _t, _cov| { + fa! { oral => f_oral } + }, + init: |_p, _t, _cov, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +#[cfg(feature = "dsl-jit")] +fn runtime_shared_channel_handwritten_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |_p, sigma| { + sigma[0] = 0.0; + sigma[1] = 0.0; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_shared_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]) + .particles(8), + ) + .expect("handwritten shared-channel SDE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn assert_prediction_vectors_close(left: &[f64], right: &[f64], tolerance: f64) { + assert_eq!(left.len(), right.len()); + for (left_value, right_value) in left.iter().zip(right.iter()) { + let diff = (left_value - right_value).abs(); + assert!( + diff <= tolerance, + "prediction mismatch: left={left_value:.12}, right={right_value:.12}, diff={diff:.12}, tolerance={tolerance:.12}" + ); + } +} + +#[cfg(feature = "dsl-jit")] +fn assert_prediction_vectors_diverge(left: &[f64], right: &[f64], tolerance: f64) { + assert_eq!(left.len(), right.len()); + assert!( + left.iter() + .zip(right.iter()) + .any(|(left_value, right_value)| (left_value - right_value).abs() > tolerance), + "expected prediction vectors to diverge beyond tolerance {tolerance:.12}" + ); +} + +#[cfg(feature = "dsl-jit")] +fn particle_prediction_means(predictions: &ndarray::Array2) -> Vec { + predictions + .get_predictions() + .into_iter() + .map(|prediction| prediction.prediction()) + .collect() +} + +fn macro_analytical_model() -> equation::Analytical { + analytical! { + name: "one_cmt_abs_parity", + params: [ka, ke, v], + states: [depot, central], + outputs: [cp], + routes: { + bolus(oral) -> depot, + }, + structure: one_compartment_with_absorption, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_analytical_model() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |_p, _t, _cov| {}, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_parity") + .kind(ModelKind::Analytical) + .parameters(["ka", "ke", "v"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route(equation::Route::bolus("oral").to_state("depot")) + .analytical_kernel(AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten analytical metadata should validate") +} + +fn macro_sde_model() -> equation::SDE { + sde! { + name: "one_cmt_sde_macro_parity", + params: [ka, ke, v, sigma], + states: [depot, central], + outputs: [cp], + particles: 256, + routes: { + bolus(oral) -> depot, + }, + drift: |x, _p, _t, dx, _cov| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; + }, + diffusion: |_p, sigma_values| { + sigma_values[depot] = 0.0; + sigma_values[central] = sigma; + }, + out: |x, _p, _t, _cov, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_sde_model() -> equation::SDE { + equation::SDE::new( + |_x, _p, _t, dx, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, sigma| { + sigma[0] = 0.0; + sigma[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + 256, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_sde_parity") + .kind(ModelKind::Sde) + .parameters(["ka", "ke", "v", "sigma"]) + .covariates([equation::Covariate::locf("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination(), + ) + .particles(256), + ) + .expect("handwritten SDE metadata should validate") +} + +fn handwritten_sde_macro_model() -> equation::SDE { + equation::SDE::new( + |_x, _p, _t, dx, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, sigma_values| { + sigma_values[0] = 0.0; + sigma_values[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + 256, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_sde_macro_parity") + .kind(ModelKind::Sde) + .parameters(["ka", "ke", "v", "sigma"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("depot") + .inject_input_to_destination(), + ) + .particles(256), + ) + .expect("handwritten macro-shape SDE metadata should validate") +} + +#[cfg(feature = "dsl-jit")] +fn mismatched_ode_model() -> equation::ODE { + equation::ODE::new( + |_x, _p, _t, dx, _bolus, _rateiv, _cov| { + dx[0] = 0.0; + dx[1] = 0.0; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |_x, _p, _t, _cov, y| { + y[0] = 0.0; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_oral_iv") + .parameters(["ka", "cl", "v", "tlag", "f_oral"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .expect_explicit_input() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("mismatched ODE metadata should validate") +} + +#[test] +fn ode_dsl_and_handwritten_metadata_agree_on_public_shape() { + let handwritten = handwritten_ode_model(); + let dsl_view = dsl_metadata_view(ODE_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten ODE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[test] +fn ode_macro_dsl_and_handwritten_metadata_agree_on_macro_authorable_shape() { + let handwritten = handwritten_ode_macro_model(); + let macro_model = macro_ode_model(); + let dsl_view = dsl_metadata_view(ODE_MACRO_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten macro-shape ODE metadata should exist"), + ); + let macro_view = validated_metadata_view( + macro_model + .metadata() + .expect("macro ODE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); + assert_eq!(macro_view, dsl_view); +} + +#[test] +fn analytical_dsl_macro_and_handwritten_metadata_agree_on_public_shape() { + let handwritten = handwritten_analytical_model(); + let macro_model = macro_analytical_model(); + let dsl_view = dsl_metadata_view(ANALYTICAL_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten analytical metadata should exist"), + ); + let macro_view = validated_metadata_view( + macro_model + .metadata() + .expect("macro analytical metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); + assert_eq!(macro_view, dsl_view); +} + +#[test] +fn sde_dsl_and_handwritten_metadata_agree_on_public_shape() { + let handwritten = handwritten_sde_model(); + let dsl_view = dsl_metadata_view(SDE_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten SDE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[test] +fn sde_macro_dsl_and_handwritten_metadata_agree_on_macro_authorable_shape() { + let handwritten = handwritten_sde_macro_model(); + let macro_model = macro_sde_model(); + let dsl_view = dsl_metadata_view(SDE_MACRO_DSL); + let handwritten_view = validated_metadata_view( + handwritten + .metadata() + .expect("handwritten macro-shape SDE metadata should exist"), + ); + let macro_view = validated_metadata_view( + macro_model + .metadata() + .expect("macro SDE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); + assert_eq!(macro_view, dsl_view); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_route_input_policies_agree_with_handwritten_metadata() { + let dsl_view = dsl_route_input_policy_view(ODE_DSL); + let handwritten = handwritten_ode_model(); + let handwritten_view = handwritten_route_input_policy_view( + handwritten + .metadata() + .expect("handwritten ODE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn sde_route_input_policies_agree_with_handwritten_metadata() { + let dsl_view = dsl_route_input_policy_view(SDE_DSL); + let handwritten = handwritten_sde_model(); + let handwritten_view = handwritten_route_input_policy_view( + handwritten + .metadata() + .expect("handwritten SDE metadata should exist"), + ); + + assert_eq!(handwritten_view, dsl_view); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn route_input_policy_mismatches_are_detected_explicitly() { + let dsl_view = dsl_route_input_policy_view(ODE_DSL); + let handwritten = mismatched_ode_model(); + let handwritten_view = handwritten_route_input_policy_view( + handwritten + .metadata() + .expect("mismatched handwritten metadata should exist"), + ); + + assert_ne!(handwritten_view, dsl_view); + assert_eq!(dsl_view[0].name, "oral"); + assert_eq!( + dsl_view[0].input_policy, + RouteInputPolicy::InjectToDestination + ); + assert_eq!( + handwritten_view[0].input_policy, + RouteInputPolicy::ExplicitInputVector + ); +} + +#[test] +fn invalid_dsl_infusion_route_properties_fail_explicitly() { + let model = + parse_model(ODE_INVALID_INFUSION_LAG_DSL).expect("invalid DSL fixture should parse"); + let typed = analyze_model(&model).expect("invalid DSL fixture should analyze"); + let error = lower_typed_model(&typed) + .err() + .expect("infusion lag should fail during lowering"); + + assert!(error + .to_string() + .contains("DSL authoring does not allow `lag` on infusion route `iv`")); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { + let runtime_model = + compile_runtime_jit_model(ODE_RUNTIME_SHARED_CHANNEL_DSL, "shared_channel_one_cpt"); + let macro_model = runtime_shared_channel_macro_ode(); + let handwritten_model = runtime_shared_channel_handwritten_ode(); + + let oral = runtime_model + .route_index("oral") + .expect("runtime oral route should exist"); + let iv = runtime_model + .route_index("iv") + .expect("runtime iv route should exist"); + let cp = runtime_model + .output_index("cp") + .expect("runtime cp output should exist"); + let subject = shared_channel_prediction_subject(oral, cp); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(macro_model.route_index("oral"), Some(oral)); + assert_eq!(macro_model.route_index("iv"), Some(iv)); + assert_eq!(handwritten_model.route_index("oral"), Some(oral)); + assert_eq!(handwritten_model.route_index("iv"), Some(iv)); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime ODE model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => panic!("ODE runtime should return subject predictions"), + }; + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro ODE model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten ODE model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_vectors_close(&runtime_predictions, ¯o_predictions, 1e-4); + assert_prediction_vectors_close(&runtime_predictions, &handwritten_predictions, 1e-4); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { + let runtime_model = + compile_runtime_jit_model(ANALYTICAL_RUNTIME_SHARED_CHANNEL_DSL, "one_cmt_abs_shared"); + let macro_model = runtime_shared_channel_macro_analytical(); + let handwritten_model = runtime_shared_channel_handwritten_analytical(); + + let oral = runtime_model + .route_index("oral") + .expect("runtime oral route should exist"); + let iv = runtime_model + .route_index("iv") + .expect("runtime iv route should exist"); + let cp = runtime_model + .output_index("cp") + .expect("runtime cp output should exist"); + let subject = shared_channel_prediction_subject(oral, cp); + let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(macro_model.route_index("oral"), Some(oral)); + assert_eq!(macro_model.route_index("iv"), Some(iv)); + assert_eq!(handwritten_model.route_index("oral"), Some(oral)); + assert_eq!(handwritten_model.route_index("iv"), Some(iv)); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime analytical model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => { + panic!("analytical runtime should return subject predictions") + } + }; + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro analytical model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten analytical model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_vectors_close(&runtime_predictions, ¯o_predictions, 1e-8); + assert_prediction_vectors_close(&runtime_predictions, &handwritten_predictions, 1e-8); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { + let runtime_model = + compile_runtime_jit_model(SDE_RUNTIME_SHARED_CHANNEL_DSL, "one_cmt_shared_sde"); + let macro_model = runtime_shared_channel_macro_sde(); + let handwritten_model = runtime_shared_channel_handwritten_sde(); + + let oral = runtime_model + .route_index("oral") + .expect("runtime oral route should exist"); + let iv = runtime_model + .route_index("iv") + .expect("runtime iv route should exist"); + let cp = runtime_model + .output_index("cp") + .expect("runtime cp output should exist"); + let subject = shared_channel_prediction_subject(oral, cp); + let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(macro_model.route_index("oral"), Some(oral)); + assert_eq!(macro_model.route_index("iv"), Some(iv)); + assert_eq!(handwritten_model.route_index("oral"), Some(oral)); + assert_eq!(handwritten_model.route_index("iv"), Some(iv)); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime SDE model should simulate") + { + RuntimePredictions::Particles(predictions) => particle_prediction_means(&predictions), + RuntimePredictions::Subject(_) => panic!("SDE runtime should return particle predictions"), + }; + let macro_predictions = particle_prediction_means( + ¯o_model + .estimate_predictions(&subject, &support_point) + .expect("macro SDE model should simulate"), + ); + let handwritten_predictions = particle_prediction_means( + &handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten SDE model should simulate"), + ); + + assert_prediction_vectors_close(&runtime_predictions, ¯o_predictions, 1e-4); + assert_prediction_vectors_close(&runtime_predictions, &handwritten_predictions, 1e-4); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn route_input_policy_runtime_mismatches_are_detected_explicitly() { + let runtime_model = + compile_runtime_jit_model(ODE_RUNTIME_SHARED_CHANNEL_DSL, "shared_channel_one_cpt"); + let mismatched_model = runtime_mismatched_shared_channel_ode(); + + let oral = runtime_model + .route_index("oral") + .expect("runtime oral route should exist"); + let iv = runtime_model + .route_index("iv") + .expect("runtime iv route should exist"); + let cp = runtime_model + .output_index("cp") + .expect("runtime cp output should exist"); + let subject = shared_channel_prediction_subject(oral, cp); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(oral, 0); + assert_eq!(iv, oral); + assert_eq!(mismatched_model.route_index("oral"), Some(oral)); + assert_eq!(mismatched_model.route_index("iv"), Some(iv)); + + let runtime_predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime ODE model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => panic!("ODE runtime should return subject predictions"), + }; + let mismatched_predictions = mismatched_model + .estimate_predictions(&subject, &support_point) + .expect("mismatched handwritten ODE should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_vectors_diverge(&runtime_predictions, &mismatched_predictions, 1e-4); +} diff --git a/tests/browser-e2e/site/app.mjs b/tests/browser-e2e/site/app.mjs index 87b68eda..c939c670 100644 --- a/tests/browser-e2e/site/app.mjs +++ b/tests/browser-e2e/site/app.mjs @@ -4,7 +4,7 @@ const precompiledInputs = Object.freeze({ }); const compileFlowSource = ` -model = example_ode +name = example_ode kind = ode params = ke, v @@ -25,7 +25,7 @@ const compileFlowInputs = Object.freeze({ }); const invalidCompileSource = ` -model = broken +name = broken kind = ode states = central diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs new file mode 100644 index 00000000..71a1afa7 --- /dev/null +++ b/tests/full_feature_macro_parity.rs @@ -0,0 +1,472 @@ +use pharmsol::prelude::*; + +fn max_abs_diff(left: &[f64], right: &[f64]) -> f64 { + left.iter() + .zip(right.iter()) + .map(|(lhs, rhs)| (lhs - rhs).abs()) + .fold(0.0_f64, f64::max) +} + +fn macro_ode_model() -> equation::ODE { + ode! { + name: "ode_full_feature_parity", + params: [ka, ke, kcp, kpc, v, tlag, f_oral, base_depot, base_central, base_peripheral], + covariates: [wt, renal], + states: [depot, central, peripheral], + outputs: [cp], + routes: { + bolus(oral) -> depot, + bolus(load) -> central, + infusion(iv) -> central, + }, + diffeq: |x, _t, dx, bolus, rateiv| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); + + dx[depot] = bolus[oral] - ka * x[depot]; + dx[central] = bolus[load] + ka * x[depot] + rateiv[iv] + - (adjusted_ke + adjusted_kcp) * x[central] + + kpc * x[peripheral]; + dx[peripheral] = adjusted_kcp * x[central] - kpc * x[peripheral]; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[depot] = base_depot + 0.05 * wt; + x[central] = base_central + 0.1 * renal; + x[peripheral] = base_peripheral + 0.02 * wt; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_ode_model() -> equation::ODE { + equation::ODE::new( + |x, p, t, dx, bolus, rateiv, cov| { + fetch_params!( + p, + ka, + ke, + kcp, + kpc, + _v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); + + dx[0] = bolus[0] - ka * x[0]; + dx[1] = + bolus[1] + ka * x[0] + rateiv[0] - (adjusted_ke + adjusted_kcp) * x[1] + kpc * x[2]; + dx[2] = adjusted_kcp * x[1] - kpc * x[2]; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + _f_oral, + base_depot, + base_central, + base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_depot + 0.05 * wt; + x[1] = base_central + 0.1 * renal; + x[2] = base_peripheral + 0.02 * wt; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(3) + .with_ndrugs(2) + .with_nout(1) + .with_metadata( + equation::metadata::new("ode_full_feature_parity") + .parameters([ + "ka", + "ke", + "kcp", + "kpc", + "v", + "tlag", + "f_oral", + "base_depot", + "base_central", + "base_peripheral", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["depot", "central", "peripheral"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + equation::Route::bolus("load") + .to_state("central") + .expect_explicit_input(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("handwritten ODE metadata should validate") +} + +fn build_ode_subject(oral: usize, load: usize, iv: usize, cp: usize) -> Subject { + Subject::builder("macro-vs-handwritten-ode-full-features") + .bolus(0.0, 80.0, load) + .bolus(1.0, 120.0, oral) + .infusion(6.0, 150.0, iv, 2.5) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + +fn macro_analytical_model() -> equation::Analytical { + analytical! { + name: "analytical_full_feature_parity", + params: [ka, ke, v, tlag, f_oral, base_gut, base_central, tvke], + covariates: [wt, renal], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + bolus(load) -> central, + infusion(iv) -> central, + }, + structure: one_compartment_with_absorption, + sec: |_t| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + ke = tvke * wt_scale * renal_scale; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[gut] = base_gut + 0.03 * wt; + x[central] = base_central + 0.08 * renal; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_analytical_model() -> equation::Analytical { + equation::Analytical::new( + equation::one_compartment_with_absorption, + |p, t, cov| { + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + p[1] = p[7] * wt_scale * renal_scale; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(2) + .with_ndrugs(2) + .with_nout(1) + .with_metadata( + equation::metadata::new("analytical_full_feature_parity") + .kind(equation::ModelKind::Analytical) + .parameters([ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .with_lag() + .with_bioavailability(), + equation::Route::bolus("load").to_state("central"), + equation::Route::infusion("iv").to_state("central"), + ]) + .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), + ) + .expect("handwritten analytical metadata should validate") +} + +fn build_analytical_subject(oral: usize, load: usize, iv: usize, cp: usize) -> Subject { + Subject::builder("macro-vs-handwritten-analytical-full-features") + .bolus(0.0, 60.0, load) + .bolus(1.0, 100.0, oral) + .infusion(6.0, 140.0, iv, 2.0) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + +#[test] +fn ode_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::PharmsolError> { + let macro_ode = macro_ode_model(); + let handwritten_ode = handwritten_ode_model(); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + + let oral = macro_ode.route_index("oral").expect("oral route exists"); + let load = macro_ode.route_index("load").expect("load route exists"); + let iv = macro_ode.route_index("iv").expect("iv route exists"); + let cp = macro_ode.output_index("cp").expect("cp output exists"); + + assert_eq!(oral, iv); + assert_eq!(load, 1); + assert_eq!(handwritten_ode.route_index("oral"), Some(oral)); + assert_eq!(handwritten_ode.route_index("load"), Some(load)); + assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); + assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); + + let subject = build_ode_subject(oral, load, iv, cp); + let params = [1.1, 0.18, 0.07, 0.04, 35.0, 0.6, 0.85, 4.0, 18.0, 9.0]; + + let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_ode.estimate_predictions(&subject, ¶ms)?; + + let diff = max_abs_diff( + ¯o_predictions.flat_predictions(), + &handwritten_predictions.flat_predictions(), + ); + assert!( + diff <= 1e-10, + "macro and handwritten ODE predictions diverged: {diff:e}" + ); + + Ok(()) +} + +#[test] +fn analytical_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::PharmsolError> { + let macro_analytical = macro_analytical_model(); + let handwritten_analytical = handwritten_analytical_model(); + + assert_eq!( + macro_analytical.metadata(), + handwritten_analytical.metadata() + ); + + let oral = macro_analytical + .route_index("oral") + .expect("oral route exists"); + let load = macro_analytical + .route_index("load") + .expect("load route exists"); + let iv = macro_analytical.route_index("iv").expect("iv route exists"); + let cp = macro_analytical + .output_index("cp") + .expect("cp output exists"); + + assert_eq!(oral, iv); + assert_eq!(load, 1); + assert_eq!(handwritten_analytical.route_index("oral"), Some(oral)); + assert_eq!(handwritten_analytical.route_index("load"), Some(load)); + assert_eq!(handwritten_analytical.route_index("iv"), Some(iv)); + assert_eq!(handwritten_analytical.output_index("cp"), Some(cp)); + + let subject = build_analytical_subject(oral, load, iv, cp); + let params = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; + + let macro_predictions = macro_analytical.estimate_predictions(&subject, ¶ms)?; + let handwritten_predictions = handwritten_analytical.estimate_predictions(&subject, ¶ms)?; + + let diff = max_abs_diff( + ¯o_predictions.flat_predictions(), + &handwritten_predictions.flat_predictions(), + ); + assert!( + diff <= 1e-10, + "macro and handwritten analytical predictions diverged: {diff:e}" + ); + + Ok(()) +} diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs new file mode 100644 index 00000000..7b068733 --- /dev/null +++ b/tests/ode_macro_lowering.rs @@ -0,0 +1,375 @@ +use approx::assert_relative_eq; +use pharmsol::prelude::*; + +fn subject_for_route(input: usize) -> Subject { + Subject::builder("macro-lowering") + .infusion(0.0, 100.0, input, 1.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .build() +} + +fn subject_for_shared_channel(input: usize) -> Subject { + Subject::builder("macro-shared-channel") + .bolus(0.0, 100.0, input) + .infusion(6.0, 60.0, input, 2.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .build() +} + +fn subject_for_covariates(input: usize) -> Subject { + Subject::builder("macro-covariates") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .covariate("wt", 0.0, 70.0) + .build() +} + +fn injected_macro_ode() -> equation::ODE { + ode! { + name: "injected_one_cpt", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _t, dx| { + dx[central] = -ke * x[central]; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn injected_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("injected_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ), + ) + .expect("handwritten injected metadata should validate") +} + +fn explicit_macro_ode() -> equation::ODE { + ode! { + name: "explicit_one_cpt", + params: [ke, v], + states: [central], + outputs: [cp], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _t, dx, _bolus, rateiv| { + dx[central] = rateiv[iv] - ke * x[central]; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn explicit_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("explicit_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ), + ) + .expect("handwritten explicit metadata should validate") +} + +fn shared_channel_macro_ode() -> equation::ODE { + ode! { + name: "shared_channel_one_cpt", + params: [ka, ke, v, tlag, f_oral], + states: [depot, central], + outputs: [cp], + routes: { + bolus(oral) -> depot, + infusion(iv) -> central, + }, + diffeq: |x, _t, dx, bolus, rateiv| { + dx[depot] = bolus[oral] - ka * x[depot]; + dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; + }, + lag: |_t| { + lag! { oral => tlag } + }, + fa: |_t| { + fa! { oral => f_oral } + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn shared_channel_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("shared_channel_one_cpt") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ]), + ) + .expect("handwritten shared-channel metadata should validate") +} + +fn covariate_macro_ode() -> equation::ODE { + ode! { + name: "covariate_one_cpt", + params: [ka, ke, v], + covariates: [wt], + states: [gut, central], + outputs: [cp], + routes: { + bolus(oral) -> gut, + }, + diffeq: |x, _t, dx| { + let scaled_ke = ke * (wt / 70.0); + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - scaled_ke * x[central]; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn covariate_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, t, dx, bolus, _rateiv, cov| { + fetch_cov!(cov, t, wt); + fetch_params!(p, ka, ke, _v); + let scaled_ke = ke * (wt / 70.0); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] - scaled_ke * x[1]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("covariate_one_cpt") + .parameters(["ka", "ke", "v"]) + .covariates([equation::Covariate::continuous("wt")]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination(), + ), + ) + .expect("handwritten covariate metadata should validate") +} + +fn assert_prediction_match(left: &[f64], right: &[f64]) { + assert_eq!(left.len(), right.len()); + for (left, right) in left.iter().zip(right.iter()) { + assert_relative_eq!(left, right, epsilon = 1e-10); + } +} + +#[test] +fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = injected_macro_ode(); + let handwritten_ode = injected_handwritten_ode(); + let subject = subject_for_route(0); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("iv"), Some(0)); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(0)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro injected model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten injected model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_explicit_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = explicit_macro_ode(); + let handwritten_ode = explicit_handwritten_ode(); + let subject = subject_for_route(0); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("iv"), Some(0)); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(0)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro explicit model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten explicit model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = shared_channel_macro_ode(); + let handwritten_ode = shared_channel_handwritten_ode(); + let subject = subject_for_shared_channel(0); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("oral"), Some(0)); + assert_eq!(macro_ode.route_index("iv"), Some(0)); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.state_index("depot"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(1)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro shared-channel model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten shared-channel model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_covariate_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = covariate_macro_ode(); + let handwritten_ode = covariate_handwritten_ode(); + let subject = subject_for_covariates(0); + let support_point = [1.0, 0.2, 10.0]; + let macro_metadata = macro_ode + .metadata() + .expect("macro covariate model should carry metadata"); + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_metadata.covariates().len(), 1); + assert_eq!(macro_ode.route_index("oral"), Some(0)); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.state_index("gut"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(1)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro covariate model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten covariate model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} diff --git a/tests/runtime_backend_matrix.rs b/tests/runtime_backend_matrix.rs index 6a207398..fdabc94d 100644 --- a/tests/runtime_backend_matrix.rs +++ b/tests/runtime_backend_matrix.rs @@ -84,6 +84,84 @@ mod tests { Ok(()) } + #[test] + fn analytical_full_runtime_backend_matrix_matches_reference_predictions( + ) -> Result<(), Box> { + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let workspace = super::runtime_corpus::ArtifactWorkspace::new()?; + + let jit = corpus::compile_runtime_jit_model(CorpusCase::AnalyticalFull)?; + assert_eq!(jit.backend(), RuntimeBackend::Jit); + corpus::assert_runtime_model_matches_reference( + CorpusCase::AnalyticalFull, + "runtime-jit", + &jit, + )?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let aot = corpus::compile_runtime_native_aot_model(CorpusCase::AnalyticalFull, &workspace)?; + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + assert_eq!(aot.backend(), RuntimeBackend::NativeAot); + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + corpus::assert_runtime_model_matches_reference( + CorpusCase::AnalyticalFull, + "runtime-native-aot", + &aot, + )?; + + let wasm = corpus::compile_runtime_wasm_model(CorpusCase::AnalyticalFull)?; + assert_eq!(wasm.backend(), RuntimeBackend::Wasm); + corpus::assert_runtime_model_matches_reference( + CorpusCase::AnalyticalFull, + "runtime-wasm", + &wasm, + )?; + corpus::assert_runtime_models_match_each_other( + CorpusCase::AnalyticalFull, + "runtime-jit", + &jit, + "runtime-wasm", + &wasm, + )?; + + Ok(()) + } + + #[test] + fn ode_full_runtime_backend_matrix_matches_reference_predictions( + ) -> Result<(), Box> { + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let workspace = super::runtime_corpus::ArtifactWorkspace::new()?; + + let jit = corpus::compile_runtime_jit_model(CorpusCase::OdeFull)?; + assert_eq!(jit.backend(), RuntimeBackend::Jit); + corpus::assert_runtime_model_matches_reference(CorpusCase::OdeFull, "runtime-jit", &jit)?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let aot = corpus::compile_runtime_native_aot_model(CorpusCase::OdeFull, &workspace)?; + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + assert_eq!(aot.backend(), RuntimeBackend::NativeAot); + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + corpus::assert_runtime_model_matches_reference( + CorpusCase::OdeFull, + "runtime-native-aot", + &aot, + )?; + + let wasm = corpus::compile_runtime_wasm_model(CorpusCase::OdeFull)?; + assert_eq!(wasm.backend(), RuntimeBackend::Wasm); + corpus::assert_runtime_model_matches_reference(CorpusCase::OdeFull, "runtime-wasm", &wasm)?; + corpus::assert_runtime_models_match_each_other( + CorpusCase::OdeFull, + "runtime-jit", + &jit, + "runtime-wasm", + &wasm, + )?; + + Ok(()) + } + #[test] fn sde_runtime_backend_matrix_matches_reference_predictions( ) -> Result<(), Box> { diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs new file mode 100644 index 00000000..876d2b23 --- /dev/null +++ b/tests/sde_macro_lowering.rs @@ -0,0 +1,593 @@ +use approx::assert_relative_eq; +use pharmsol::prelude::*; +use pharmsol::Predictions; + +fn infusion_subject(input: usize) -> Subject { + Subject::builder("sde-macro-iv") + .infusion(0.0, 120.0, input, 1.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .build() +} + +fn oral_subject(input: usize) -> Subject { + Subject::builder("sde-macro-oral") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .build() +} + +fn shared_channel_subject(input: usize) -> Subject { + Subject::builder("sde-macro-shared") + .bolus(0.0, 100.0, input) + .infusion(6.0, 60.0, input, 2.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .build() +} + +fn covariate_subject(oral: usize, iv: usize, cp: usize) -> Subject { + Subject::builder("sde-macro-covariates") + .bolus(1.0, 100.0, oral) + .infusion(6.0, 140.0, iv, 2.0) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() +} + +fn prediction_means(predictions: &ndarray::Array2) -> Vec { + predictions + .get_predictions() + .into_iter() + .map(|prediction| prediction.prediction()) + .collect() +} + +fn assert_prediction_match(left: &[f64], right: &[f64]) { + assert_eq!(left.len(), right.len()); + for (left, right) in left.iter().zip(right.iter()) { + assert_relative_eq!(left, right, epsilon = 1e-10); + } +} + +fn macro_infusion_sde() -> equation::SDE { + sde! { + name: "one_cpt_sde", + params: [ke, sigma_ke, v], + states: [central], + outputs: [cp], + particles: 16, + routes: { + infusion(iv) -> central, + }, + drift: |x, _t, dx| { + dx[central] = -ke * x[central]; + }, + diffusion: |sigma| { + sigma[central] = sigma_ke; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_infusion_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ke, _sigma_ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |p, sigma| { + fetch_params!(p, _ke, sigma_ke, _v); + sigma[0] = sigma_ke; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, _sigma_ke, v); + y[0] = x[0] / v; + }, + 16, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cpt_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ke", "sigma_ke", "v"]) + .states(["central"]) + .outputs(["cp"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ) + .particles(16), + ) + .expect("handwritten SDE metadata should validate") +} + +fn macro_absorption_sde() -> equation::SDE { + sde! { + name: "one_cmt_abs_sde", + params: [ka, ke, sigma_ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: { + bolus(oral) -> gut, + }, + drift: |x, _t, dx| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + diffusion: |sigma| { + sigma[gut] = 0.0 * sigma_ke; + sigma[central] = sigma_ke; + }, + lag: |_t| { + lag! { oral => tlag } + }, + fa: |_t| { + fa! { oral => f_oral } + }, + init: |_t, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_absorption_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, _rateiv, _cov| { + fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, sigma| { + fetch_params!(p, _ka, _ke, sigma_ke, _v, _tlag, _f_oral); + sigma[0] = 0.0 * sigma_ke; + sigma[1] = sigma_ke; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_abs_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .route( + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + ) + .particles(8), + ) + .expect("handwritten absorption SDE metadata should validate") +} + +fn macro_shared_channel_sde() -> equation::SDE { + sde! { + name: "one_cmt_shared_sde", + params: [ka, ke, sigma_ke, v, tlag, f_oral], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + drift: |x, _t, dx| { + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - ke * x[central]; + }, + diffusion: |sigma| { + sigma[gut] = 0.0; + sigma[central] = 0.0; + }, + lag: |_t| { + lag! { oral => tlag } + }, + fa: |_t| { + fa! { oral => f_oral } + }, + init: |_t, x| { + x[gut] = 0.0; + x[central] = 0.0; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + }, + } +} + +fn handwritten_shared_channel_sde() -> equation::SDE { + equation::SDE::new( + |x, p, _t, dx, rateiv, _cov| { + fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - ke * x[1]; + }, + |_p, sigma| { + sigma[0] = 0.0; + sigma[1] = 0.0; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _sigma_ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, x| { + x[0] = 0.0; + x[1] = 0.0; + }, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, _sigma_ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_shared_sde") + .kind(equation::ModelKind::Sde) + .parameters(["ka", "ke", "sigma_ke", "v", "tlag", "f_oral"]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]) + .particles(8), + ) + .expect("handwritten shared-channel SDE metadata should validate") +} + +fn macro_covariate_sde() -> equation::SDE { + sde! { + name: "one_cmt_sde_covariates", + params: [ka, ke, sigma_ke, v, tlag, f_oral, base_gut, base_central], + covariates: [wt, renal], + states: [gut, central], + outputs: [cp], + particles: 8, + routes: { + bolus(oral) -> gut, + infusion(iv) -> central, + }, + drift: |x, _t, dx| { + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + + dx[gut] = -ka * x[gut]; + dx[central] = ka * x[gut] - adjusted_ke * x[central]; + }, + diffusion: |sigma| { + sigma[gut] = 0.0 * sigma_ke; + sigma[central] = 0.0 * sigma_ke; + }, + lag: |_t| { + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { oral => tlag * lag_scale } + }, + fa: |_t| { + let fa_scale = (renal / 90.0).powf(0.1); + fa! { oral => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + init: |_t, x| { + x[gut] = base_gut + 0.03 * wt; + x[central] = base_central + 0.08 * renal; + }, + out: |x, _t, y| { + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[cp] = x[central] / adjusted_v; + }, + } +} + +fn handwritten_covariate_sde() -> equation::SDE { + equation::SDE::new( + |x, p, t, dx, rateiv, cov| { + fetch_params!( + p, + ka, + ke, + _sigma_ke, + _v, + _tlag, + _f_oral, + _base_gut, + _base_central + ); + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + + dx[0] = -ka * x[0]; + dx[1] = ka * x[0] + rateiv[0] - adjusted_ke * x[1]; + }, + |p, sigma| { + fetch_params!( + p, + _ka, + _ke, + sigma_ke, + _v, + _tlag, + _f_oral, + _base_gut, + _base_central + ); + sigma[0] = 0.0 * sigma_ke; + sigma[1] = 0.0 * sigma_ke; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + _sigma_ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + 8, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("one_cmt_sde_covariates") + .kind(equation::ModelKind::Sde) + .parameters([ + "ka", + "ke", + "sigma_ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + ]) + .covariates([ + equation::Covariate::continuous("wt"), + equation::Covariate::continuous("renal"), + ]) + .states(["gut", "central"]) + .outputs(["cp"]) + .routes([ + equation::Route::bolus("oral") + .to_state("gut") + .inject_input_to_destination() + .with_lag() + .with_bioavailability(), + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ]) + .particles(8), + ) + .expect("handwritten covariate SDE metadata should validate") +} + +#[test] +fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_infusion_sde(); + let handwritten_model = handwritten_infusion_sde(); + let subject = infusion_subject(0); + let support_point = [0.2, 0.0, 10.0]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("iv"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("central"), Some(0)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro SDE model should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten SDE model should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} + +#[test] +fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { + let macro_model = macro_absorption_sde(); + let handwritten_model = handwritten_absorption_sde(); + let subject = oral_subject(0); + let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("oral"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("gut"), Some(0)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro absorption SDE should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten absorption SDE should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} + +#[test] +fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_channel_sde(); + let handwritten_model = handwritten_shared_channel_sde(); + let subject = shared_channel_subject(0); + let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + assert_eq!(macro_model.route_index("oral"), Some(0)); + assert_eq!(macro_model.route_index("iv"), Some(0)); + assert_eq!(macro_model.output_index("cp"), Some(0)); + assert_eq!(macro_model.state_index("gut"), Some(0)); + assert_eq!(macro_model.state_index("central"), Some(1)); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro shared-channel SDE should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten shared-channel SDE should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} + +#[test] +fn sde_macro_covariates_lower_to_handwritten_behavior() { + let macro_model = macro_covariate_sde(); + let handwritten_model = handwritten_covariate_sde(); + + assert_eq!(macro_model.metadata(), handwritten_model.metadata()); + + let oral = macro_model.route_index("oral").expect("oral route exists"); + let iv = macro_model.route_index("iv").expect("iv route exists"); + let cp = macro_model.output_index("cp").expect("cp output exists"); + let subject = covariate_subject(oral, iv, cp); + let support_point = [1.0, 0.16, 0.0, 32.0, 0.5, 0.8, 3.0, 14.0]; + + assert_eq!(oral, iv); + + let macro_predictions = macro_model + .estimate_predictions(&subject, &support_point) + .expect("macro covariate SDE should simulate"); + let handwritten_predictions = handwritten_model + .estimate_predictions(&subject, &support_point) + .expect("handwritten covariate SDE should simulate"); + + assert_prediction_match( + &prediction_means(¯o_predictions), + &prediction_means(&handwritten_predictions), + ); +} diff --git a/tests/support/bimodal_ke.rs b/tests/support/bimodal_ke.rs index fe363aa7..4c82be4f 100644 --- a/tests/support/bimodal_ke.rs +++ b/tests/support/bimodal_ke.rs @@ -12,7 +12,7 @@ pub const OBSERVATION_TIMES: [f64; 7] = [0.5, 1.0, 2.0, 3.0, 4.0, 6.0, 8.0]; pub const SUPPORT_POINT: [f64; 2] = [1.2, 50.0]; pub const AUTHORING_DSL: &str = r#" -model = bimodal_ke +name = bimodal_ke kind = ode params = ke, v diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index 0a9917a0..3ed75511 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -19,7 +19,7 @@ use pharmsol::{equation, fa, fetch_cov, fetch_params, lag, Subject, SubjectBuild use tempfile::{tempdir, TempDir}; const ODE_SOURCE: &str = r#" -model = one_cmt_oral_iv +name = one_cmt_oral_iv kind = ode params = ka, cl, v, tlag, f_oral @@ -43,8 +43,40 @@ dx(central) = ka * depot - ke * central out(cp) = central / v ~ continuous() "#; +const ODE_FULL_SOURCE: &str = r#" +name = ode_full_feature_parity +kind = ode + +params = ka, ke, kcp, kpc, v, tlag, f_oral, base_depot, base_central, base_peripheral +covariates = wt@linear, renal@linear +derived = adjusted_ke, adjusted_kcp, adjusted_v +states = depot, central, peripheral +outputs = cp + +bolus(oral) -> depot +bolus(load) -> central +infusion(iv) -> central + +lag(oral) = tlag * sqrt(wt / 70.0) * pow(90.0 / renal, 0.1) +fa(oral) = min(max(f_oral * pow(renal / 90.0, 0.1), 0.0), 1.0) + +adjusted_ke = ke * pow(wt / 70.0, 0.75) * pow(renal / 90.0, 0.25) +adjusted_kcp = kcp * pow(wt / 70.0, 0.25) +adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)) + +dx(depot) = -ka * depot +dx(central) = ka * depot - (adjusted_ke + adjusted_kcp) * central + kpc * peripheral +dx(peripheral) = adjusted_kcp * central - kpc * peripheral + +init(depot) = base_depot + 0.05 * wt +init(central) = base_central + 0.1 * renal +init(peripheral) = base_peripheral + 0.02 * wt + +out(cp) = central / adjusted_v ~ continuous() +"#; + const ANALYTICAL_SOURCE: &str = r#" -model = one_cmt_abs +name = one_cmt_abs kind = analytical params = ka, ke, v, tlag, f_oral @@ -56,13 +88,41 @@ bolus(oral) -> depot lag(oral) = tlag fa(oral) = f_oral -kernel = one_compartment_with_absorption +structure = one_compartment_with_absorption out(cp) = central / v ~ continuous() "#; +const ANALYTICAL_FULL_SOURCE: &str = r#" +name = analytical_full_feature_parity +kind = analytical + +params = ka, ke, v, tlag, f_oral, base_gut, base_central, tvke +covariates = wt@linear, renal@linear +derived = ka_proj, ke_proj +states = gut, central +outputs = cp + +bolus(oral) -> gut +bolus(load) -> central +infusion(iv) -> central + +lag(oral) = tlag * sqrt(wt / 70.0) * pow(90.0 / renal, 0.1) +fa(oral) = min(max(f_oral * pow(renal / 90.0, 0.1), 0.0), 1.0) + +ka_proj = ka +ke_proj = tvke * pow(wt / 70.0, 0.75) * pow(renal / 90.0, 0.25) + +structure = one_compartment_with_absorption + +init(gut) = base_gut + 0.03 * wt +init(central) = base_central + 0.08 * renal + +out(cp) = central / (v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0))) ~ continuous() +"#; + const SDE_SOURCE: &str = r#" -model = vanco_sde +name = vanco_sde kind = sde params = ka, ke0, kcp, kpc, vol, ske @@ -90,7 +150,9 @@ pub const SDE_PARTICLE_COUNT: usize = 16; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CorpusCase { Ode, + OdeFull, Analytical, + AnalyticalFull, Sde, } @@ -98,7 +160,9 @@ impl CorpusCase { pub fn label(self) -> &'static str { match self { Self::Ode => "dsl-ode-one_cmt_oral_iv", + Self::OdeFull => "dsl-ode-full-feature-parity", Self::Analytical => "dsl-analytical-one_cmt_abs", + Self::AnalyticalFull => "dsl-analytical-full-feature-parity", Self::Sde => "dsl-sde-vanco_sde", } } @@ -106,7 +170,9 @@ impl CorpusCase { pub fn model_name(self) -> &'static str { match self { Self::Ode => "one_cmt_oral_iv", + Self::OdeFull => "ode_full_feature_parity", Self::Analytical => "one_cmt_abs", + Self::AnalyticalFull => "analytical_full_feature_parity", Self::Sde => "vanco_sde", } } @@ -114,7 +180,9 @@ impl CorpusCase { fn source(self) -> &'static str { match self { Self::Ode => ODE_SOURCE, + Self::OdeFull => ODE_FULL_SOURCE, Self::Analytical => ANALYTICAL_SOURCE, + Self::AnalyticalFull => ANALYTICAL_FULL_SOURCE, Self::Sde => SDE_SOURCE, } } @@ -122,7 +190,9 @@ impl CorpusCase { pub fn tolerance(self) -> f64 { match self { Self::Ode => 1e-4, + Self::OdeFull => 1e-4, Self::Analytical => 1e-8, + Self::AnalyticalFull => 1e-8, Self::Sde => 1e-4, } } @@ -130,7 +200,9 @@ impl CorpusCase { pub fn support_point(self) -> &'static [f64] { match self { Self::Ode => &[1.2, 5.0, 40.0, 0.5, 0.8], + Self::OdeFull => &[1.1, 0.18, 0.07, 0.04, 35.0, 0.6, 0.85, 4.0, 18.0, 9.0], Self::Analytical => &[1.0, 0.15, 25.0, 0.5, 0.8], + Self::AnalyticalFull => &[1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16], Self::Sde => &[1.1, 0.2, 0.12, 0.08, 15.0, 0.0], } } @@ -160,6 +232,34 @@ impl CorpusCase { .missing_observation(9.0, cp) .build() } + Self::OdeFull => { + let oral = model.route_index("oral").ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + let load = model.route_index("load").ok_or_else(|| { + io::Error::other(format!("{}: missing load route", self.label())) + })?; + let iv = model.route_index("iv").ok_or_else(|| { + io::Error::other(format!("{}: missing iv route", self.label())) + })?; + Subject::builder(self.label()) + .bolus(0.0, 80.0, load) + .bolus(1.0, 120.0, oral) + .infusion(6.0, 150.0, iv, 2.5) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() + } Self::Analytical => { let oral = model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) @@ -172,6 +272,34 @@ impl CorpusCase { .missing_observation(4.0, cp) .build() } + Self::AnalyticalFull => { + let oral = model.route_index("oral").ok_or_else(|| { + io::Error::other(format!("{}: missing oral route", self.label())) + })?; + let load = model.route_index("load").ok_or_else(|| { + io::Error::other(format!("{}: missing load route", self.label())) + })?; + let iv = model.route_index("iv").ok_or_else(|| { + io::Error::other(format!("{}: missing iv route", self.label())) + })?; + Subject::builder(self.label()) + .bolus(0.0, 60.0, load) + .bolus(1.0, 100.0, oral) + .infusion(6.0, 140.0, iv, 2.0) + .missing_observation(0.25, cp) + .missing_observation(0.75, cp) + .missing_observation(1.5, cp) + .missing_observation(3.0, cp) + .missing_observation(6.5, cp) + .missing_observation(7.0, cp) + .missing_observation(8.0, cp) + .missing_observation(12.0, cp) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build() + } Self::Sde => { let oral = model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) @@ -193,9 +321,15 @@ impl CorpusCase { fn reference_predictions(self) -> Result> { match self { Self::Ode => Ok(ExpectedPredictions::Subject(reference_ode_predictions()?)), + Self::OdeFull => Ok(ExpectedPredictions::Subject( + reference_ode_full_predictions()?, + )), Self::Analytical => Ok(ExpectedPredictions::Subject( reference_analytical_predictions()?, )), + Self::AnalyticalFull => Ok(ExpectedPredictions::Subject( + reference_analytical_full_predictions()?, + )), Self::Sde => Ok(ExpectedPredictions::Particles(reference_sde_predictions()?)), } } @@ -605,6 +739,137 @@ fn reference_ode_predictions() -> Result> { )?) } +fn reference_ode_full_predictions() -> Result> { + Ok(equation::ODE::new( + |x, p, t, dx, bolus, rateiv, cov| { + fetch_params!( + p, + ka, + ke, + kcp, + kpc, + _v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + let adjusted_ke = ke * wt_scale * renal_scale; + let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); + + dx[0] = bolus[0] - ka * x[0]; + dx[1] = + bolus[1] + ka * x[0] + rateiv[0] - (adjusted_ke + adjusted_kcp) * x[1] + kpc * x[2]; + dx[2] = adjusted_kcp * x[1] - kpc * x[2]; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + _v, + _tlag, + _f_oral, + base_depot, + base_central, + base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_depot + 0.05 * wt; + x[1] = base_central + 0.1 * renal; + x[2] = base_peripheral + 0.02 * wt; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + _kcp, + _kpc, + v, + _tlag, + _f_oral, + _base_depot, + _base_central, + _base_peripheral + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(3) + .with_ndrugs(2) + .with_nout(1) + .estimate_predictions( + &Subject::builder(CorpusCase::OdeFull.label()) + .bolus(0.0, 80.0, 1) + .bolus(1.0, 120.0, 0) + .infusion(6.0, 150.0, 0, 2.5) + .missing_observation(0.25, 0) + .missing_observation(0.75, 0) + .missing_observation(1.5, 0) + .missing_observation(3.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .missing_observation(12.0, 0) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build(), + CorpusCase::OdeFull.support_point(), + )?) +} + fn reference_analytical_predictions() -> Result> { Ok(equation::Analytical::new( one_compartment_with_absorption, @@ -638,6 +903,110 @@ fn reference_analytical_predictions() -> Result Result> { + Ok(equation::Analytical::new( + equation::one_compartment_with_absorption, + |p, t, cov| { + fetch_cov!(cov, t, wt, renal); + + let wt_scale = (wt / 70.0).powf(0.75); + let renal_scale = (renal / 90.0).powf(0.25); + p[1] = p[7] * wt_scale * renal_scale; + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let lag_scale = (wt / 70.0).sqrt() * (90.0 / renal).powf(0.1); + lag! { 0 => tlag * lag_scale } + }, + |p, t, cov| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let fa_scale = (renal / 90.0).powf(0.1); + fa! { 0 => (f_oral * fa_scale).clamp(0.0, 1.0) } + }, + |p, t, cov, x| { + fetch_params!( + p, + _ka, + _ke, + _v, + _tlag, + _f_oral, + base_gut, + base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + x[0] = base_gut + 0.03 * wt; + x[1] = base_central + 0.08 * renal; + }, + |x, p, t, cov, y| { + fetch_params!( + p, + _ka, + _ke, + v, + _tlag, + _f_oral, + _base_gut, + _base_central, + _tvke + ); + fetch_cov!(cov, t, wt, renal); + + let adjusted_v = v * (wt / 70.0) * (1.0 + 0.001 * (renal - 90.0)); + y[0] = x[1] / adjusted_v; + }, + ) + .with_nstates(2) + .with_ndrugs(2) + .with_nout(1) + .estimate_predictions( + &Subject::builder(CorpusCase::AnalyticalFull.label()) + .bolus(0.0, 60.0, 1) + .bolus(1.0, 100.0, 0) + .infusion(6.0, 140.0, 0, 2.0) + .missing_observation(0.25, 0) + .missing_observation(0.75, 0) + .missing_observation(1.5, 0) + .missing_observation(3.0, 0) + .missing_observation(6.5, 0) + .missing_observation(7.0, 0) + .missing_observation(8.0, 0) + .missing_observation(12.0, 0) + .covariate("wt", 0.0, 68.0) + .covariate("wt", 8.0, 74.0) + .covariate("renal", 0.0, 95.0) + .covariate("renal", 8.0, 72.0) + .build(), + CorpusCase::AnalyticalFull.support_point(), + )?) +} + fn reference_sde_predictions() -> Result, Box> { Ok(SDE::new( |x, p, _t, dx, _rateiv, _cov| {