diff --git a/README.md b/README.md index d24aea87..73932de2 100644 --- a/README.md +++ b/README.md @@ -27,24 +27,21 @@ let analytical = analytical! { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; }, }; -let iv = analytical.route_index("iv").unwrap(); -let cp = analytical.output_index("cp").unwrap(); - let subject = Subject::builder("patient_001") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = analytical @@ -64,9 +61,9 @@ let ode = ode! { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, @@ -121,12 +118,12 @@ use pharmsol::prelude::*; use pharmsol::nca::NCAOptions; let subject = Subject::builder("patient_001") - .bolus(0.0, 100.0, 0) // 100 mg oral dose - .observation(0.5, 5.0, 0) - .observation(1.0, 10.0, 0) - .observation(2.0, 8.0, 0) - .observation(4.0, 4.0, 0) - .observation(8.0, 2.0, 0) + .bolus(0.0, 100.0, "oral") // 100 mg oral dose + .observation(0.5, 5.0, "cp") + .observation(1.0, 10.0, "cp") + .observation(2.0, 8.0, "cp") + .observation(4.0, 4.0, "cp") + .observation(8.0, 2.0, "cp") .build(); let result = subject.nca(&NCAOptions::default()).expect("NCA failed"); diff --git a/examples/analytical_readme.rs b/examples/analytical_readme.rs index 8e5b97f7..8451b478 100644 --- a/examples/analytical_readme.rs +++ b/examples/analytical_readme.rs @@ -6,24 +6,21 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; }, }; - let iv = analytical.route_index("iv").expect("iv route exists"); - let cp = analytical.output_index("cp").expect("cp output exists"); - let subject = Subject::builder("analytical_readme") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = analytical.estimate_predictions(&subject, &[1.022, 194.0])?; diff --git a/examples/analytical_vs_ode.rs b/examples/analytical_vs_ode.rs index 290d6632..3fd58fd1 100644 --- a/examples/analytical_vs_ode.rs +++ b/examples/analytical_vs_ode.rs @@ -72,9 +72,9 @@ fn one_cmt_iv(params: &[f64]) { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -86,9 +86,9 @@ fn one_cmt_iv(params: &[f64]) { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, @@ -114,9 +114,9 @@ fn one_cmt_oral(params: &[f64]) { params: [ka, ke, v], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], structure: one_compartment_with_absorption, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -128,9 +128,9 @@ fn one_cmt_oral(params: &[f64]) { params: [ka, ke, v], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -157,9 +157,9 @@ fn two_cmt_iv(params: &[f64]) { params: [ke, k12, k21, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: two_compartments, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -171,9 +171,9 @@ fn two_cmt_iv(params: &[f64]) { params: [ke, k12, k21, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central] - k12 * x[central] + k21 * x[peripheral]; dx[peripheral] = k12 * x[central] - k21 * x[peripheral]; @@ -200,9 +200,9 @@ fn two_cmt_oral(params: &[f64]) { params: [ka, ke, k12, k21, v], states: [gut, central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], structure: two_compartments_with_absorption, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -214,9 +214,9 @@ fn two_cmt_oral(params: &[f64]) { params: [ka, ke, k12, k21, v], states: [gut, central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central] - k12 * x[central] + k21 * x[peripheral]; diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index ebec4caa..5d8fdbb6 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -24,10 +24,10 @@ fn two_cpt(solver: OdeSolver) -> equation::ODE { params: [ke, kcp, kpc, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(load) -> central, infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; @@ -48,32 +48,24 @@ fn main() { let trbdf2 = two_cpt(OdeSolver::Sdirk(SdirkTableau::TrBdf2)); let esdirk34 = two_cpt(OdeSolver::Sdirk(SdirkTableau::Esdirk34)); - // Both declarations resolve to the same shared input channel, so subject + // Both declarations resolve to the same shared input, so subject // authoring still uses one numeric index for the loading bolus and the // maintenance infusion. - let load = bdf.route_index("load").expect("load route exists"); - let iv = bdf.route_index("iv").expect("iv route exists"); - let cp = bdf.output_index("cp").expect("cp output exists"); - - assert_eq!( - load, iv, - "mixed IV declarations should share one numeric channel" - ); let subject = Subject::builder("id1") - .bolus(0.0, 100.0, iv) - .infusion(12.0, 200.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) - .missing_observation(12.5, cp) - .missing_observation(13.0, cp) - .missing_observation(14.0, cp) - .missing_observation(16.0, cp) - .missing_observation(24.0, cp) + .bolus(0.0, 100.0, "load") + .infusion(12.0, 200.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .missing_observation(12.5, "cp") + .missing_observation(13.0, "cp") + .missing_observation(14.0, "cp") + .missing_observation(16.0, "cp") + .missing_observation(24.0, "cp") .build(); let spp = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V diff --git a/examples/covariates.rs b/examples/covariates.rs index 9aabf491..180a0173 100644 --- a/examples/covariates.rs +++ b/examples/covariates.rs @@ -7,9 +7,9 @@ fn main() { covariates: [creatinine, age], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _t, dx| { let scaled_ke = ke * (creatinine / 75.0).powf(0.75) * (age / 25.0).powf(0.5); diff --git a/examples/dsl_runtime_jit.rs b/examples/dsl_runtime_jit.rs index 932acaae..3f7d1efe 100644 --- a/examples/dsl_runtime_jit.rs +++ b/examples/dsl_runtime_jit.rs @@ -43,24 +43,16 @@ out(cp) = central / v on_compile_event, )?; - // 2. Resolve the route and output indices declared by the model. - let iv = model - .route_index("iv") - .ok_or_else(|| io::Error::other("missing iv route"))?; - let cp = model - .output_index("cp") - .ok_or_else(|| io::Error::other("missing cp output"))?; - // 3. Define the subject data. let subject = Subject::builder("bimodal_ke") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(3.0, cp) - .missing_observation(4.0, cp) - .missing_observation(6.0, cp) - .missing_observation(8.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(3.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(8.0, "cp") .build(); // 4. Estimate predictions for one support point. diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs index 4d8f74d0..c7b088a5 100644 --- a/examples/macro_vs_handwritten_one_cpt.rs +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -12,9 +12,9 @@ fn macro_model() -> equation::ODE { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, @@ -26,6 +26,9 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( + // Handwritten closures stay on dense internal slots. + // Public labels like `iv` and `cp` live in attached metadata, not in + // the low-level `rateiv[]` / `y[]` buffers. |x, p, _t, dx, _bolus, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = rateiv[0] - ke * x[0]; @@ -75,12 +78,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); let subject = Subject::builder("macro-vs-handwritten-one-cpt") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) - .missing_observation(8.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") .build(); let params = [1.022, 194.0]; diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs index 9ab1a675..377e1e88 100644 --- a/examples/macro_vs_handwritten_two_cpt.rs +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -1,5 +1,5 @@ //! Compares a declaration-first macro ODE with the equivalent handwritten ODE -//! on a two-compartment IV problem that shares one numeric input channel across +//! on a two-compartment IV problem that shares one numeric input across //! a loading bolus and a maintenance infusion. //! //! This keeps the macro story as the default surface while showing the @@ -9,14 +9,14 @@ use pharmsol::prelude::*; fn macro_model() -> equation::ODE { ode! { - name: "two_cpt_shared_channel_parity", + name: "two_cpt_shared_input_parity", params: [ke, kcp, kpc, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(load) -> central, infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; @@ -29,6 +29,10 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( + // Handwritten closures stay on dense internal slots. + // Public route labels like `load` and `iv` are metadata names; the + // low-level `bolus[]`, `rateiv[]`, and `y[]` buffers remain indexed by + // dense internal slots. |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ke, kcp, kpc, _v); dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0] + bolus[0]; @@ -46,7 +50,7 @@ fn handwritten_model() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("two_cpt_shared_channel_parity") + equation::metadata::new("two_cpt_shared_input_parity") .parameters(["ke", "kcp", "kpc", "v"]) .states(["central", "peripheral"]) .outputs(["cp"]) @@ -79,28 +83,25 @@ fn main() -> Result<(), pharmsol::PharmsolError> { let iv = macro_ode.route_index("iv").expect("iv route exists"); let cp = macro_ode.output_index("cp").expect("cp output exists"); - assert_eq!( - load, iv, - "load and iv should share one numeric input channel" - ); + assert_eq!(load, iv, "load and iv should share one numeric input"); assert_eq!(handwritten_ode.route_index("load"), Some(load)); assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); let subject = Subject::builder("macro-vs-handwritten-two-cpt") - .bolus(0.0, 100.0, load) - .infusion(12.0, 200.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) - .missing_observation(12.5, cp) - .missing_observation(13.0, cp) - .missing_observation(14.0, cp) - .missing_observation(16.0, cp) - .missing_observation(24.0, cp) + .bolus(0.0, 100.0, "load") + .infusion(12.0, 200.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .missing_observation(12.5, "cp") + .missing_observation(13.0, "cp") + .missing_observation(14.0, "cp") + .missing_observation(16.0, "cp") + .missing_observation(24.0, "cp") .build(); let params = [0.1, 0.05, 0.03, 50.0]; diff --git a/examples/ode_readme.rs b/examples/ode_readme.rs index a0174801..2989895f 100644 --- a/examples/ode_readme.rs +++ b/examples/ode_readme.rs @@ -6,9 +6,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, @@ -17,15 +17,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { }, }; - let iv = ode.route_index("iv").expect("iv route exists"); - let cp = ode.output_index("cp").expect("cp output exists"); - let subject = Subject::builder("id1") - .infusion(0.0, 100.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 100.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = ode.estimate_predictions(&subject, &[1.022, 194.0])?; diff --git a/examples/one_compartment.rs b/examples/one_compartment.rs index aafdf2b2..021e06f2 100644 --- a/examples/one_compartment.rs +++ b/examples/one_compartment.rs @@ -6,9 +6,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -20,9 +20,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/sde_readme.rs b/examples/sde_readme.rs index 6106b17a..97b5fed4 100644 --- a/examples/sde_readme.rs +++ b/examples/sde_readme.rs @@ -7,9 +7,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { states: [central], outputs: [cp], particles: 16, - routes: { + routes: [ infusion(iv) -> central, - }, + ], drift: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, @@ -21,15 +21,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { }, }; - let iv = sde.route_index("iv").expect("iv route exists"); - let cp = sde.output_index("cp").expect("cp output exists"); - let subject = Subject::builder("sde_readme") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = sde.estimate_predictions(&subject, &[1.022, 0.0, 194.0])?; diff --git a/examples/two_compartment.rs b/examples/two_compartment.rs index 64d554af..fdba715e 100644 --- a/examples/two_compartment.rs +++ b/examples/two_compartment.rs @@ -27,9 +27,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { covariates: [wt], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _t, dx| { // CL: Clearance (L/hr), V: Central volume (L) // Vp: Peripheral volume (L), Q: Inter-compartmental clearance (L/hr) diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index c81c19eb..0496c0fc 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -20,6 +20,7 @@ struct AuthoringParser<'a> { states: Vec, declared_derived: BTreeSet, declared_outputs: BTreeSet, + explicit_output_order: Vec, explicit_outputs: BTreeMap, assigned_outputs: BTreeMap, declared_outputs_span: Option, @@ -77,6 +78,7 @@ impl<'a> AuthoringParser<'a> { states: Vec::new(), declared_derived: BTreeSet::new(), declared_outputs: BTreeSet::new(), + explicit_output_order: Vec::new(), explicit_outputs: BTreeMap::new(), assigned_outputs: BTreeMap::new(), declared_outputs_span: None, @@ -175,6 +177,20 @@ impl<'a> AuthoringParser<'a> { )); } + if !self.explicit_output_order.is_empty() { + let output_order = self + .explicit_output_order + .iter() + .enumerate() + .map(|(index, name)| (name.clone(), index)) + .collect::>(); + self.output_statements.sort_by_key(|statement| { + output_statement_name(statement) + .and_then(|name| output_order.get(name).copied()) + .unwrap_or(usize::MAX) + }); + } + let mut derivative_statements = std::mem::take(&mut self.derivative_statements); inject_infusion_rates(&surface_routes, &routes, &mut derivative_statements); @@ -371,7 +387,8 @@ impl<'a> AuthoringParser<'a> { if lhs_trimmed == "outputs" { self.declared_outputs_span = Some(span); - for ident in parse_ident_list(rhs, rhs_abs)? { + for ident in parse_output_label_list(rhs, rhs_abs)? { + self.explicit_output_order.push(ident.text.clone()); self.declared_outputs.insert(ident.text.clone()); self.explicit_outputs.insert(ident.text, ident.span); } @@ -413,7 +430,20 @@ impl<'a> AuthoringParser<'a> { return self.parse_call_assignment(call, rhs, rhs_abs, span); } - let target = parse_ident_segment(lhs, lhs_abs)?; + let target = match parse_ident_segment(lhs, lhs_abs) { + Ok(target) => target, + Err(error) => { + if self.declared_outputs_span.is_none() { + return Err(error); + } + + let target = parse_output_label_segment(lhs, lhs_abs)?; + if !self.declared_outputs.contains(&target.text) { + return Err(self.undeclared_output_error(&target.text, target.span)); + } + target + } + }; let rhs = parse_surface_rhs(rhs, rhs_abs)?; let stmt = build_assignment_statement( AssignTarget { @@ -454,7 +484,7 @@ impl<'a> AuthoringParser<'a> { } }; - let input = parse_ident_segment(call.argument, call.argument_start)?; + let input = parse_label_segment(call.argument, call.argument_start, "route label")?; let route_name = input.text.clone(); let destination = parse_place_at(rhs, line_start + arrow + 2)?; if self.routes.contains_key(&route_name) { @@ -485,7 +515,8 @@ impl<'a> AuthoringParser<'a> { ) -> Result<(), ParseError> { match call.callee.text.as_str() { "lag" | "fa" => { - let route_name = parse_ident_segment(call.argument, call.argument_start)?; + let route_name = + parse_label_segment(call.argument, call.argument_start, "route label")?; let value = parse_expr_at(rhs, rhs_abs)?; let property_name = match call.callee.text.as_str() { "lag" => "lag", @@ -552,7 +583,7 @@ impl<'a> AuthoringParser<'a> { self.init_statements.push(stmt); } "out" => { - let output = parse_ident_segment(call.argument, call.argument_start)?; + let output = parse_output_label_segment(call.argument, call.argument_start)?; self.validate_output_target(&output)?; self.declared_outputs.insert(output.text.clone()); self.note_output_assignment(&output); @@ -839,6 +870,13 @@ fn parse_ident_list(src: &str, abs_start: usize) -> Result, ParseErro .collect() } +fn parse_output_label_list(src: &str, abs_start: usize) -> Result, ParseError> { + split_top_level(src, ',') + .into_iter() + .map(|(segment, start)| parse_output_label_segment(segment, abs_start + start)) + .collect() +} + fn parse_covariates_list(src: &str, abs_start: usize) -> Result, ParseError> { let mut covariates = Vec::new(); for (segment, start) in split_top_level(src, ',') { @@ -907,6 +945,31 @@ fn parse_ident_segment(src: &str, abs_start: usize) -> Result )) } +fn parse_output_label_segment(src: &str, abs_start: usize) -> Result { + parse_label_segment(src, abs_start, "output label") +} + +fn parse_label_segment(src: &str, abs_start: usize, expected: &str) -> Result { + let trimmed = src.trim(); + let leading = src.len() - src.trim_start().len(); + if trimmed.is_empty() { + return Err(ParseError::new( + format!("expected {expected}"), + Span::new(abs_start, abs_start + src.len()), + )); + } + if !is_valid_output_label(trimmed) { + return Err(ParseError::new( + format!("expected {expected}, found `{trimmed}`"), + Span::new(abs_start + leading, abs_start + leading + trimmed.len()), + )); + } + Ok(Ident::new( + trimmed, + Span::new(abs_start + leading, abs_start + leading + trimmed.len()), + )) +} + fn parse_place_at(src: &str, abs_start: usize) -> Result { let mut place = parse_place_fragment(src).map_err(|error| error.shifted(abs_start))?; shift_place(&mut place, abs_start); @@ -1344,6 +1407,10 @@ fn is_valid_ident(src: &str) -> bool { chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_') } +fn is_valid_output_label(src: &str) -> bool { + is_valid_ident(src) || src.chars().all(|ch| ch.is_ascii_digit()) +} + fn is_ident_byte(byte: u8) -> bool { (byte as char).is_ascii_alphanumeric() || byte == b'_' } @@ -1372,6 +1439,16 @@ fn join_covariate_spans(items: &[CovariateDecl]) -> Span { .unwrap_or_else(|| Span::empty(0)) } +fn output_statement_name(statement: &Stmt) -> Option<&str> { + match &statement.kind { + StmtKind::Assign(assign) => match &assign.target.kind { + AssignTargetKind::Name(name) => Some(name.text.as_str()), + _ => None, + }, + _ => None, + } +} + fn join_state_spans(items: &[StateDecl]) -> Span { items .iter() diff --git a/pharmsol-dsl/src/execution.rs b/pharmsol-dsl/src/execution.rs index 886d570a..8bac1d69 100644 --- a/pharmsol-dsl/src/execution.rs +++ b/pharmsol-dsl/src/execution.rs @@ -1516,7 +1516,7 @@ mod tests { } #[test] - fn authoring_routes_share_channel_indices_by_kind_local_ordinal() { + fn authoring_routes_share_input_indices_by_kind_local_ordinal() { let src = r#"name = shared_authoring kind = ode diff --git a/pharmsol-dsl/src/parser.rs b/pharmsol-dsl/src/parser.rs index f07fbd50..98c6b0a4 100644 --- a/pharmsol-dsl/src/parser.rs +++ b/pharmsol-dsl/src/parser.rs @@ -106,12 +106,18 @@ struct Parser { #[derive(Clone, Copy)] enum LayoutBoundary { ModelItem, - Statement, + Statement(StatementContext), Binding, IdentItem, RouteDecl, } +#[derive(Clone, Copy, PartialEq, Eq)] +enum StatementContext { + Standard, + Outputs, +} + impl Parser { fn new(src: &str) -> Result { Ok(Self::from_tokens(lex(src)?, src.len())) @@ -557,7 +563,7 @@ impl Parser { } fn parse_route_decl(&mut self) -> Result { - let input = self.parse_ident()?; + let input = self.parse_label_name("route label")?; let arrow = self.expect_simple(|kind| matches!(kind, TokenKind::Arrow), "`->`")?; self.ensure_not_layout_boundary( arrow.span, @@ -655,8 +661,13 @@ impl Parser { fn parse_statement_block(&mut self, name: &str) -> Result { let start = self.bump().unwrap().span; let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; + let statement_context = if name == "outputs" { + StatementContext::Outputs + } else { + StatementContext::Standard + }; let (statements, mut errors) = - self.with_layout_boundary(LayoutBoundary::Statement, |parser| { + self.with_layout_boundary(LayoutBoundary::Statement(statement_context), |parser| { let mut statements = Vec::new(); let mut errors = Vec::new(); while !parser.is_eof() && !parser.at(|kind| matches!(kind, TokenKind::RBrace)) { @@ -790,8 +801,9 @@ impl Parser { fn parse_stmt_body(&mut self) -> Result, ParseError> { let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; + let statement_context = self.current_statement_context(); let (statements, mut errors) = - self.with_layout_boundary(LayoutBoundary::Statement, |parser| { + self.with_layout_boundary(LayoutBoundary::Statement(statement_context), |parser| { let mut statements = Vec::new(); let mut errors = Vec::new(); while !parser.is_eof() && !parser.at(|kind| matches!(kind, TokenKind::RBrace)) { @@ -854,7 +866,11 @@ impl Parser { } fn parse_assign_target(&mut self) -> Result { - let name = self.parse_ident()?; + let name = if matches!(self.current_statement_context(), StatementContext::Outputs) { + self.parse_output_target_name()? + } else { + self.parse_ident()? + }; let mut span = name.span; let kind = if let Some(open) = self.take_if(|kind| matches!(kind, TokenKind::LParen)) { let args = self.parse_expr_list(&open, TokenKindMatcher::RPAREN)?; @@ -885,6 +901,34 @@ impl Parser { Ok(AssignTarget { kind, span }) } + fn parse_output_target_name(&mut self) -> Result { + self.parse_label_name("output label") + } + + fn parse_label_name(&mut self, expected: &str) -> Result { + let token = self.bump().ok_or_else(|| { + ParseError::new(format!("expected {expected}"), Span::empty(self.src_len)) + })?; + match token.kind { + TokenKind::Ident(name) => Ok(Ident::new(name, token.span)), + TokenKind::Number(value) + if value.is_finite() + && value >= 0.0 + && value.fract() == 0.0 + && value <= usize::MAX as f64 => + { + Ok(Ident::new((value as usize).to_string(), token.span)) + } + other => Err(ParseError::new( + format!( + "expected {expected} identifier or non-negative integer, found {}", + other.describe() + ), + token.span, + )), + } + } + fn parse_ident(&mut self) -> Result { let token = self .bump() @@ -1320,9 +1364,12 @@ impl Parser { | TokenKind::Diffusion | TokenKind::Particles ), - LayoutBoundary::Statement => match &token.kind { + LayoutBoundary::Statement(context) => match &token.kind { TokenKind::If | TokenKind::For | TokenKind::Let => true, TokenKind::Ident(_) => self.line_starts_assignment_target(index), + TokenKind::Number(_) if matches!(context, StatementContext::Outputs) => { + self.line_starts_numeric_output_assignment(index) + } _ => false, }, LayoutBoundary::Binding => self.line_starts_named_assignment(index), @@ -1379,6 +1426,26 @@ impl Parser { } } + fn line_starts_numeric_output_assignment(&self, index: usize) -> bool { + matches!( + self.tokens.get(index).map(|token| &token.kind), + Some(TokenKind::Number(_)) + ) && self + .next_same_line_index(index) + .is_some_and(|next| matches!(self.tokens[next].kind, TokenKind::Eq)) + } + + fn current_statement_context(&self) -> StatementContext { + self.layout_boundaries + .iter() + .rev() + .find_map(|boundary| match boundary { + LayoutBoundary::Statement(context) => Some(*context), + _ => None, + }) + .unwrap_or(StatementContext::Standard) + } + fn next_same_line_index(&self, index: usize) -> Option { let next = index + 1; let token = self.tokens.get(next)?; @@ -1413,7 +1480,7 @@ impl Parser { fn current_boundary_label(&self) -> &'static str { match self.current_layout_boundary() { Some(LayoutBoundary::ModelItem) => "next model item starts here", - Some(LayoutBoundary::Statement) => "next statement starts here", + Some(LayoutBoundary::Statement(_)) => "next statement starts here", Some(LayoutBoundary::Binding) => "next binding starts here", Some(LayoutBoundary::IdentItem) => "next declaration starts here", Some(LayoutBoundary::RouteDecl) => "next route starts here", @@ -1424,7 +1491,7 @@ impl Parser { fn current_boundary_subject(&self) -> &'static str { match self.current_layout_boundary() { Some(LayoutBoundary::ModelItem) => "model item", - Some(LayoutBoundary::Statement) => "statement", + Some(LayoutBoundary::Statement(_)) => "statement", Some(LayoutBoundary::Binding) => "binding", Some(LayoutBoundary::IdentItem) => "declaration", Some(LayoutBoundary::RouteDecl) => "route", diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index ac9223dd..f20288a9 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -1617,29 +1617,32 @@ impl<'a> Analyzer<'a> { span, ))); } - if let Some(existing) = self.globals.all_names.get(name) { - return Err(SemanticAssist::default() - .context_label( - self.symbol_span(*existing), - self.symbol_declared_here(*existing), - ) - .help(format!( - "rename this declaration to a unique name such as `{}_2`", - name - )) - .replacement_suggestion( - span, - format!("{}_2", name), - format!("rename this declaration to `{}_2`", name), - Applicability::MaybeIncorrect, - ) - .apply(SemanticError::new( - format!( - "symbol name `{name}` collides with existing `{}`", - self.symbol_name(*existing) - ), - span, - ))); + if let Some(existing) = self.globals.all_names.get(name).copied() { + let existing_kind = self.symbols.get(existing).expect("valid symbol id").kind; + if !allows_route_output_name_overlap(existing_kind, kind) { + return Err(SemanticAssist::default() + .context_label( + self.symbol_span(existing), + self.symbol_declared_here(existing), + ) + .help(format!( + "rename this declaration to a unique name such as `{}_2`", + name + )) + .replacement_suggestion( + span, + format!("{}_2", name), + format!("rename this declaration to `{}_2`", name), + Applicability::MaybeIncorrect, + ) + .apply(SemanticError::new( + format!( + "symbol name `{name}` collides with existing `{}`", + self.symbol_name(existing) + ), + span, + ))); + } } let id = self.symbols.len(); self.symbols.push(PendingSymbol { @@ -1649,7 +1652,7 @@ impl<'a> Analyzer<'a> { ty, span, }); - self.globals.all_names.insert(name.to_string(), id); + self.globals.all_names.entry(name.to_string()).or_insert(id); Ok(id) } @@ -2132,6 +2135,13 @@ impl<'a> Analyzer<'a> { } } +fn allows_route_output_name_overlap(existing: SymbolKind, new: SymbolKind) -> bool { + matches!( + (existing, new), + (SymbolKind::Route, SymbolKind::Output) | (SymbolKind::Output, SymbolKind::Route) + ) +} + #[derive(Default)] struct Globals { all_names: BTreeMap, diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 797be3e9..4d1651f5 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -1,4 +1,4 @@ -use pharmsol_dsl::{analyze_model, parse_model, parse_module}; +use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model, parse_module}; #[test] fn output_annotation_is_optional() { @@ -161,6 +161,126 @@ out(cp) = central ~ continous() ); } +#[test] +fn mixed_named_and_numeric_output_labels_lower_and_round_trip() { + let src = r#" +name = mixed_output_labels +kind = ode +params = ke, v +states = central +outputs = cp, 0, 1 +infusion(iv) -> central +ddt(central) = -ke * central +out(cp) = central / v +out(0) = 2 * central / v +out(1) = 3 * central / v +"#; + + let module = parse_module(src).expect("mixed output labels should parse in authoring DSL"); + let model = module + .models + .first() + .expect("authoring DSL should produce one model"); + let typed = analyze_model(&model).expect("mixed output labels should analyze"); + let lowered = lower_typed_model(&typed).expect("mixed output labels should lower"); + + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp", "0", "1"] + ); + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.index) + .collect::>(), + vec![0, 1, 2] + ); + + let rendered = module.to_string(); + let reparsed = parse_module(&rendered).expect("rendered mixed-output model should reparse"); + + assert_eq!(rendered, reparsed.to_string()); +} + +#[test] +fn shared_numeric_route_and_output_labels_lower_and_round_trip() { + let src = r#" +name = shared_numeric_route_output_labels +kind = ode +params = ke, v +states = central +outputs = 1 +infusion(1) -> central +ddt(central) = -ke * central +out(1) = central / v +"#; + + let module = parse_module(src).expect("shared numeric route/output labels should parse"); + let model = module + .models + .first() + .expect("authoring DSL should produce one model"); + let typed = analyze_model(model).expect("shared numeric route/output labels should analyze"); + let lowered = + lower_typed_model(&typed).expect("shared numeric route/output labels should lower"); + + assert_eq!( + lowered + .metadata + .routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["1"] + ); + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["1"] + ); + + let rendered = module.to_string(); + let reparsed = parse_module(&rendered).expect("rendered shared-label model should reparse"); + + assert_eq!(rendered, reparsed.to_string()); +} + +#[test] +fn route_labels_still_collide_with_scalar_symbol_names() { + let src = r#" +name = route_state_collision +kind = ode +params = ke +states = central, iv +outputs = cp +infusion(iv) -> central +ddt(central) = -ke * central +ddt(iv) = 0 +out(cp) = central +"#; + + let model = parse_model(src).expect("route/state collision model parses"); + let err = analyze_model(&model).expect_err("route label should still collide with state name"); + let rendered = err.render(src); + + assert!( + rendered.contains("symbol name `iv` collides with existing `iv`"), + "{}", + rendered + ); +} + #[test] fn unknown_route_destination_state_suggests_declared_state() { let src = r#" diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 54a79fe3..7e483951 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -6,13 +6,14 @@ use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use syn::{ parse::{Parse, ParseStream, Parser}, punctuated::Punctuated, token, visit::Visit, - Expr, ExprClosure, Ident, LitStr, Pat, Stmt, Token, + visit_mut::VisitMut, + Expr, ExprClosure, Ident, Lit, LitInt, LitStr, Pat, Stmt, Token, }; // --------------------------------------------------------------------------- @@ -24,9 +25,8 @@ struct OdeInput { params: Vec, covariates: Vec, states: Vec, - outputs: Vec, + outputs: Vec, routes: Vec, - diffeq_mode: OdeDiffeqMode, diffeq: ExprClosure, lag: Option, fa: Option, @@ -39,7 +39,7 @@ struct AnalyticalInput { params: Vec, covariates: Vec, states: Vec, - outputs: Vec, + outputs: Vec, routes: Vec, structure: Ident, sec: Option, @@ -54,7 +54,7 @@ struct SdeInput { params: Vec, covariates: Vec, states: Vec, - outputs: Vec, + outputs: Vec, routes: Vec, particles: Expr, drift: ExprClosure, @@ -65,15 +65,9 @@ struct SdeInput { out: ExprClosure, } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum OdeDiffeqMode { - InjectedRouteInputs, - ExplicitRouteVectors, -} - struct OdeRouteDecl { kind: OdeRouteKind, - input: Ident, + input: SymbolicIndex, destination: Ident, } @@ -91,10 +85,78 @@ struct AnalyticalKernelSpec { } struct RoutePropertyEntry { - route: Ident, + route: SymbolicIndex, value: Expr, } +#[derive(Clone)] +enum SymbolicIndex { + Ident(Ident), + Int(LitInt), +} + +impl SymbolicIndex { + fn name(&self) -> String { + match self { + Self::Ident(ident) => ident.to_string(), + Self::Int(lit) => lit.base10_digits().to_string(), + } + } + + fn ident(&self) -> Option<&Ident> { + match self { + Self::Ident(ident) => Some(ident), + Self::Int(_) => None, + } + } + + fn numeric_value(&self) -> Option { + match self { + Self::Ident(_) => None, + Self::Int(lit) => Some( + lit.base10_parse::() + .expect("validated numeric label should fit usize"), + ), + } + } + + fn numeric(value: usize) -> Self { + Self::Int(LitInt::new(&value.to_string(), Span::call_site())) + } +} + +impl Parse for SymbolicIndex { + fn parse(input: ParseStream) -> syn::Result { + if input.peek(LitInt) { + let lit: LitInt = input.parse()?; + lit.base10_parse::().map_err(|_| { + syn::Error::new_spanned( + &lit, + "numeric declaration-first labels must be non-negative base-10 integers that fit in usize", + ) + })?; + Ok(Self::Int(lit)) + } else { + Ok(Self::Ident(input.parse()?)) + } + } +} + +impl ToTokens for SymbolicIndex { + fn to_tokens(&self, tokens: &mut TokenStream2) { + match self { + Self::Ident(ident) => ident.to_tokens(tokens), + Self::Int(lit) => lit.to_tokens(tokens), + } + } +} + +impl std::fmt::Display for SymbolicIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.name()) + } +} + impl Parse for OdeRouteDecl { fn parse(input: ParseStream) -> syn::Result { let kind_ident: Ident = input.parse()?; @@ -111,7 +173,7 @@ impl Parse for OdeRouteDecl { let content; syn::parenthesized!(content in input); - let route_input: Ident = content.parse()?; + let route_input: SymbolicIndex = content.parse()?; if !content.is_empty() { return Err(content.error("expected a single route input name inside `(...)`")); } @@ -166,7 +228,12 @@ impl Parse for OdeInput { "covariates", )?, "states" => set_once_ode(&mut states, parse_ident_list(input)?, &key, "states")?, - "outputs" => set_once_ode(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, + "outputs" => set_once_ode( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, "routes" => set_once_ode(&mut routes, parse_route_list(input)?, &key, "routes")?, "diffeq" => set_once_ode(&mut diffeq, input.parse()?, &key, "diffeq")?, "lag" => set_once_ode(&mut lag, input.parse()?, &key, "lag")?, @@ -201,19 +268,21 @@ impl Parse for OdeInput { let routes = routes.ok_or_else(|| missing_required_ode_field("routes"))?; let diffeq = diffeq.ok_or_else(|| missing_required_ode_field("diffeq"))?; let out = out.ok_or_else(|| missing_required_ode_field("out"))?; - let diffeq_mode = classify_diffeq_mode(&diffeq, &routes)?; + validate_ode_diffeq_uses_automatic_injection(&diffeq, &routes)?; validate_unique_idents("parameter", ¶ms, "ode!")?; validate_unique_idents("covariate", &covariates, "ode!")?; validate_unique_idents("state", &states, "ode!")?; - validate_unique_idents("output", &outputs, "ode!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "ode!")?; validate_routes(&routes, &states, "ode!")?; validate_named_binding_compatibility( NamedBindingSets { params: ¶ms, covariates: &covariates, states: &states, - outputs: &outputs, + outputs: &output_idents, routes: &routes, }, OdeBindingClosures { @@ -224,7 +293,6 @@ impl Parse for OdeInput { init: init.as_ref(), out: &out, }, - diffeq_mode, }, )?; @@ -235,7 +303,6 @@ impl Parse for OdeInput { states, outputs, routes, - diffeq_mode, diffeq, lag, fa, @@ -247,7 +314,7 @@ impl Parse for OdeInput { impl Parse for RoutePropertyEntry { fn parse(input: ParseStream) -> syn::Result { - let route: Ident = input.parse()?; + let route: SymbolicIndex = input.parse()?; input.parse::]>()?; let value: Expr = input.parse()?; Ok(Self { route, value }) @@ -287,9 +354,12 @@ impl Parse for AnalyticalInput { "states" => { set_once_analytical(&mut states, parse_ident_list(input)?, &key, "states")? } - "outputs" => { - set_once_analytical(&mut outputs, parse_ident_list(input)?, &key, "outputs")? - } + "outputs" => set_once_analytical( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, "routes" => { set_once_analytical(&mut routes, parse_route_list(input)?, &key, "routes")? } @@ -328,7 +398,9 @@ impl Parse for AnalyticalInput { validate_unique_idents("parameter", ¶ms, "analytical!")?; validate_unique_idents("covariate", &covariates, "analytical!")?; validate_unique_idents("state", &states, "analytical!")?; - validate_unique_idents("output", &outputs, "analytical!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "analytical!")?; validate_routes(&routes, &states, "analytical!")?; let kernel_spec = resolve_analytical_structure(&structure)?; @@ -358,7 +430,7 @@ impl Parse for AnalyticalInput { params: ¶ms, covariates: &covariates, states: &states, - outputs: &outputs, + outputs: &output_idents, routes: &routes, }, AnalyticalBindingClosures { @@ -431,7 +503,12 @@ impl Parse for SdeInput { "covariates", )?, "states" => set_once_sde(&mut states, parse_ident_list(input)?, &key, "states")?, - "outputs" => set_once_sde(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, + "outputs" => set_once_sde( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, "routes" => set_once_sde(&mut routes, parse_route_list(input)?, &key, "routes")?, "particles" => set_once_sde(&mut particles, input.parse()?, &key, "particles")?, "drift" => set_once_sde(&mut drift, input.parse()?, &key, "drift")?, @@ -469,14 +546,16 @@ impl Parse for SdeInput { validate_unique_idents("parameter", ¶ms, "sde!")?; validate_unique_idents("covariate", &covariates, "sde!")?; validate_unique_idents("state", &states, "sde!")?; - validate_unique_idents("output", &outputs, "sde!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "sde!")?; validate_routes(&routes, &states, "sde!")?; validate_sde_named_binding_compatibility( NamedBindingSets { params: ¶ms, covariates: &covariates, states: &states, - outputs: &outputs, + outputs: &output_idents, routes: &routes, }, SdeBindingClosures { @@ -595,9 +674,29 @@ fn parse_ident_list(input: ParseStream) -> syn::Result> { .collect()) } +fn parse_symbolic_index_list(input: ParseStream) -> syn::Result> { + let content; + syn::bracketed!(content in input); + Ok( + Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(), + ) +} + fn parse_route_list(input: ParseStream) -> syn::Result> { + if input.peek(token::Brace) { + return Err(input.error("declaration-first macro `routes` must use `[...]`, not `{...}`")); + } + + if !input.peek(token::Bracket) { + return Err( + input.error("expected a bracketed route list like `routes: [infusion(iv) -> central]`") + ); + } + let content; - syn::braced!(content in input); + syn::bracketed!(content in input); Ok( Punctuated::::parse_terminated(&content)? .into_iter() @@ -627,6 +726,29 @@ fn generated_ident(name: &str) -> Ident { Ident::new(name, Span::call_site()) } +fn symbolic_index_idents(labels: &[SymbolicIndex]) -> Vec { + labels + .iter() + .filter_map(|label| label.ident().cloned()) + .collect() +} + +fn symbolic_index_bindings(labels: &[SymbolicIndex]) -> Vec<(SymbolicIndex, usize)> { + labels + .iter() + .cloned() + .enumerate() + .map(|(index, label)| (label, index)) + .collect() +} + +fn symbolic_numeric_binding_map(bindings: &[(SymbolicIndex, usize)]) -> HashMap { + bindings + .iter() + .filter_map(|(label, index)| label.numeric_value().map(|value| (value, *index))) + .collect() +} + #[derive(Default)] struct ClosureBodyUsage { idents: HashSet, @@ -713,6 +835,124 @@ impl<'ast> Visit<'ast> for ClosureBodyUsage { } } +struct IndexRewriteTarget { + container: Ident, + labels: HashMap, +} + +impl IndexRewriteTarget { + fn new(container: Ident, labels: HashMap) -> Self { + Self { container, labels } + } +} + +struct NumericLabelRewriter { + index_targets: Vec, + route_labels: Option>, +} + +impl NumericLabelRewriter { + fn rewrite( + expr: &Expr, + index_targets: Vec, + route_labels: Option>, + ) -> Expr { + let mut rewritten = expr.clone(); + let mut rewriter = Self { + index_targets, + route_labels, + }; + rewriter.visit_expr_mut(&mut rewritten); + rewritten + } + + fn target_labels(&self, path: &syn::ExprPath) -> Option<&HashMap> { + if path.qself.is_some() + || path.path.leading_colon.is_some() + || path.path.segments.len() != 1 + { + return None; + } + + let ident = &path.path.segments[0].ident; + self.index_targets + .iter() + .find(|target| target.container == *ident) + .map(|target| &target.labels) + } + + fn rewrite_route_macro(&self, mac: &mut syn::Macro) { + let Some(route_labels) = self.route_labels.as_ref() else { + return; + }; + if !(mac.path.is_ident("lag") || mac.path.is_ident("fa")) { + return; + } + + let Ok(entries) = Punctuated::::parse_terminated + .parse2(mac.tokens.clone()) + else { + return; + }; + + let entries = entries.into_iter().map(|mut entry| { + if let Some(value) = entry.route.numeric_value() { + if let Some(internal_index) = route_labels.get(&value) { + entry.route = SymbolicIndex::numeric(*internal_index); + } + } + entry + }); + + let tokens = entries.map(|entry| { + let route = entry.route; + let value = entry.value; + quote! { #route => #value } + }); + mac.tokens = quote! { #(#tokens),* }; + } +} + +impl VisitMut for NumericLabelRewriter { + fn visit_expr_index_mut(&mut self, expr_index: &mut syn::ExprIndex) { + syn::visit_mut::visit_expr_index_mut(self, expr_index); + + let Expr::Path(expr_path) = expr_index.expr.as_ref() else { + return; + }; + let Some(labels) = self.target_labels(expr_path) else { + return; + }; + let Expr::Lit(expr_lit) = expr_index.index.as_ref() else { + return; + }; + let Lit::Int(lit) = &expr_lit.lit else { + return; + }; + let Ok(external_index) = lit.base10_parse::() else { + return; + }; + let Some(internal_index) = labels.get(&external_index) else { + return; + }; + + expr_index.index = Box::new(Expr::Lit(syn::ExprLit { + attrs: Vec::new(), + lit: Lit::Int(LitInt::new(&internal_index.to_string(), lit.span())), + })); + } + + fn visit_expr_macro_mut(&mut self, expr_macro: &mut syn::ExprMacro) { + self.rewrite_route_macro(&mut expr_macro.mac); + syn::visit_mut::visit_expr_macro_mut(self, expr_macro); + } + + fn visit_stmt_macro_mut(&mut self, stmt_macro: &mut syn::StmtMacro) { + self.rewrite_route_macro(&mut stmt_macro.mac); + syn::visit_mut::visit_stmt_macro_mut(self, stmt_macro); + } +} + fn generate_closure_input_aliases( closure: &ExprClosure, internal_names: &[Ident], @@ -824,13 +1064,12 @@ fn generate_covariate_bindings( } } -fn classify_diffeq_mode( +fn validate_ode_diffeq_uses_automatic_injection( diffeq: &ExprClosure, routes: &[OdeRouteDecl], -) -> syn::Result { +) -> syn::Result<()> { match closure_param_names(diffeq).len() { - 3 => Ok(OdeDiffeqMode::InjectedRouteInputs), - 7 => Ok(OdeDiffeqMode::ExplicitRouteVectors), + 3 => Ok(()), 5 => { let usage = ClosureBodyUsage::analyze(diffeq.body.as_ref()); let route_inputs = route_input_idents(routes); @@ -843,23 +1082,33 @@ fn classify_diffeq_mode( .is_some_and(|ident| usage.indexes(ident) && !usage.assigns_index(ident)); if mentions_route_inputs || indexes_fifth_param || reads_fourth_param_as_input { - Ok(OdeDiffeqMode::ExplicitRouteVectors) + Err(syn::Error::new_spanned( + diffeq, + "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx| and remove manual `bolus[...]` / `rateiv[...]` terms", + )) } else { - Ok(OdeDiffeqMode::InjectedRouteInputs) + Ok(()) } } _ => Err(syn::Error::new_spanned( diffeq, - "declaration-first `ode!` requires `diffeq` to have either 3 parameters: |x, t, dx|, 5 parameters: |x, p, t, dx, cov| or |x, t, dx, bolus, rateiv|, or 7 parameters: |x, p, t, dx, bolus, rateiv, cov|", + "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", )), } } fn route_input_idents(routes: &[OdeRouteDecl]) -> Vec { - routes.iter().map(|route| route.input.clone()).collect() + routes + .iter() + .filter_map(|route| route.input.ident().cloned()) + .collect() +} + +fn route_input_names(routes: &[OdeRouteDecl]) -> Vec { + routes.iter().map(|route| route.input.name()).collect() } -fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(Ident, usize)> { +fn ode_route_input_bindings(routes: &[OdeRouteDecl]) -> Vec<(SymbolicIndex, usize)> { let mut next_bolus_index = 0usize; let mut next_infusion_index = 0usize; @@ -883,7 +1132,7 @@ fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(Ident, usize)> { .collect() } -fn dense_index_len(bindings: &[(Ident, usize)]) -> usize { +fn dense_index_len(bindings: &[(SymbolicIndex, usize)]) -> usize { bindings .iter() .map(|(_, index)| index + 1) @@ -968,7 +1217,6 @@ struct AnalyticalBindingClosures<'a> { struct OdeBindingClosures<'a> { diffeq: &'a ExprClosure, common: CommonBindingClosures<'a>, - diffeq_mode: OdeDiffeqMode, } #[derive(Clone, Copy)] @@ -992,7 +1240,6 @@ fn validate_named_binding_compatibility( let OdeBindingClosures { diffeq, common: CommonBindingClosures { lag, fa, init, out }, - diffeq_mode, } = closures; let route_inputs = route_input_idents(routes); @@ -1043,31 +1290,6 @@ fn validate_named_binding_compatibility( validate_closure_param_conflicts("diffeq", diffeq, covariates, "covariate")?; validate_closure_param_conflicts("diffeq", diffeq, states, "state")?; - if diffeq_mode == OdeDiffeqMode::ExplicitRouteVectors { - validate_binding_conflicts( - "parameter", - params, - "route", - &route_inputs, - "`diffeq` named binding generation", - )?; - validate_binding_conflicts( - "state", - states, - "route", - &route_inputs, - "`diffeq` named binding generation", - )?; - validate_binding_conflicts( - "covariate", - covariates, - "route", - &route_inputs, - "`diffeq` named binding generation", - )?; - validate_closure_param_conflicts("diffeq", diffeq, &route_inputs, "route")?; - } - if let Some(lag) = lag { validate_binding_conflicts( "covariate", @@ -1361,12 +1583,14 @@ fn generate_index_consts(idents: &[Ident]) -> TokenStream2 { } } -fn generate_mapped_index_consts(bindings: &[(Ident, usize)]) -> TokenStream2 { - let bindings = bindings.iter().map(|(ident, index)| { - quote! { - #[allow(non_upper_case_globals, dead_code)] - const #ident: usize = #index; - } +fn generate_mapped_index_consts(bindings: &[(SymbolicIndex, usize)]) -> TokenStream2 { + let bindings = bindings.iter().filter_map(|(label, index)| { + label.ident().map(|ident| { + quote! { + #[allow(non_upper_case_globals, dead_code)] + const #ident: usize = #index; + } + }) }); quote! { @@ -1379,10 +1603,11 @@ fn expand_out( params: &[Ident], covariates: &[Ident], states: &[Ident], - outputs: &[Ident], + outputs: &[SymbolicIndex], ) -> syn::Result { let state_consts = generate_index_consts(states); - let output_consts = generate_index_consts(outputs); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); let x = generated_ident("__pharmsol_x"); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); @@ -1397,7 +1622,19 @@ fn expand_out( )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); - let body = &out.body; + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); Ok(quote! {{ let __pharmsol_out: fn( @@ -1480,14 +1717,13 @@ fn extract_route_property_routes( let macro_expr = find_terminal_macro_invocation(macro_name, label, closure)?; let entries = Punctuated::::parse_terminated .parse2(macro_expr.tokens.clone())?; - let known_routes = route_input_idents(routes) + let known_routes = route_input_names(routes) .into_iter() - .map(|route| route.to_string()) .collect::>(); let mut seen = HashSet::new(); for entry in entries { - let route_name = entry.route.to_string(); + let route_name = entry.route.name(); if !known_routes.contains(&route_name) { return Err(syn::Error::new_spanned( &entry.route, @@ -1515,7 +1751,7 @@ fn validate_route_property_kinds( property_routes: &HashSet, ) -> syn::Result<()> { for route in routes { - if property_routes.contains(&route.input.to_string()) + if property_routes.contains(&route.input.name()) && matches!(route.kind, OdeRouteKind::Infusion) { return Err(syn::Error::new_spanned( @@ -1536,7 +1772,7 @@ fn expand_ode_route_map( closure: &ExprClosure, params: &[Ident], covariates: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); @@ -1553,7 +1789,11 @@ fn expand_ode_route_map( )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); - let body = &closure.body; + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); Ok(quote! {{ let __pharmsol_route_map: fn( @@ -1617,7 +1857,6 @@ fn expand_ode_init( fn expand_route_metadata( routes: &[OdeRouteDecl], - diffeq_mode: OdeDiffeqMode, lag_routes: &HashSet, fa_routes: &HashSet, ) -> Vec { @@ -1626,7 +1865,7 @@ fn expand_route_metadata( .map(|route| { let input = &route.input; let destination = &route.destination; - let route_name = route.input.to_string(); + let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } @@ -1635,10 +1874,6 @@ fn expand_route_metadata( quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } } }; - let input_policy = match diffeq_mode { - OdeDiffeqMode::InjectedRouteInputs => quote! { .inject_input_to_destination() }, - OdeDiffeqMode::ExplicitRouteVectors => quote! { .expect_explicit_input() }, - }; let lag_flag = if lag_routes.contains(&route_name) { quote! { .with_lag() } } else { @@ -1655,7 +1890,7 @@ fn expand_route_metadata( .to_state(stringify!(#destination)) #lag_flag #fa_flag - #input_policy + .inject_input_to_destination() } }) .collect() @@ -1671,7 +1906,7 @@ fn expand_analytical_route_metadata( .map(|route| { let input = &route.input; let destination = &route.destination; - let route_name = route.input.to_string(); + let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } @@ -1711,7 +1946,7 @@ fn expand_sde_route_metadata( .map(|route| { let input = &route.input; let destination = &route.destination; - let route_name = route.input.to_string(); + let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } @@ -1752,7 +1987,7 @@ fn route_destination_index(route: &OdeRouteDecl, states: &[Ident]) -> usize { fn expand_injected_ode_route_terms( routes: &[OdeRouteDecl], states: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], dx: &Ident, bolus: &Ident, rateiv: &Ident, @@ -1760,14 +1995,14 @@ fn expand_injected_ode_route_terms( let terms = routes .iter() .zip(route_bindings.iter()) - .map(|(route, (_, channel_index))| { + .map(|(route, (_, input_index))| { let destination = route_destination_index(route, states); match route.kind { OdeRouteKind::Bolus => quote! { - #dx[#destination] += #bolus[#channel_index]; + #dx[#destination] += #bolus[#input_index]; }, OdeRouteKind::Infusion => quote! { - #dx[#destination] += #rateiv[#channel_index]; + #dx[#destination] += #rateiv[#input_index]; }, } }); @@ -1780,23 +2015,22 @@ fn expand_injected_ode_route_terms( fn expand_injected_sde_rate_terms( routes: &[OdeRouteDecl], states: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], dx: &Ident, rateiv: &Ident, ) -> TokenStream2 { - let terms = - routes - .iter() - .zip(route_bindings.iter()) - .filter_map(|(route, (_, channel_index))| match route.kind { - OdeRouteKind::Bolus => None, - OdeRouteKind::Infusion => { - let destination = route_destination_index(route, states); - Some(quote! { - #dx[#destination] += #rateiv[#channel_index]; - }) - } - }); + let terms = routes + .iter() + .zip(route_bindings.iter()) + .filter_map(|(route, (_, input_index))| match route.kind { + OdeRouteKind::Bolus => None, + OdeRouteKind::Infusion => { + let destination = route_destination_index(route, states); + Some(quote! { + #dx[#destination] += #rateiv[#input_index]; + }) + } + }); quote! { #(#terms)* @@ -1806,14 +2040,14 @@ fn expand_injected_sde_rate_terms( fn expand_injected_sde_bolus_mappings( routes: &[OdeRouteDecl], states: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> TokenStream2 { let mut destinations = vec![quote! { None }; dense_index_len(route_bindings)]; - for (route, (_, channel_index)) in routes.iter().zip(route_bindings.iter()) { + for (route, (_, input_index)) in routes.iter().zip(route_bindings.iter()) { if let OdeRouteKind::Bolus = route.kind { let destination = route_destination_index(route, states); - destinations[*channel_index] = quote! { Some(#destination) }; + destinations[*input_index] = quote! { Some(#destination) }; } } @@ -1836,12 +2070,30 @@ fn validate_unique_idents(kind: &str, idents: &[Ident], macro_name: &str) -> syn Ok(()) } +fn validate_unique_symbolic_indices( + kind: &str, + labels: &[SymbolicIndex], + macro_name: &str, +) -> syn::Result<()> { + let mut seen = HashSet::new(); + for label in labels { + let name = label.name(); + if !seen.insert(name.clone()) { + return Err(syn::Error::new_spanned( + label, + format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"), + )); + } + } + Ok(()) +} + fn validate_routes(routes: &[OdeRouteDecl], states: &[Ident], macro_name: &str) -> syn::Result<()> { let known_states = states.iter().map(Ident::to_string).collect::>(); let mut seen_routes = HashSet::new(); for route in routes { - let route_name = route.input.to_string(); + let route_name = route.input.name(); if !seen_routes.insert(route_name.clone()) { return Err(syn::Error::new_spanned( &route.input, @@ -1869,131 +2121,65 @@ fn expand_diffeq( covariates: &[Ident], states: &[Ident], routes: &[OdeRouteDecl], - route_bindings: &[(Ident, usize)], - diffeq_mode: OdeDiffeqMode, + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let state_consts = generate_index_consts(states); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let bolus = generated_ident("__pharmsol_bolus"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; + let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; + let input_aliases = generate_supported_input_aliases( + diffeq, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", + )?; + let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); + let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); + let body = &diffeq.body; + let dx_binding = if diffeq.inputs.len() == full_inputs.len() { + closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()) + } else { + closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone()) + }; + let route_terms = expand_injected_ode_route_terms( + routes, + states, + route_bindings, + &dx_binding, + &bolus, + &rateiv, + ); - match diffeq_mode { - OdeDiffeqMode::ExplicitRouteVectors => { - let route_consts = generate_mapped_index_consts(route_bindings); - let x = generated_ident("__pharmsol_x"); - let p = generated_ident("__pharmsol_p"); - let t = generated_ident("__pharmsol_t"); - let dx = generated_ident("__pharmsol_dx"); - let bolus = generated_ident("__pharmsol_bolus"); - let rateiv = generated_ident("__pharmsol_rateiv"); - let cov = generated_ident("__pharmsol_cov"); - let full_inputs = [ - x.clone(), - p.clone(), - t.clone(), - dx.clone(), - bolus.clone(), - rateiv.clone(), - cov.clone(), - ]; - let reduced_inputs = [ - x.clone(), - t.clone(), - dx.clone(), - bolus.clone(), - rateiv.clone(), - ]; - let input_aliases = generate_supported_input_aliases( - diffeq, - &[&full_inputs, &reduced_inputs], - "declaration-first `ode!` explicit-route `diffeq` requires either 7 parameters: |x, p, t, dx, bolus, rateiv, cov| or 5 parameters: |x, t, dx, bolus, rateiv|", - )?; - let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); - let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); - let body = &diffeq.body; - - Ok(quote! {{ - let __pharmsol_diffeq: fn( - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - f64, - &mut ::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::data::Covariates, - ) = |#x: &::pharmsol::simulator::V, - #p: &::pharmsol::simulator::V, - #t: f64, - #dx: &mut ::pharmsol::simulator::V, - #bolus: &::pharmsol::simulator::V, - #rateiv: &::pharmsol::simulator::V, - #cov: &::pharmsol::data::Covariates| { - #input_aliases - #state_consts - #route_consts - #parameter_bindings - #covariate_bindings - #body - }; - __pharmsol_diffeq - }}) - } - OdeDiffeqMode::InjectedRouteInputs => { - let x = generated_ident("__pharmsol_x"); - let p = generated_ident("__pharmsol_p"); - let t = generated_ident("__pharmsol_t"); - let dx = generated_ident("__pharmsol_dx"); - let bolus = generated_ident("__pharmsol_bolus"); - let rateiv = generated_ident("__pharmsol_rateiv"); - let cov = generated_ident("__pharmsol_cov"); - let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; - let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; - let input_aliases = generate_supported_input_aliases( - diffeq, - &[&full_inputs, &reduced_inputs], - "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", - )?; - let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); - let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); - let body = &diffeq.body; - let dx_binding = if diffeq.inputs.len() == full_inputs.len() { - closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()) - } else { - closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone()) - }; - let route_terms = expand_injected_ode_route_terms( - routes, - states, - route_bindings, - &dx_binding, - &bolus, - &rateiv, - ); - - Ok(quote! {{ - let __pharmsol_diffeq: fn( - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - f64, - &mut ::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::data::Covariates, - ) = |#x: &::pharmsol::simulator::V, - #p: &::pharmsol::simulator::V, - #t: f64, - #dx: &mut ::pharmsol::simulator::V, - #bolus: &::pharmsol::simulator::V, - #rateiv: &::pharmsol::simulator::V, - #cov: &::pharmsol::data::Covariates| { - #input_aliases - #state_consts - #parameter_bindings - #covariate_bindings - #body - #route_terms - }; - __pharmsol_diffeq - }}) - } - } + Ok(quote! {{ + let __pharmsol_diffeq: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #bolus: &::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + #route_terms + }; + __pharmsol_diffeq + }}) } fn resolve_analytical_structure(structure: &Ident) -> syn::Result { @@ -2094,7 +2280,7 @@ fn expand_analytical_route_map( closure: &ExprClosure, params: &[Ident], covariates: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); @@ -2111,7 +2297,11 @@ fn expand_analytical_route_map( )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); - let body = &closure.body; + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); Ok(quote! {{ let __pharmsol_route_map: fn( @@ -2221,10 +2411,11 @@ fn expand_analytical_out( params: &[Ident], covariates: &[Ident], states: &[Ident], - outputs: &[Ident], + outputs: &[SymbolicIndex], ) -> syn::Result { let state_consts = generate_index_consts(states); - let output_consts = generate_index_consts(outputs); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); let x = generated_ident("__pharmsol_x"); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); @@ -2239,7 +2430,19 @@ fn expand_analytical_out( )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); - let body = &out.body; + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); Ok(quote! {{ let __pharmsol_out: fn( @@ -2270,7 +2473,7 @@ fn expand_sde_drift( covariates: &[Ident], states: &[Ident], routes: &[OdeRouteDecl], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let state_consts = generate_index_consts(states); let x = generated_ident("__pharmsol_x"); @@ -2360,7 +2563,7 @@ fn expand_sde_route_map( closure: &ExprClosure, params: &[Ident], covariates: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); @@ -2377,7 +2580,11 @@ fn expand_sde_route_map( )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); - let body = &closure.body; + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); Ok(quote! {{ let __pharmsol_route_map: fn( @@ -2444,10 +2651,11 @@ fn expand_sde_out( params: &[Ident], covariates: &[Ident], states: &[Ident], - outputs: &[Ident], + outputs: &[SymbolicIndex], ) -> syn::Result { let state_consts = generate_index_consts(states); - let output_consts = generate_index_consts(outputs); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); let x = generated_ident("__pharmsol_x"); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); @@ -2462,7 +2670,19 @@ fn expand_sde_out( )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); - let body = &out.body; + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); Ok(quote! {{ let __pharmsol_out: fn( @@ -2495,7 +2715,7 @@ fn expand_sde_out( pub fn ode(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as OdeInput); - let route_bindings = ode_route_channel_bindings(&input.routes); + let route_bindings = ode_route_input_bindings(&input.routes); let lag_routes = match input.lag.as_ref() { Some(closure) => match extract_route_property_routes( @@ -2550,7 +2770,6 @@ pub fn ode(input: TokenStream) -> TokenStream { &input.states, &input.routes, &route_bindings, - input.diffeq_mode, ) { Ok(diffeq) => diffeq, Err(error) => return error.to_compile_error().into(), @@ -2576,7 +2795,7 @@ pub fn ode(input: TokenStream) -> TokenStream { let covariates = &input.covariates; let states = &input.states; let outputs = &input.outputs; - let routes = expand_route_metadata(&input.routes, input.diffeq_mode, &lag_routes, &fa_routes); + let routes = expand_route_metadata(&input.routes, &lag_routes, &fa_routes); let covariate_metadata = if covariates.is_empty() { quote! {} } else { @@ -2652,7 +2871,7 @@ pub fn ode(input: TokenStream) -> TokenStream { #[proc_macro] pub fn analytical(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as AnalyticalInput); - let route_bindings = ode_route_channel_bindings(&input.routes); + let route_bindings = ode_route_input_bindings(&input.routes); let kernel_spec = match resolve_analytical_structure(&input.structure) { Ok(spec) => spec, @@ -2816,7 +3035,7 @@ pub fn analytical(input: TokenStream) -> TokenStream { #[proc_macro] pub fn sde(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as SdeInput); - let route_bindings = ode_route_channel_bindings(&input.routes); + let route_bindings = ode_route_input_bindings(&input.routes); let lag_routes = match input.lag.as_ref() { Some(closure) => match extract_route_property_routes( @@ -3006,7 +3225,7 @@ mod tests { #[test] fn validates_route_destinations() { let error = syn::parse_str::( - "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> peripheral }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> peripheral], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown route destination must fail"); @@ -3019,7 +3238,7 @@ mod tests { #[test] fn rejects_named_binding_collisions() { let error = syn::parse_str::( - "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: [infusion(iv) -> central], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", ) .err() .expect("parameter/state binding collisions must fail"); @@ -3030,20 +3249,20 @@ mod tests { } #[test] - fn ode_route_bindings_share_channels_by_kind_local_ordinal() { + fn ode_route_bindings_share_inputs_by_kind_local_ordinal() { let input = syn::parse_str::( - "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: { bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot }, diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: [bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot], diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", ) .expect("declaration-first ode input should parse"); - let bindings = ode_route_channel_bindings(&input.routes); + let bindings = ode_route_input_bindings(&input.routes); assert_eq!(dense_index_len(&bindings), 2); - assert_eq!(bindings[0].0.to_string(), "oral"); + assert_eq!(bindings[0].0.name(), "oral"); assert_eq!(bindings[0].1, 0); - assert_eq!(bindings[1].0.to_string(), "iv"); + assert_eq!(bindings[1].0.name(), "iv"); assert_eq!(bindings[1].1, 0); - assert_eq!(bindings[2].0.to_string(), "sc"); + assert_eq!(bindings[2].0.name(), "sc"); assert_eq!(bindings[2].1, 1); } @@ -3083,7 +3302,7 @@ mod tests { #[test] fn analytical_accepts_extra_parameters_beyond_kernel_arity() { let input = syn::parse_str::( - "name: \"demo\", params: [ka, ke, v, tlag, tvke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, sec: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ka, ke, v, tlag, tvke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, sec: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}", ) .expect("extra declared parameters should be allowed"); @@ -3096,7 +3315,7 @@ mod tests { #[test] fn analytical_rejects_unknown_structure() { let error = syn::parse_str::( - "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, structure: mystery, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: mystery, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown analytical structure must fail"); @@ -3109,7 +3328,7 @@ mod tests { #[test] fn analytical_rejects_insufficient_kernel_parameters() { let error = syn::parse_str::( - "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", ) .err() .expect("insufficient kernel parameters must fail"); @@ -3122,7 +3341,7 @@ mod tests { #[test] fn analytical_rejects_unknown_route_property_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown lag route must fail"); @@ -3135,7 +3354,7 @@ mod tests { #[test] fn analytical_rejects_infusion_lag_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", ) .err() .expect("infusion lag must fail"); @@ -3148,7 +3367,7 @@ mod tests { #[test] fn sde_requires_particles() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: [infusion(iv) -> central], drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}", ) .err() .expect("missing particles must fail"); @@ -3161,7 +3380,7 @@ mod tests { #[test] fn sde_rejects_unknown_route_property_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown lag route must fail"); @@ -3174,7 +3393,7 @@ mod tests { #[test] fn sde_rejects_infusion_lag_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", ) .err() .expect("infusion lag must fail"); @@ -3183,4 +3402,17 @@ mod tests { .to_string() .contains("declaration-first `sde!` does not allow `lag` on infusion route `iv`")); } + + #[test] + fn rejects_braced_route_lists() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("braced route lists must fail"); + + assert!(error + .to_string() + .contains("declaration-first macro `routes` must use `[...]`, not `{...}`")); + } } diff --git a/src/data/builder.rs b/src/data/builder.rs index 18aa17fe..ed0a57a8 100644 --- a/src/data/builder.rs +++ b/src/data/builder.rs @@ -1,6 +1,21 @@ +//! Builder API for constructing [`Subject`] schedules in Rust. +//! +//! Use `Subject::builder(...)` when you want to describe a subject directly in +//! code with a schedule-oriented API. This is the preferred high-level +//! path for hand-written datasets. +//! +//! Builder methods accept public input and output labels. Prefer stable strings +//! such as `"depot"`, `"iv"`, and `"cp"`. Numeric values are accepted, but +//! they remain public labels rather than automatically becoming dense internal +//! indices. + use crate::{data::*, Censor}; -/// Extension trait for creating [Subject] instances using the builder pattern +/// Extension trait that enables `Subject::builder(...)`. +/// +/// Most users do not need to import [`SubjectBuilder`] directly. Import this +/// trait from the crate root or [`crate::prelude`] and then start with +/// `Subject::builder("id")`. pub trait SubjectBuilderExt { /// Create a new SubjectBuilder with the specified ID /// @@ -14,8 +29,8 @@ pub trait SubjectBuilderExt { /// use pharmsol::*; /// /// let subject = Subject::builder("patient_001") - /// .bolus(0.0, 100.0, 0) - /// .observation(1.0, 10.5, 0) + /// .bolus(0.0, 100.0, "depot") + /// .observation(1.0, 10.5, "cp") /// .build(); /// ``` fn builder(id: impl Into) -> SubjectBuilder; @@ -34,11 +49,37 @@ impl SubjectBuilderExt for Subject { } } -/// Builder for creating [Subject] instances with a fluent API +/// Builder for creating [`Subject`] values with a fluent API. +/// +/// Use [`SubjectBuilder`] when you want to author common dose and observation +/// schedules directly in Rust without constructing low-level event values by +/// hand. +/// +/// A builder instance accumulates events inside the current [`Occasion`]. +/// [`SubjectBuilder::repeat`] duplicates the most recently added event at later +/// times, and [`SubjectBuilder::reset`] closes the current occasion and starts a +/// new one with fresh occasion-local state. +/// +/// Input and output arguments are public labels. Prefer stable model-facing +/// names such as `"depot"`, `"iv"`, and `"cp"`. +/// +/// # Example +/// +/// ```rust +/// use pharmsol::*; +/// +/// let subject = Subject::builder("patient_001") +/// .bolus(0.0, 100.0, "depot") +/// .repeat(1, 24.0) +/// .observation(1.0, 12.3, "cp") +/// .missing_observation(25.0, "cp") +/// .reset() +/// .bolus(0.0, 80.0, "depot") +/// .observation(1.0, 10.1, "cp") +/// .build(); /// -/// The [SubjectBuilder] allows for constructing complex subject data with a -/// chainable, readable syntax. Events like doses and observations can be -/// added sequentially, and the builder handles organizing them into occasions. +/// assert_eq!(subject.occasions().len(), 2); +/// ``` #[derive(Debug, Clone)] pub struct SubjectBuilder { id: String, @@ -49,52 +90,54 @@ pub struct SubjectBuilder { } impl SubjectBuilder { - /// Add an event to the current occasion + /// Add a fully constructed event to the current occasion. /// - /// # Arguments - /// - /// * `event` - The event to add + /// Use this when you want to mix builder convenience methods with direct + /// [`Event`] values. pub fn event(mut self, event: Event) -> Self { self.last_added_event = Some(event.clone()); self.current_occasion.add_event(event); self } - /// Add a bolus dosing event + /// Add an instantaneous dose. /// /// # Arguments /// /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered - /// * `input` - The compartment number receiving the dose - pub fn bolus(self, time: f64, amount: f64, input: usize) -> Self { + /// * `input` - Public input label receiving the dose + /// + /// Prefer stable route names such as `"depot"` or `"iv"` when the model + /// declares named routes. + pub fn bolus(self, time: f64, amount: f64, input: impl ToString) -> Self { let bolus = Bolus::new(time, amount, input, self.current_occasion.index()); let event = Event::Bolus(bolus); self.event(event) } - /// Add an infusion event + /// Add a continuous dose over a duration. /// /// # Arguments /// /// * `time` - Start time of the infusion /// * `amount` - Total amount of drug to be administered - /// * `input` - The compartment number receiving the dose + /// * `input` - Public input label receiving the dose /// * `duration` - Duration of the infusion in time units - pub fn infusion(self, time: f64, amount: f64, input: usize, duration: f64) -> Self { + pub fn infusion(self, time: f64, amount: f64, input: impl ToString, duration: f64) -> Self { let infusion = Infusion::new(time, amount, input, duration, self.current_occasion.index()); let event = Event::Infusion(infusion); self.event(event) } - /// Add an observation + /// Add an observed value at a given time. /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number corresponding to this observation - pub fn observation(self, time: f64, value: f64, outeq: usize) -> Self { + /// * `outeq` - Public output label for this observation + pub fn observation(self, time: f64, value: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, Some(value), @@ -107,18 +150,19 @@ impl SubjectBuilder { self.event(event) } - /// Add a censored observation + /// Add an observed value with explicit censoring information. + /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number (zero-indexed) corresponding to this - /// observation + /// * `outeq` - Public output label for this observation + /// * `censoring` - Censoring status for the observation value pub fn censored_observation( self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, censoring: Censor, ) -> Self { let observation = Observation::new( @@ -133,13 +177,16 @@ impl SubjectBuilder { self.event(event) } - /// Add an observation + /// Add a prediction-only observation slot. /// /// # Arguments /// /// * `time` - Time of the observation - /// * `outeq` - Output equation number (zero-indexed) corresponding to this observation - pub fn missing_observation(self, time: f64, outeq: usize) -> Self { + /// * `outeq` - Public output label for this observation + /// + /// Use this when you want a prediction at a time point but do not have an + /// observed value. + pub fn missing_observation(self, time: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, None, @@ -152,20 +199,20 @@ impl SubjectBuilder { self.event(event) } - /// Add an observation with a specific error polynomial + /// Add an observed value with an explicit assay error polynomial. /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number (zero-indexed) corresponding to this observation + /// * `outeq` - Public output label for this observation /// * `errorpoly` - Error polynomial coefficients (c0, c1, c2, c3) - /// * `censored` - Whether the observation is censored + /// * `censored` - Censoring status for the observation value pub fn observation_with_error( self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: ErrorPoly, censored: Censor, ) -> Self { @@ -181,7 +228,10 @@ impl SubjectBuilder { self.event(event) } - /// Repeat the last event `n` times, separated by some interval `delta` + /// Repeat the last event `n` times, separated by `delta`. + /// + /// The repeated events keep the same label, value, censoring state, and + /// error polynomial as the original event. Only the event time changes. /// /// # Arguments /// @@ -193,9 +243,8 @@ impl SubjectBuilder { /// ```rust /// use pharmsol::*; /// - /// /// let subject = Subject::builder("patient_001") - /// .bolus(0.0, 100.0, 0) // First dose at time 0 + /// .bolus(0.0, 100.0, "depot") // First dose at time 0 /// .repeat(3, 24.0) // Repeat the dose at times 24, 48, and 72 /// .build(); /// ``` @@ -255,12 +304,14 @@ impl SubjectBuilder { self } - /// Complete the current occasion and start a new one + /// Complete the current occasion and start a new one. /// /// This finalizes the current occasion, adds it to the subject, /// and creates a new occasion for subsequent events. - /// This is useful if a patient has new observations at some other occasion. - /// Note that all states are reset! + /// Use this when the subject should begin a new occasion with reset state. + /// + /// Covariates collected since the previous reset are attached to the + /// finished occasion. The new occasion starts empty and its state is reset. pub fn reset(mut self) -> Self { let block_index = self.current_occasion.index() + 1; self.current_occasion.sort(); @@ -274,7 +325,7 @@ impl SubjectBuilder { self } - /// Add a covariate value at a specific time + /// Add a covariate value at a specific time. /// /// Multiple calls for the same covariate at different times will create /// linear interpolation between the time points. @@ -300,7 +351,7 @@ impl SubjectBuilder { self } - /// Finalize and build the Subject + /// Finalize and build the [`Subject`]. /// /// This completes the current occasion and returns a new Subject with all /// the accumulated data. diff --git a/src/data/event.rs b/src/data/event.rs index ff88e097..02a4c9a7 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -1,3 +1,15 @@ +//! Event types and public label wrappers for subject schedules. +//! +//! These types are the low-level representation behind the higher-level +//! builder and parsing APIs. Most users can start with +//! [`crate::data::builder::SubjectBuilder`], then inspect or transform +//! [`Event`] values after construction. +//! +//! Dose events carry an [`InputLabel`], and observations carry an +//! [`OutputLabel`]. Prefer stable strings such as `"depot"`, `"iv"`, and +//! `"cp"`. Numeric values are accepted, but they remain labels until a +//! downstream workflow explicitly interprets them as indices. + use crate::data::error_model::ErrorPoly; use crate::prelude::simulator::Prediction; use serde::{Deserialize, Serialize}; @@ -7,12 +19,16 @@ use std::fmt; // Shared Analysis Types // ============================================================================ -/// Administration route for a dosing event +/// Administration route classification used by downstream analyses. +/// +/// [`Route`] is a coarse route category, not the original public input label. +/// In the current data-side heuristic: +/// - [`Event::Infusion`] maps to [`Route::IVInfusion`] +/// - [`Event::Bolus`] with input label `0` maps to [`Route::Extravascular`] +/// - [`Event::Bolus`] with any other label maps to [`Route::IVBolus`] /// -/// Determined by the type of dose events and their target compartment: -/// - [`Event::Infusion`] → [`Route::IVInfusion`] -/// - [`Event::Bolus`] with `input >= 1` (central compartment) → [`Route::IVBolus`] -/// - [`Event::Bolus`] with `input == 0` (depot compartment) → [`Route::Extravascular`] +/// If you need the original model-facing label, read [`Bolus::input`] or +/// [`Infusion::input`] instead. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum Route { /// Intravenous bolus @@ -78,12 +94,15 @@ pub enum BLQRule { }, } -/// Represents a pharmacokinetic/pharmacodynamic event +/// One scheduled item in a subject record. /// -/// Events represent key occurrences in a PK/PD profile, including: -/// - [Bolus] doses (instantaneous drug input) -/// - [Infusion]s (continuous drug input over a duration) -/// - [Observation]s (measured concentrations or other values) +/// Events are the low-level representation for doses and observations: +/// - [`Bolus`] for instantaneous input +/// - [`Infusion`] for input over a duration +/// - [`Observation`] for measured or missing outputs +/// +/// Most users create these through `Subject::builder(...)`, row ingestion, or +/// file parsing rather than constructing them all by hand. #[derive(Serialize, Debug, Clone, Deserialize)] pub enum Event { /// A bolus dose (instantaneous drug input) @@ -93,6 +112,109 @@ pub enum Event { /// An observation of drug concentration or other measure Observation(Observation), } + +macro_rules! impl_label_type { + ($(#[$meta:meta])* $name:ident) => { + $(#[$meta])* + #[derive( + Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, + )] + pub struct $name(String); + + impl $name { + /// Create a new public label. + /// + /// Prefer stable names when the model declares named routes or + /// outputs. + pub fn new(label: impl ToString) -> Self { + Self(label.to_string()) + } + + /// Borrow the stored label as a string. + pub fn as_str(&self) -> &str { + &self.0 + } + + /// Try to interpret the label as a numeric index. + /// + /// This is mainly a compatibility helper for lower-level paths that + /// still operate on dense indices after label resolution. + pub fn index(&self) -> Option { + self.0.parse::().ok() + } + } + + impl From for $name { + fn from(value: String) -> Self { + Self(value) + } + } + + impl From<&str> for $name { + fn from(value: &str) -> Self { + Self(value.to_string()) + } + } + + impl From for $name { + fn from(value: usize) -> Self { + Self(value.to_string()) + } + } + + impl AsRef for $name { + fn as_ref(&self) -> &str { + self.as_str() + } + } + + impl fmt::Display for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } + } + + impl PartialEq for $name { + fn eq(&self, other: &usize) -> bool { + self.index() == Some(*other) + } + } + + impl PartialEq<$name> for usize { + fn eq(&self, other: &$name) -> bool { + other == self + } + } + + impl PartialEq for &$name { + fn eq(&self, other: &usize) -> bool { + (**self).eq(other) + } + } + + impl PartialEq<&$name> for usize { + fn eq(&self, other: &&$name) -> bool { + other.eq(self) + } + } + }; +} + +impl_label_type!( + /// Public label for a dosing input or route. + /// + /// [`Bolus`] and [`Infusion`] store the original user-facing route name in + /// this type. + InputLabel +); +impl_label_type!( + /// Public label for an observation output. + /// + /// [`Observation`] stores the original user-facing output name in this + /// type. + OutputLabel +); + impl Event { /// Get the time of the event pub fn time(&self) -> f64 { @@ -145,14 +267,15 @@ impl Event { } } -/// Represents an instantaneous input of drug +/// Instantaneous dose input. /// -/// A [Bolus] is a discrete amount of drug added to a specific compartment at a specific time. +/// A [`Bolus`] records one discrete amount at one time, tagged with the public +/// input label that should be matched against the model. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Bolus { time: f64, amount: f64, - input: usize, + input: InputLabel, occasion: usize, } impl Bolus { @@ -162,12 +285,12 @@ impl Bolus { /// /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered - /// * `input` - The compartment number receiving the dose - pub fn new(time: f64, amount: f64, input: usize, occasion: usize) -> Self { + /// * `input` - The route label receiving the dose + pub fn new(time: f64, amount: f64, input: impl ToString, occasion: usize) -> Self { Bolus { time, amount, - input, + input: InputLabel::new(input), occasion, } } @@ -177,9 +300,16 @@ impl Bolus { self.amount } - /// Get the compartment number that receives the bolus - pub fn input(&self) -> usize { - self.input + /// Get the route label that receives the bolus + pub fn input(&self) -> &InputLabel { + &self.input + } + + /// Try to interpret the input label as a numeric index. + /// + /// Prefer [`Bolus::input`] when working with the public label itself. + pub fn input_index(&self) -> Option { + self.input.index() } /// Get the time of the bolus administration @@ -192,9 +322,9 @@ impl Bolus { self.amount = amount; } - /// Set the compartment number that receives the bolus - pub fn set_input(&mut self, input: usize) { - self.input = input; + /// Set the route label that receives the bolus + pub fn set_input(&mut self, input: impl ToString) { + self.input = InputLabel::new(input); } /// Set the time of the bolus administration @@ -207,8 +337,8 @@ impl Bolus { &mut self.amount } - /// Get a mutable reference to the compartment number (1-indexed) that receives the bolus - pub fn mut_input(&mut self) -> &mut usize { + /// Get a mutable reference to the route label that receives the bolus + pub fn mut_input(&mut self) -> &mut InputLabel { &mut self.input } @@ -228,14 +358,15 @@ impl Bolus { } } -/// Represents a continuous dose of drug over time +/// Continuous dose input over a duration. /// -/// An [Infusion] administers drug at a constant rate over a specified duration. +/// An [`Infusion`] records the total amount, start time, duration, and public +/// input label for one infusion event. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Infusion { time: f64, amount: f64, - input: usize, + input: InputLabel, duration: f64, occasion: usize, } @@ -246,13 +377,19 @@ impl Infusion { /// /// * `time` - Start time of the infusion /// * `amount` - Total amount of drug to be administered - /// * `input` - The compartment number receiving the dose + /// * `input` - The route label receiving the dose /// * `duration` - Duration of the infusion in time units - pub fn new(time: f64, amount: f64, input: usize, duration: f64, occasion: usize) -> Self { + pub fn new( + time: f64, + amount: f64, + input: impl ToString, + duration: f64, + occasion: usize, + ) -> Self { Infusion { time, amount, - input, + input: InputLabel::new(input), duration, occasion, } @@ -263,9 +400,16 @@ impl Infusion { self.amount } - /// Get the compartment number that receives the infusion - pub fn input(&self) -> usize { - self.input + /// Get the route label that receives the infusion + pub fn input(&self) -> &InputLabel { + &self.input + } + + /// Try to interpret the input label as a numeric index. + /// + /// Prefer [`Infusion::input`] when working with the public label itself. + pub fn input_index(&self) -> Option { + self.input.index() } /// Get the duration of the infusion @@ -285,9 +429,9 @@ impl Infusion { self.amount = amount; } - /// Set the compartment number that receives the infusion - pub fn set_input(&mut self, input: usize) { - self.input = input; + /// Set the route label that receives the infusion + pub fn set_input(&mut self, input: impl ToString) { + self.input = InputLabel::new(input); } /// Set the time of the infusion administration @@ -305,8 +449,8 @@ impl Infusion { &mut self.amount } - /// Get a mutable reference to the compartment number (1-indexed) that receives the infusion - pub fn mut_input(&mut self) -> &mut usize { + /// Get a mutable reference to the route label that receives the infusion + pub fn mut_input(&mut self) -> &mut InputLabel { &mut self.input } @@ -343,12 +487,16 @@ pub enum Censor { ALOQ, } -/// Represents an observation of drug concentration or other measured value + /// Observation of a model output. + /// + /// An [`Observation`] can carry a measured value or `None` for a prediction-only + /// time point. Observations also carry the public output label, optional assay + /// error polynomial, occasion index, and censoring state. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Observation { time: f64, value: Option, - outeq: usize, + outeq: OutputLabel, errorpoly: Option, occasion: usize, censoring: Censor, @@ -360,14 +508,14 @@ impl Observation { /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number corresponding to this observation + /// * `outeq` - Output label corresponding to this observation /// * `errorpoly` - Optional error polynomial coefficients (c0, c1, c2, c3) /// * `occasion` - Occasion index /// * `censoring` - Censoring type for this observation pub(crate) fn new( time: f64, value: Option, - outeq: usize, + outeq: impl ToString, errorpoly: Option, occasion: usize, censoring: Censor, @@ -375,7 +523,7 @@ impl Observation { Observation { time, value, - outeq, + outeq: OutputLabel::new(outeq), errorpoly, occasion, censoring, @@ -387,14 +535,23 @@ impl Observation { self.time } - /// Get the value of the observation (e.g., drug concentration) + /// Get the value of the observation. + /// + /// `None` means this is a prediction-only or missing-observation slot. pub fn value(&self) -> Option { self.value } - /// Get the output equation number corresponding to this observation - pub fn outeq(&self) -> usize { - self.outeq + /// Get the output label corresponding to this observation + pub fn outeq(&self) -> &OutputLabel { + &self.outeq + } + + /// Try to interpret the output label as a numeric index. + /// + /// Prefer [`Observation::outeq`] when working with the public label itself. + pub fn outeq_index(&self) -> Option { + self.outeq.index() } /// Get the error polynomial coefficients (c0, c1, c2, c3) if available @@ -414,9 +571,9 @@ impl Observation { self.value = value; } - /// Set the output equation number corresponding to this observation - pub fn set_outeq(&mut self, outeq: usize) { - self.outeq = outeq; + /// Set the output label corresponding to this observation + pub fn set_outeq(&mut self, outeq: impl ToString) { + self.outeq = OutputLabel::new(outeq); } /// Set the [ErrorPoly] for this observation @@ -434,8 +591,8 @@ impl Observation { &mut self.value } - /// Get a mutable reference to the output equation number - pub fn mut_outeq(&mut self) -> &mut usize { + /// Get a mutable reference to the output label + pub fn mut_outeq(&mut self) -> &mut OutputLabel { &mut self.outeq } @@ -454,13 +611,19 @@ impl Observation { &mut self.occasion } - /// Create a [Prediction] from this observation + /// Create a [`Prediction`] from this observation. + /// + /// This is a low-level helper for code paths that already operate on a + /// resolved or numeric output index. Named output labels must be resolved by + /// the caller before this conversion happens. pub fn to_prediction(&self, pred: f64, state: Vec) -> Prediction { Prediction { time: self.time(), observation: self.value(), prediction: pred, - outeq: self.outeq(), + outeq: self + .outeq_index() + .expect("prediction requires a resolved or numeric output label"), errorpoly: self.errorpoly(), state, occasion: self.occasion(), @@ -539,6 +702,7 @@ mod tests { assert_eq!(bolus.time(), 2.5); assert_eq!(bolus.amount(), 100.0); assert_eq!(bolus.input(), 1); + assert_eq!(bolus.input().as_str(), "1"); } #[test] @@ -561,6 +725,7 @@ mod tests { assert_eq!(infusion.time(), 1.0); assert_eq!(infusion.amount(), 200.0); assert_eq!(infusion.input(), 1); + assert_eq!(infusion.input().as_str(), "1"); assert_eq!(infusion.duration(), 2.5); } @@ -589,6 +754,7 @@ mod tests { assert_eq!(observation.time(), 5.0); assert_eq!(observation.value(), Some(75.5)); assert_eq!(observation.outeq(), 2); + assert_eq!(observation.outeq().as_str(), "2"); assert_eq!(observation.errorpoly(), error_poly); } diff --git a/src/data/mod.rs b/src/data/mod.rs index 996c791d..28a80b32 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -1,32 +1,83 @@ -//! Data structures and utilities for pharmacometric modeling +//! Data structures for building pharmacometric input data. //! -//! This module provides types for representing pharmacokinetic/pharmacodynamic data, -//! including subjects, dosing events, observations, and covariates. It also includes -//! utilities for reading and manipulating this data. +//! Use this module when you need to describe what happened to each subject: +//! doses, infusions, observations, covariates, and occasion boundaries. //! -//! # Key Components +//! This module is the input side of `pharmsol`. It is where you assemble +//! subjects and datasets before simulation, estimation, or NCA. It is not where +//! you define model equations or choose a backend. For those workflows, move to +//! [`crate::simulator`], [`crate::nca`], or the feature-gated `pharmsol::dsl` +//! surface. //! -//! - **Events**: Dosing events (bolus, infusion) and observations -//! - **Covariates**: Time-varying subject characteristics -//! - **Subjects**: Collections of events and covariates for a single individual -//! - **Data**: Collections of subjects, representing a complete dataset -//! - **Error Models**: Two types for different algorithm families: -//! - [`ErrorModel`]: Observation-based (assay error) for non-parametric algorithms -//! - [`ResidualErrorModel`]: Prediction-based (residual error) for parametric algorithms +//! # Start Here //! -//! # Examples +//! Most users only need three entrypoints first: //! -//! Creating a subject with the builder pattern: +//! - [`Subject`] for one individual and their full schedule. +//! - [`Data`] for a dataset containing many subjects. +//! - `Subject::builder` for the smallest fluent API to create doses, +//! observations, and covariates in Rust. +//! +//! The main supporting types are: +//! +//! - [`Occasion`] for repeated periods within one subject. +//! - [`Event`], [`Bolus`], [`Infusion`], and [`Observation`] for explicit +//! event-level control. +//! - [`Covariate`] and [`Covariates`] for time-varying subject characteristics. +//! - [`ErrorModel`], [`ResidualErrorModel`], and [`ObservationError`] for the +//! different error surfaces used by downstream workflows. +//! +//! # Choose A Data Input Path +//! +//! - Use `Subject::builder` when you are authoring a schedule directly in Rust. +//! - Use [`row::DataRow`] and [`row::DataRowBuilder`] when your source data is +//! already row-shaped in memory. +//! - Use [`parser::read_pmetrics`] when you are loading a Pmetrics-style file +//! from disk. +//! - Use [`Event`] variants directly when you already have validated event +//! records and need lower-level control than the builder offers. +//! +//! # Label Semantics +//! +//! Dosing inputs and observation outputs use public labels. +//! +//! - The `input` on [`Bolus`] and [`Infusion`] is the route or input label that +//! will be matched against the model. +//! - The `outeq` on [`Observation`] is the output label that identifies which +//! model output the observation belongs to. +//! - Prefer stable names such as `"depot"`, `"central"`, `"iv"`, or `"cp"`. +//! - If you pass a number, it is still treated as a public label string. Use +//! numeric values only when your model intentionally declares numeric labels. +//! +//! [`Occasion`] indices are different: they are integer period markers used to +//! separate repeated dosing blocks within one subject. +//! +//! # Error Surfaces +//! +//! This module exposes three related but different error families: +//! +//! - [`ErrorModel`] for assay or measurement error driven by the observation +//! value, commonly used in non-parametric workflows. +//! - [`ResidualErrorModel`] for residual unexplained variability driven by the +//! prediction value, commonly used in parametric workflows. +//! - [`ObservationError`] for invalid or insufficient observation data during +//! profile construction and related preprocessing. +//! +//! # Example //! //! ```rust //! use pharmsol::*; //! //! let subject = Subject::builder("patient_001") -//! .bolus(0.0, 100.0, 0) -//! .observation(1.0, 10.5, 0) -//! .observation(2.0, 8.2, 0) +//! .bolus(0.0, 100.0, "depot") +//! .observation(1.0, 12.3, "cp") +//! .missing_observation(2.0, "cp") //! .covariate("weight", 0.0, 70.0) //! .build(); +//! +//! let data = Data::new(vec![subject]); +//! +//! assert_eq!(data.subjects().len(), 1); //! ``` pub mod auc; diff --git a/src/data/parser/mod.rs b/src/data/parser/mod.rs index 7bfde3ca..74a50a84 100644 --- a/src/data/parser/mod.rs +++ b/src/data/parser/mod.rs @@ -1,3 +1,15 @@ +//! File-based parsers and parser-facing row utilities. +//! +//! Use this module when your source data starts as files or parser-shaped rows. +//! It re-exports the row ingestion API from [`crate::data::row`] and provides +//! format-specific loaders such as [`read_pmetrics`]. +//! +//! Choose the entrypoint by source shape: +//! - Use [`DataRow`] or [`build_data`] when you already mapped external data into +//! canonical row fields yourself. +//! - Use [`read_pmetrics`] when the source file already follows the Pmetrics CSV +//! convention. + pub mod pmetrics; pub use crate::data::row::{build_data, DataError, DataRow, DataRowBuilder}; diff --git a/src/data/parser/pmetrics.rs b/src/data/parser/pmetrics.rs index c410d689..2c90e2a7 100644 --- a/src/data/parser/pmetrics.rs +++ b/src/data/parser/pmetrics.rs @@ -1,3 +1,12 @@ +//! Pmetrics CSV parsing and export helpers. +//! +//! This module reads and writes the Pmetrics-style tabular format while keeping +//! pharmsol's public input and output labels intact. +//! +//! `INPUT` and `OUTEQ` values are parsed as labels, not rewritten to dense +//! indices. Named values such as `iv` and `cp` are preserved exactly, and +//! numeric values such as `1` are preserved as numeric-looking labels. + use crate::{data::*, PharmsolError}; use csv::WriterBuilder; use serde::de::{MapAccess, Visitor}; @@ -10,19 +19,27 @@ use crate::data::row::DataRow; use std::fmt; use std::str::FromStr; -/// Read a Pmetrics datafile and convert it to a [Data] object +/// Read a Pmetrics CSV file into [`Data`]. +/// +/// Use [`read_pmetrics`] when the source file already follows the usual +/// Pmetrics column convention instead of mapping the file into [`DataRow`] +/// values yourself. /// -/// This function parses a Pmetrics-formatted CSV file and constructs a [Data] object containing the structured -/// pharmacokinetic/pharmacodynamic data. The function handles various data formats including doses, observations, -/// and covariates. +/// The parser normalizes header names to lowercase, preserves `INPUT` and +/// `OUTEQ` as public labels, expands `ADDL` dosing rows through the shared row +/// ingestion path, and groups rows into occasions using `EVID=4`. +/// +/// All columns not claimed by the core Pmetrics schema are treated as +/// covariates. /// /// # Arguments /// -/// * `path` - The path to the Pmetrics CSV file +/// * `path` - Path to the Pmetrics CSV file /// /// # Returns /// -/// * `Result` - A result containing either the parsed [Data] object or an error +/// A parsed [`Data`] object or a [`DataError`] if the file cannot be read or a +/// required row field is missing. /// /// # Example /// @@ -33,14 +50,25 @@ use std::str::FromStr; /// println!("Number of subjects: {}", data.subjects().len()); /// ``` /// -/// # Format details +/// # Expected columns +/// +/// The canonical columns are `ID`, `TIME`, `EVID`, `DOSE`, `DUR`, `ADDL`, +/// `II`, `INPUT`, `OUT`, `OUTEQ`, `CENS`, and optional `C0..C3` error +/// coefficients. /// -/// The Pmetrics format expects columns like ID, TIME, EVID, DOSE, DUR, etc. The function will: +/// All other numeric columns are treated as covariates. +/// +/// # Parsing behavior +/// +/// The parser will: /// - Convert all headers to lowercase for case-insensitivity /// - Group rows by subject ID /// - Create occasions based on EVID=4 events /// - Parse covariates and create appropriate interpolations /// - Handle additional doses via ADDL and II fields +/// - Preserve raw `INPUT` and `OUTEQ` labels as strings until model resolution +/// - Treat `OUT=-99` as a missing observation value, matching the common +/// Pmetrics convention /// /// For specific column definitions, see the `Row` struct. #[allow(dead_code)] @@ -72,7 +100,7 @@ pub fn read_pmetrics(path: impl Into) -> Result { build_data(data_rows) } -/// A [Row] represents a row in the Pmetrics data format +/// One row from a Pmetrics file after serde deserialization. #[derive(Deserialize, Debug, Serialize, Default, Clone)] #[serde(rename_all = "lowercase")] struct Row { @@ -94,15 +122,15 @@ struct Row { /// Dosing interval #[serde(deserialize_with = "deserialize_option_f64")] ii: Option, - /// Input compartment - #[serde(deserialize_with = "deserialize_option_usize")] - input: Option, + /// Input label from the `INPUT` column + #[serde(deserialize_with = "deserialize_option_route_label")] + input: Option, /// Observed value #[serde(deserialize_with = "deserialize_option_f64")] out: Option, - /// Corresponding output equation for the observation - #[serde(deserialize_with = "deserialize_option_usize")] - outeq: Option, + /// Output label from the `OUTEQ` column + #[serde(deserialize_with = "deserialize_option_output_label")] + outeq: Option, /// Censoring output #[serde(default, deserialize_with = "deserialize_option_censor")] cens: Option, @@ -134,12 +162,12 @@ impl Row { dur: self.dur, addl: self.addl.map(|a| a as i64), ii: self.ii, - input: self.input, + input: self.input.clone(), // Treat -99 as missing value (Pmetrics convention) out: self .out .and_then(|v| if v == -99.0 { None } else { Some(v) }), - outeq: self.outeq, + outeq: self.outeq.clone(), cens: self.cens, c0: self.c0, c1: self.c1, @@ -196,11 +224,18 @@ where } } -fn deserialize_option_usize<'de, D>(deserializer: D) -> Result, D::Error> +fn deserialize_option_route_label<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + deserialize_option::(deserializer).map(|value| value.map(InputLabel::from)) +} + +fn deserialize_option_output_label<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - deserialize_option::(deserializer) + deserialize_option::(deserializer).map(|value| value.map(OutputLabel::from)) } fn deserialize_option_isize<'de, D>(deserializer: D) -> Result, D::Error> @@ -257,7 +292,14 @@ where } impl Data { - /// Write the dataset to a file in Pmetrics format + /// Write the dataset to a file in Pmetrics format. + /// + /// `INPUT` and `OUTEQ` are written using their stored public labels. Named + /// labels such as `iv` and `cp` remain named labels, and numeric-looking + /// labels are written back exactly as stored. + /// + /// Missing optional fields are emitted as `.` placeholders to match the + /// usual Pmetrics text convention. /// /// # Arguments /// @@ -496,4 +538,50 @@ mod tests { assert_eq!(second.get(11), Some(".")); assert_eq!(second.get(14), Some(".")); } + + #[test] + fn read_pmetrics_preserves_named_route_and_output_labels() { + let file = NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,iv,.,.,.,.,.,.,.\npt1,0,1,.,.,.,.,.,42,cp,0,.,.,.,.\n", + ) + .unwrap(); + + let data = read_pmetrics(file.path().display().to_string()).unwrap(); + let events = data.subjects()[0].occasions()[0].events(); + + match &events[0] { + Event::Infusion(infusion) => assert_eq!(infusion.input().as_str(), "iv"), + _ => panic!("expected infusion event"), + } + + match &events[1] { + Event::Observation(observation) => assert_eq!(observation.outeq().as_str(), "cp"), + _ => panic!("expected observation event"), + } + } + + #[test] + fn read_pmetrics_preserves_numeric_labels_as_strings() { + let file = NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,.,100,.,.,1,.,.,.,.,.,.,.\npt1,0,1,.,.,.,.,.,42,1,0,.,.,.,.\n", + ) + .unwrap(); + + let data = read_pmetrics(file.path().display().to_string()).unwrap(); + let events = data.subjects()[0].occasions()[0].events(); + + match &events[0] { + Event::Bolus(bolus) => assert_eq!(bolus.input().as_str(), "1"), + _ => panic!("expected bolus event"), + } + + match &events[1] { + Event::Observation(observation) => assert_eq!(observation.outeq().as_str(), "1"), + _ => panic!("expected observation event"), + } + } } diff --git a/src/data/row.rs b/src/data/row.rs index b3b38ad8..fcb610ea 100644 --- a/src/data/row.rs +++ b/src/data/row.rs @@ -1,39 +1,56 @@ -//! Row representation of [Data] for flexible parsing +//! Row-shaped data ingestion for [`Data`] and [`Subject`] assembly. +//! +//! Use this module when your source data already looks like rows from a table, +//! CSV file, database export, or ETL pipeline. +//! +//! Choose the ingestion path by source shape: +//! - Use [`crate::data::builder::SubjectBuilder`] when you want to author a +//! schedule directly in Rust. +//! - Use [`DataRow`] and [`build_data`] when your application already has +//! validated row records in memory. +//! - Use [`crate::data::parser::read_pmetrics`] when the source file already +//! follows the Pmetrics column convention. +//! +//! [`DataRow`] keeps public route and output labels as strings. Labels such as +//! `"iv"`, `"depot"`, and `"cp"` are preserved through row parsing and later +//! resolved against model metadata by downstream workflows. //! //! # Example //! //! ```rust //! use pharmsol::data::parser::DataRow; //! -//! // Create a dosing row with ADDL expansion //! let row = DataRow::builder("subject_1", 0.0) //! .evid(1) //! .dose(100.0) -//! .input(1) -//! .addl(3) // 3 additional doses -//! .ii(12.0) // 12 hours apart +//! .input("iv") +//! .addl(3) +//! .ii(12.0) //! .build(); //! //! let events = row.into_events().unwrap(); -//! assert_eq!(events.len(), 4); // Original + 3 additional doses +//! assert_eq!(events.len(), 4); //! ``` -//! use crate::data::*; use std::collections::HashMap; use thiserror::Error; -/// A format-agnostic representation of a single data row +/// A format-agnostic representation of one input row. +/// +/// [`DataRow`] collects the canonical fields needed to turn one external row +/// into one or more [`Event`] values. /// -/// This struct represents the canonical fields needed to create pharmsol Events. -/// Consumers construct this from their source data (regardless of column names), -/// then call [`into_events()`](DataRow::into_events) to get properly parsed -/// Events with full ADDL expansion, EVID handling, censoring, etc. +/// Build this type from your own column mapping or external schema, then call +/// [`DataRow::into_events`] or [`build_data`] to assemble subjects and datasets. +/// +/// A single row can expand into several events when `ADDL` and `II` are both +/// present. /// /// # Fields /// -/// All fields use Pmetrics conventions: -/// - `input` and `outeq` are **1-indexed** (kept as-is, user must size arrays accordingly) +/// All fields use the public labeling conventions: +/// - `input` and `outeq` preserve the route and output labels from the source data /// - `evid`: 0=observation, 1=dose, 4=reset/new occasion /// - `addl`: positive=forward in time, negative=backward in time /// @@ -42,24 +59,22 @@ use thiserror::Error; /// ```rust /// use pharmsol::data::parser::DataRow; /// -/// // Observation row /// let obs = DataRow::builder("pt1", 1.0) /// .evid(0) /// .out(25.5) -/// .outeq(1) +/// .outeq("cp") /// .build(); /// -/// // Dosing row with negative ADDL (doses before time 0) /// let dose = DataRow::builder("pt1", 0.0) /// .evid(1) /// .dose(100.0) -/// .input(1) -/// .addl(-10) // 10 doses BEFORE time 0 +/// .input("iv") +/// .addl(-10) /// .ii(12.0) /// .build(); /// /// let events = dose.into_events().unwrap(); -/// // Events at times: -120, -108, -96, ..., -12, 0 +/// assert_eq!(obs.outeq.as_ref().map(|label| label.as_str()), Some("cp")); /// assert_eq!(events.len(), 11); /// ``` #[derive(Debug, Clone, Default)] @@ -78,12 +93,12 @@ pub struct DataRow { pub addl: Option, /// Interdose interval for ADDL pub ii: Option, - /// Input compartment - pub input: Option, + /// Input route label + pub input: Option, /// Observed value (for EVID=0) pub out: Option, - /// Output equation number - pub outeq: Option, + /// Output label + pub outeq: Option, /// Censoring indicator pub cens: Option, /// Error polynomial coefficients @@ -99,7 +114,7 @@ pub struct DataRow { } impl DataRow { - /// Create a new builder for constructing a DataRow + /// Create a builder for constructing one [`DataRow`]. /// /// # Arguments /// @@ -114,7 +129,7 @@ impl DataRow { /// let row = DataRow::builder("patient_001", 0.0) /// .evid(1) /// .dose(100.0) - /// .input(1) + /// .input("depot") /// .build(); /// ``` pub fn builder(id: impl Into, time: f64) -> DataRowBuilder { @@ -129,13 +144,14 @@ impl DataRow { } } - /// Convert this row into pharmsol Events + /// Convert this row into one or more [`Event`] values. /// - /// This method contains all the complex parsing logic: + /// This method performs the row-level translation logic: /// - EVID interpretation (0=observation, 1=dose, 4=reset) /// - ADDL/II expansion (both positive and negative directions) /// - Infusion vs bolus detection based on DUR /// - Censoring and error polynomial handling + /// - Preservation of public input and output labels /// /// # ADDL Expansion /// @@ -163,13 +179,13 @@ impl DataRow { /// let row = DataRow::builder("pt1", 0.0) /// .evid(1) /// .dose(100.0) - /// .input(1) + /// .input("iv") /// .addl(2) /// .ii(24.0) /// .build(); /// /// let events = row.into_events().unwrap(); - /// assert_eq!(events.len(), 3); // doses at 24, 48, and 0 + /// assert_eq!(events.len(), 3); /// /// let times: Vec = events.iter().map(|e| e.time()).collect(); /// assert_eq!(times, vec![24.0, 48.0, 0.0]); @@ -180,14 +196,17 @@ impl DataRow { match self.evid { 0 => { // Observation event - events.push(Event::Observation(Observation::new( - self.time, - self.out, + let outeq = self.outeq + .clone() .ok_or_else(|| DataError::MissingObservationOuteq { id: self.id.clone(), time: self.time, - })?, // Keep 1-indexed as provided by Pmetrics + })?; + events.push(Event::Observation(Observation::new( + self.time, + self.out, + outeq, self.get_errorpoly(), 0, // occasion set later self.cens.unwrap_or(Censor::None), @@ -196,10 +215,13 @@ impl DataRow { 1 | 4 => { // Dosing event (1) or reset with dose (4) - let input = self.input.ok_or_else(|| DataError::MissingBolusInput { - id: self.id.clone(), - time: self.time, - })?; // Keep 1-indexed as provided by Pmetrics + let input = self + .input + .clone() + .ok_or_else(|| DataError::MissingBolusInput { + id: self.id.clone(), + time: self.time, + })?; let event = if self.dur.unwrap_or(0.0) > 0.0 { // Infusion @@ -281,7 +303,11 @@ impl DataRow { } } -/// Builder for constructing DataRow with a fluent API +/// Fluent builder for [`DataRow`]. +/// +/// Use [`DataRowBuilder`] when you have row-shaped data in memory and want to +/// construct rows incrementally before calling [`DataRow::into_events`] or +/// [`build_data`]. /// /// # Example /// @@ -292,7 +318,7 @@ impl DataRow { /// let row = DataRow::builder("patient_001", 1.5) /// .evid(0) /// .out(25.5) -/// .outeq(1) +/// .outeq("cp") /// .cens(Censor::None) /// .covariate("weight", 70.0) /// .covariate("age", 45.0) @@ -367,12 +393,13 @@ impl DataRowBuilder { self } - /// Set the input compartment (1-indexed) + /// Set the input route label. /// - /// Required for EVID=1 (dosing events). - /// Kept as 1-indexed; user must size state arrays accordingly. - pub fn input(mut self, input: usize) -> Self { - self.row.input = Some(input); + /// Required for EVID=1 dosing rows. + /// The provided value is preserved as the public label until downstream + /// model resolution. + pub fn input(mut self, input: impl ToString) -> Self { + self.row.input = Some(InputLabel::new(input)); self } @@ -384,12 +411,13 @@ impl DataRowBuilder { self } - /// Set the output equation (1-indexed) + /// Set the output label. /// - /// Required for EVID=0 (observation events). - /// Will be converted to 0-indexed internally. - pub fn outeq(mut self, outeq: usize) -> Self { - self.row.outeq = Some(outeq); + /// Required for EVID=0 observation rows. + /// The provided value is preserved as the public label until downstream + /// model resolution. + pub fn outeq(mut self, outeq: impl ToString) -> Self { + self.row.outeq = Some(OutputLabel::new(outeq)); self } @@ -430,13 +458,18 @@ impl DataRowBuilder { } } -/// Build a [Data] object from an iterator of [DataRow]s +/// Build a [`Data`] object from row-shaped input. /// -/// This function handles all the complex assembly logic: +/// This function assembles rows into subjects and occasions: /// - Groups rows by subject ID /// - Splits into occasions at EVID=4 boundaries /// - Converts rows to events via [`DataRow::into_events()`] /// - Builds covariates from row covariate data +/// - Preserves per-subject row order within each occasion block +/// +/// Use this when you already have a collection of [`DataRow`] values in memory. +/// If your source file is a Pmetrics CSV, use [`crate::data::parser::read_pmetrics`] +/// instead. /// /// # Example /// @@ -444,23 +477,21 @@ impl DataRowBuilder { /// use pharmsol::data::parser::{DataRow, build_data}; /// /// let rows = vec![ -/// // Subject 1, Occasion 0 /// DataRow::builder("pt1", 0.0) -/// .evid(1).dose(100.0).input(1).build(), +/// .evid(1).dose(100.0).input("iv").build(), /// DataRow::builder("pt1", 1.0) -/// .evid(0).out(50.0).outeq(1).build(), -/// // Subject 1, Occasion 1 (EVID=4 starts new occasion) +/// .evid(0).out(50.0).outeq("cp").build(), /// DataRow::builder("pt1", 24.0) -/// .evid(4).dose(100.0).input(1).build(), +/// .evid(4).dose(100.0).input("iv").build(), /// DataRow::builder("pt1", 25.0) -/// .evid(0).out(48.0).outeq(1).build(), -/// // Subject 2 +/// .evid(0).out(48.0).outeq("cp").build(), /// DataRow::builder("pt2", 0.0) -/// .evid(1).dose(50.0).input(1).build(), +/// .evid(1).dose(50.0).input("iv").build(), /// ]; /// /// let data = build_data(rows).unwrap(); /// assert_eq!(data.subjects().len(), 2); +/// assert_eq!(data.subjects()[0].occasions().len(), 2); /// ``` pub fn build_data(rows: impl IntoIterator) -> Result { // Group rows by subject ID @@ -556,14 +587,14 @@ pub enum DataError { /// Required observation value (OUT) is missing #[error("Observation OUT is missing for {id} at time {time}")] MissingObservationOut { id: String, time: f64 }, - /// Required observation output equation (OUTEQ) is missing - #[error("Observation OUTEQ is missing in for {id} at time {time}")] + /// Required observation output label (`OUTEQ`) is missing + #[error("Observation OUTEQ is missing for {id} at time {time}")] MissingObservationOuteq { id: String, time: f64 }, /// Required infusion dose amount is missing #[error("Infusion amount (DOSE) is missing for {id} at time {time}")] MissingInfusionDose { id: String, time: f64 }, - /// Required infusion input compartment is missing - #[error("Infusion compartment (INPUT) is missing for {id} at time {time}")] + /// Required infusion input label (`INPUT`) is missing + #[error("Infusion input label (INPUT) is missing for {id} at time {time}")] MissingInfusionInput { id: String, time: f64 }, /// Required infusion duration is missing #[error("Infusion duration (DUR) is missing for {id} at time {time}")] @@ -571,8 +602,8 @@ pub enum DataError { /// Required bolus dose amount is missing #[error("Bolus amount (DOSE) is missing for {id} at time {time}")] MissingBolusDose { id: String, time: f64 }, - /// Required bolus input compartment is missing - #[error("Bolus compartment (INPUT) is missing for {id} at time {time}")] + /// Required bolus input label (`INPUT`) is missing + #[error("Bolus input label (INPUT) is missing for {id} at time {time}")] MissingBolusInput { id: String, time: f64 }, } diff --git a/src/data/structs.rs b/src/data/structs.rs index 82cd3faf..d7d123b1 100644 --- a/src/data/structs.rs +++ b/src/data/structs.rs @@ -180,17 +180,18 @@ impl Data { let old_events = occasion.process_events(None, true); // Create a set of existing (time, outeq) pairs for fast lookup - let existing_obs: std::collections::HashSet<(u64, usize)> = old_events - .iter() - .filter_map(|event| match event { - Event::Observation(obs) => { - // Convert to microseconds for consistent comparison - let time_key = (obs.time() * 1e6).round() as u64; - Some((time_key, obs.outeq())) - } - _ => None, - }) - .collect(); + let existing_obs: std::collections::HashSet<(u64, OutputLabel)> = + old_events + .iter() + .filter_map(|event| match event { + Event::Observation(obs) => { + // Convert to microseconds for consistent comparison + let time_key = (obs.time() * 1e6).round() as u64; + Some((time_key, obs.outeq().clone())) + } + _ => None, + }) + .collect(); // Generate new observation times let mut new_events = Vec::new(); @@ -198,13 +199,13 @@ impl Data { while time < last_time { let time_key = (time * 1e6).round() as u64; - for &outeq in &outeq_values { + for outeq in &outeq_values { // Only add if this (time, outeq) combination doesn't exist - if !existing_obs.contains(&(time_key, outeq)) { + if !existing_obs.contains(&(time_key, outeq.clone())) { let obs = Observation::new( time, None, - outeq, + outeq.clone(), None, occasion.index, Censor::None, @@ -273,15 +274,15 @@ impl Data { self.subjects.is_empty() } - /// Get a vector of all unique output equations (outeq) across all subjects - pub fn get_output_equations(&self) -> Vec { + /// Get a vector of all unique output labels (outeq) across all subjects + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let mut outeq_values: Vec = self + let mut outeq_values: Vec = self .subjects .iter() .flat_map(|subject| subject.get_output_equations()) .collect(); - outeq_values.sort_unstable(); + outeq_values.sort(); outeq_values.dedup(); outeq_values } @@ -396,14 +397,14 @@ impl Subject { self.occasions.iter_mut() } - pub fn get_output_equations(&self) -> Vec { + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let outeq_values: Vec = self + let outeq_values: Vec = self .occasions .iter() .flat_map(|occasion| { occasion.events.iter().filter_map(|event| match event { - Event::Observation(obs) => Some(obs.outeq()), + Event::Observation(obs) => Some(obs.outeq().clone()), _ => None, }) }) @@ -598,8 +599,10 @@ impl Occasion { let time = event.time(); if let Event::Bolus(bolus) = event { let lagtime = fn_lag(&spp.clone().into(), time, covariates); - if let Some(l) = lagtime.get(&bolus.input()) { - *bolus.mut_time() += l; + if let Some(input) = bolus.input_index() { + if let Some(l) = lagtime.get(&input) { + *bolus.mut_time() += l; + } } } } @@ -615,8 +618,10 @@ impl Occasion { let time = event.time(); if let Event::Bolus(bolus) = event { let fa = fn_fa(&spp.clone().into(), time, covariates); - if let Some(f) = fa.get(&bolus.input()) { - bolus.set_amount(bolus.amount() * f); + if let Some(input) = bolus.input_index() { + if let Some(f) = fa.get(&input) { + bolus.set_amount(bolus.amount() * f); + } } } } @@ -703,7 +708,7 @@ impl Occasion { &mut self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: Option, censored: Censor, ) { @@ -713,7 +718,7 @@ impl Occasion { } /// Add a missing [Observation] event to the [Occasion] - pub fn add_missing_observation(&mut self, time: f64, outeq: usize) { + pub fn add_missing_observation(&mut self, time: f64, outeq: impl ToString) { let observation = Observation::new(time, None, outeq, None, self.index, Censor::None); self.add_event(Event::Observation(observation)); } @@ -725,7 +730,7 @@ impl Occasion { &mut self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: ErrorPoly, censored: Censor, ) { @@ -741,13 +746,13 @@ impl Occasion { } /// Add a [Bolus] event to the [Occasion] - pub fn add_bolus(&mut self, time: f64, amount: f64, input: usize) { + pub fn add_bolus(&mut self, time: f64, amount: f64, input: impl ToString) { let bolus = Bolus::new(time, amount, input, self.index); self.add_event(Event::Bolus(bolus)); } /// Add an [Infusion] event to the [Occasion] - pub fn add_infusion(&mut self, time: f64, amount: f64, input: usize, duration: f64) { + pub fn add_infusion(&mut self, time: f64, amount: f64, input: impl ToString, duration: f64) { let infusion = Infusion::new(time, amount, input, duration, self.index); self.add_event(Event::Infusion(infusion)); } @@ -775,17 +780,6 @@ impl Occasion { .unwrap_or(0.0) } - pub(crate) fn infusions_ref(&self) -> Vec<&Infusion> { - //TODO this can be pre-computed when the struct is initially created - self.events - .iter() - .filter_map(|event| match event { - Event::Infusion(infusion) => Some(infusion), - _ => None, - }) - .collect() - } - /// Get an iterator over all events /// /// # Returns @@ -967,7 +961,7 @@ impl Occasion { for event in &self.events { if let Event::Observation(obs) = event { - if obs.outeq() == outeq { + if obs.outeq_index() == Some(outeq) { if let Some(value) = obs.value() { times.push(obs.time()); concs.push(value); diff --git a/src/dsl/aot.rs b/src/dsl/aot.rs index 3749f183..6557e015 100644 --- a/src/dsl/aot.rs +++ b/src/dsl/aot.rs @@ -37,18 +37,23 @@ use pharmsol_dsl::ModelKind; use pharmsol_dsl::{analyze_module, lower_typed_model, parse_module, ExecutionModel}; use pharmsol_dsl::{Diagnostic, DiagnosticReport, LoweringError, ParseError, SemanticError}; +/// ABI version for native AoT artifacts produced by this crate. pub const AOT_API_VERSION: u32 = 1; #[cfg(feature = "dsl-aot")] +/// Selects the compilation target for a native ahead-of-time artifact. #[derive(Debug, Clone, PartialEq, Eq, Default)] pub enum NativeAotTarget { + /// Compile for the current host toolchain target. #[default] Host, + /// Compile for an explicit Rust target triple. Triple(String), } #[cfg(feature = "dsl-aot")] impl NativeAotTarget { + /// Create a target selector for an explicit Rust target triple. pub fn triple(target: impl Into) -> Self { Self::Triple(target.into()) } @@ -62,15 +67,24 @@ impl NativeAotTarget { } #[cfg(feature = "dsl-aot")] +/// Options that control native ahead-of-time artifact export. +/// +/// AoT export writes a small template crate under [`template_root`](Self::template_root), +/// builds a native shared library, and then copies the resulting artifact to +/// [`output`](Self::output) or a generated default path. #[derive(Debug, Clone, PartialEq, Eq)] pub struct NativeAotCompileOptions { + /// Target triple selection for the emitted artifact. pub target: NativeAotTarget, + /// Optional final artifact location. pub output: Option, + /// Working directory used for the temporary template crate and build output. pub template_root: PathBuf, } #[cfg(feature = "dsl-aot")] impl NativeAotCompileOptions { + /// Create AoT options rooted at a template build directory. pub fn new(template_root: PathBuf) -> Self { Self { target: NativeAotTarget::Host, @@ -79,17 +93,20 @@ impl NativeAotCompileOptions { } } + /// Set the final artifact output path. pub fn with_output(mut self, output: PathBuf) -> Self { self.output = Some(output); self } + /// Set the compilation target triple. pub fn with_target(mut self, target: NativeAotTarget) -> Self { self.target = target; self } } +/// Error produced while exporting, reading, or loading a native AoT artifact. #[derive(Error)] pub enum AotError { #[error(transparent)] @@ -151,6 +168,43 @@ impl fmt::Debug for AotError { } #[cfg(feature = "dsl-aot")] +/// Parse DSL source, lower one selected model, and export a native AoT artifact. +/// +/// Use this when you want a reusable native artifact that can be loaded later +/// with [`load_aot_model`] or [`crate::dsl::load_runtime_artifact`]. +/// +/// This function requires the `dsl-aot` feature. Loading the resulting artifact +/// later requires `dsl-aot-load`. +/// +/// ```rust,no_run +/// use std::path::PathBuf; +/// +/// use pharmsol::dsl::{compile_module_source_to_aot, load_aot_model, NativeAotCompileOptions}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let artifact = compile_module_source_to_aot( +/// source, +/// Some("bimodal_ke"), +/// NativeAotCompileOptions::new(PathBuf::from("target/doc-aot-build")), +/// |_, _| {}, +/// )?; +/// let loaded = load_aot_model(&artifact)?; +/// # let _ = loaded; +/// # Ok::<(), Box>(()) +/// ``` pub fn compile_module_source_to_aot( source: &str, model_name: Option<&str>, @@ -184,6 +238,10 @@ pub fn compile_module_source_to_aot( } #[cfg(feature = "dsl-aot")] +/// Export a lowered execution model as a native AoT artifact. +/// +/// Use this lower-level entrypoint when you already own the frontend pipeline +/// and only need artifact generation. pub fn export_execution_model_to_aot( model: &ExecutionModel, options: NativeAotCompileOptions, @@ -240,6 +298,10 @@ pub fn export_execution_model_to_aot( } #[cfg(feature = "dsl-aot-load")] +/// Read only the metadata from a native AoT artifact. +/// +/// This is useful when you need to inspect model identity, routes, outputs, or +/// buffer sizes without loading the executable kernels. pub fn read_aot_model_info(path: impl AsRef) -> Result { let library = unsafe { Library::new(path.as_ref()) } .map_err(|error| AotError::Load(error.to_string()))?; @@ -248,6 +310,7 @@ pub fn read_aot_model_info(path: impl AsRef) -> Result) -> Result { let path = path.as_ref(); let library = @@ -543,14 +606,14 @@ mod tests { let subject = crate::Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build(); let support = vec![1.2, 5.0, 40.0, 0.5, 0.8]; diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index b0f1fe4a..684b7810 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -83,6 +83,11 @@ pub type JitAnalyticalModel = NativeAnalyticalModel; pub type JitSdeModel = NativeSdeModel; pub type CompiledJitModel = CompiledNativeModel; +/// Error reported while lowering an execution model into native in-process JIT +/// code. +/// +/// The error retains the backend diagnostic so callers can render the message +/// against the original DSL source when available. #[derive(Clone, PartialEq, Eq)] pub struct JitCompileError { diagnostic: Box, @@ -214,6 +219,10 @@ struct LoweredValue { ty: ValueType, } +/// Compile one lowered execution model into a reusable JIT kernel artifact. +/// +/// This builds the raw Cranelift-compiled kernel bundle for all roles present in +/// the model. Most callers should use [`compile_execution_model_to_jit`] instead. pub fn compile_execution_artifact( model: &ExecutionModel, ) -> Result { @@ -1217,6 +1226,41 @@ fn state_address( Ok(builder.ins().iadd(base, byte_offset)) } +/// Compile an [`ExecutionModel`](pharmsol_dsl::ExecutionModel) to the native +/// in-process JIT backend. +/// +/// Use this low-level entrypoint when you already own the parse, analyze, and +/// lower steps and want the JIT backend directly instead of the higher-level +/// runtime facade. +/// +/// This function requires the `dsl-jit` feature. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{ +/// analyze_model, compile_execution_model_to_jit, lower_typed_model, parse_model, +/// }; +/// +/// let parsed = parse_model( +/// r#" +/// model implicit_route_injection { +/// kind ode +/// states { central } +/// routes { iv -> central } +/// dynamics { +/// ddt(central) = 0 +/// } +/// outputs { +/// cp = central +/// } +/// } +/// "#, +/// )?; +/// let typed = analyze_model(&parsed)?; +/// let execution = lower_typed_model(&typed)?; +/// let compiled = compile_execution_model_to_jit(&execution)?; +/// # let _ = compiled; +/// # Ok::<(), Box>(()) +/// ``` pub fn compile_execution_model_to_jit( model: &ExecutionModel, ) -> Result { @@ -1229,6 +1273,7 @@ pub fn compile_execution_model_to_jit( } } +/// Compile an ODE execution model to the native in-process JIT backend. pub fn compile_ode_model_to_jit(model: &ExecutionModel) -> Result { if model.kind != ModelKind::Ode { return Err(JitCompileError::new( @@ -1245,6 +1290,7 @@ pub fn compile_ode_model_to_jit(model: &ExecutionModel) -> Result Result { @@ -1263,6 +1309,7 @@ pub fn compile_analytical_model_to_jit( )) } +/// Compile an SDE execution model to the native in-process JIT backend. pub fn compile_sde_model_to_jit(model: &ExecutionModel) -> Result { if model.kind != ModelKind::Sde { return Err(JitCompileError::new( @@ -1331,7 +1378,7 @@ mod tests { } #[test] - fn authoring_runtime_shares_channel_between_bolus_and_infusion_routes() { + fn authoring_runtime_shares_input_between_bolus_and_infusion_routes() { let source = r#" name = shared_authoring kind = ode @@ -1360,21 +1407,33 @@ out(cp) = central / v ~ continuous() let cp = jit.output_index("cp").expect("cp output"); assert_eq!(oral, 0); assert_eq!(iv, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("ode") + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .observation(0.5, 0.0, "cp") + .observation(1.0, 0.0, "cp") + .observation(2.0, 0.0, "cp") + .observation(6.0, 0.0, "cp") + .observation(7.0, 0.0, "cp") + .observation(9.0, 0.0, "cp") + .build(); - let subject = Subject::builder("ode") - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .observation(0.5, 0.0, cp) - .observation(1.0, 0.0, cp) - .observation(2.0, 0.0, cp) - .observation(6.0, 0.0, cp) - .observation(7.0, 0.0, cp) - .observation(9.0, 0.0, cp) + let reference_subject = Subject::builder("ode") + .bolus(0.0, 120.0, 0) + .infusion(6.0, 60.0, 0, 2.0) + .observation(0.5, 0.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(6.0, 0.0, 0) + .observation(7.0, 0.0, 0) + .observation(9.0, 0.0, 0) .build(); let support = vec![1.2, 0.15, 40.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit predictions"); let reference = ODE::new( @@ -1397,7 +1456,7 @@ out(cp) = central / v ~ continuous() .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference ode predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1491,22 +1550,35 @@ out(cp) = central / v ~ continuous() let cp = jit.output_index("cp").expect("cp output"); assert_eq!(oral, 0); assert_eq!(iv, 1); + assert_eq!(cp, 0); - let subject = Subject::builder("ode") + let jit_subject = Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") + .build(); + + let reference_subject = Subject::builder("ode") + .covariate("wt", 0.0, 70.0) + .bolus(0.0, 120.0, 0) + .infusion(6.0, 60.0, 1, 2.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(6.0, 0) + .missing_observation(7.0, 0) + .missing_observation(9.0, 0) .build(); let support = vec![1.2, 5.0, 40.0, 0.5, 0.8]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit predictions"); let reference = ODE::new( @@ -1551,7 +1623,7 @@ out(cp) = central / v ~ continuous() .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference ode predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1574,18 +1646,28 @@ out(cp) = central / v ~ continuous() let oral = jit.route_index("oral").expect("oral route"); let cp = jit.output_index("cp").expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("analytical") + .bolus(0.0, 100.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); - let subject = Subject::builder("analytical") - .bolus(0.0, 100.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + let reference_subject = Subject::builder("analytical") + .bolus(0.0, 100.0, 0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(4.0, 0) .build(); let support = vec![1.0, 0.15, 25.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit analytical predictions"); let reference = equation::Analytical::new( @@ -1603,7 +1685,7 @@ out(cp) = central / v ~ continuous() .with_nout(1); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference analytical predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1628,19 +1710,30 @@ out(cp) = central / v ~ continuous() let oral = jit.route_index("oral").expect("oral route"); let cp = jit.output_index("cp").expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("sde") + .covariate("wt", 0.0, 70.0) + .bolus(0.0, 80.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); - let subject = Subject::builder("sde") + let reference_subject = Subject::builder("sde") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 80.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 80.0, 0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(4.0, 0) .build(); let support = vec![1.1, 0.2, 0.12, 0.08, 15.0, 0.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit sde predictions"); let reference = SDE::new( @@ -1677,7 +1770,7 @@ out(cp) = central / v ~ continuous() .with_nout(1); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference sde predictions"); for (jit_pred, reference_pred) in jit_predictions diff --git a/src/dsl/mod.rs b/src/dsl/mod.rs index 563e4cf6..f536c377 100644 --- a/src/dsl/mod.rs +++ b/src/dsl/mod.rs @@ -1,9 +1,94 @@ //! Public DSL facade for pharmsol. //! -//! The backend-neutral frontend is being extracted into `pharmsol-dsl`. -//! Frontend syntax, diagnostics, semantic analysis, and lowering now come -//! from `pharmsol-dsl`, while runtime and backend compilation entrypoints -//! remain owned by `pharmsol`. +//! Use this module when you want to work with pharmsol models as source text +//! and stay inside the main crate for the full workflow: parse DSL source, +//! inspect diagnostics, lower to the execution model, compile to a runtime +//! backend, load saved artifacts, and run predictions. +//! +//! Use the `pharmsol-dsl` crate directly only when you need the backend-neutral +//! frontend as an engineering API. That crate owns parsing, diagnostics, +//! semantic analysis, and lowering. This module re-exports that stable +//! frontend surface and adds the backend-specific entrypoints that stay owned +//! by `pharmsol`. +//! +//! Main entrypoints: +//! +//! - [`parse_model`], [`parse_module`], [`analyze_model`], and +//! [`analyze_module`] for frontend-only validation and inspection. +//! - [`lower_typed_model`] and [`lower_typed_module`] for lowering typed models +//! into the execution representation used by the runtime backends. +//! - [`compile_module_source_to_runtime`] and [`compile_execution_model_to_runtime`] +//! for the one-stop compile-and-run path. +//! - [`load_runtime_artifact`], [`load_aot_model`], and +//! [`load_runtime_wasm_bytes`] for loading saved artifacts back into a model +//! you can execute. +//! +//! Common workflow choices: +//! +//! - Frontend only: parse, analyze, and lower when you need diagnostics, +//! authoring tools, or your own backend. +//! - In-process execution: compile straight to [`RuntimeCompilationTarget`] and +//! keep everything inside the current process. +//! - Native artifact shipping: export a native AoT artifact, then load it later +//! on a compatible host. +//! - WASM artifact shipping: emit `.wasm` bytes or a bundled module for browser +//! or portable runtime use. +//! +//! Feature map: +//! +//! - `dsl-core`: enables this facade and the frontend re-exports from +//! `pharmsol-dsl`. +//! - `dsl-jit`: enables in-process JIT compilation through +//! [`compile_module_source_to_runtime`] with +//! [`RuntimeCompilationTarget::Jit`], plus the lower-level JIT compile +//! entrypoints. +//! - `dsl-aot`: enables native ahead-of-time artifact export through +//! [`compile_module_source_to_aot`] and [`export_execution_model_to_aot`]. +//! - `dsl-aot-load`: enables native AoT artifact loading through +//! [`load_aot_model`] and [`read_aot_model_info`]. +//! - `dsl-wasm-compile`: enables WASM artifact emission through +//! [`compile_module_source_to_wasm_bytes`], +//! [`compile_module_source_to_wasm_module`], and the browser loader helpers. +//! - `dsl-wasm`: enables host-side WASM loading and runtime execution on +//! non-browser native hosts. This includes +//! [`compile_module_source_to_runtime_wasm`], [`load_runtime_wasm_bytes`], +//! [`read_wasm_model_info`], and [`read_wasm_model_info_bytes`]. +//! +//! Smallest compile-to-runtime example: +//! +//! This example requires `dsl-jit`. +//! +//! ```rust,no_run +//! use pharmsol::dsl::{compile_module_source_to_runtime, RuntimeCompilationTarget}; +//! +//! let source = r#" +//! name = bimodal_ke +//! kind = ode +//! +//! params = ke, v +//! states = central +//! outputs = cp +//! +//! infusion(iv) -> central +//! +//! dx(central) = -ke * central +//! out(cp) = central / v +//! "#; +//! +//! let model = compile_module_source_to_runtime( +//! source, +//! Some("bimodal_ke"), +//! RuntimeCompilationTarget::Jit, +//! |_, _| {}, +//! )?; +//! +//! # let _ = model; +//! # Ok::<(), pharmsol::dsl::RuntimeError>(()) +//! ``` +//! +//! For a lower-level frontend pipeline without backend selection, use +//! `pharmsol-dsl`. For a complete runtime path inside the main crate, stay in +//! [`pharmsol::dsl`](self). #[cfg(any(feature = "dsl-aot", feature = "dsl-aot-load"))] mod aot; diff --git a/src/dsl/model_info.rs b/src/dsl/model_info.rs index 0094059f..7dd3f72a 100644 --- a/src/dsl/model_info.rs +++ b/src/dsl/model_info.rs @@ -8,47 +8,80 @@ use pharmsol_dsl::execution::{ }; use pharmsol_dsl::{AnalyticalKernel, ModelKind, RouteKind}; +/// Public metadata extracted from a compiled backend model. +/// +/// This is the shared inspection surface returned by the native AoT, WASM, and +/// runtime loaders. It keeps public labels and buffer sizes available without +/// exposing backend-specific kernel details. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeModelInfo { + /// Public model name. pub name: String, + /// High-level model family. pub kind: ModelKind, + /// Parameter names in support-point order. pub parameters: Vec, + /// Declared covariates and their dense runtime indices. pub covariates: Vec, + /// Declared routes together with declaration-order and dense runtime indices. pub routes: Vec, + /// Declared outputs and their dense runtime indices. pub outputs: Vec, + /// Length of the state buffer used during execution. pub state_len: usize, + /// Length of the derived-value buffer used during execution. pub derived_len: usize, + /// Length of the output buffer used during execution. pub output_len: usize, + /// Length of the dense route-input buffer used during execution. pub route_len: usize, + /// Analytical kernel metadata when the compiled model is analytical. pub analytical: Option, + /// Particle count when the compiled model is stochastic. pub particles: Option, } +/// Metadata for one compiled covariate. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeCovariateInfo { + /// Public covariate name. pub name: String, + /// Dense runtime covariate index. pub index: usize, } +/// Metadata for one compiled route. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeRouteInfo { + /// Public route label. pub name: String, + /// Route position in declaration order. #[serde(default)] pub declaration_index: usize, + /// Dense runtime route-input index. pub index: usize, + /// Coarse route kind when declared in metadata. #[serde(default)] pub kind: Option, + /// Dense destination state offset used by compiled kernels. pub destination_offset: usize, + /// Whether the compiled backend injects the route input into the destination + /// state automatically when the model does not read the route input + /// explicitly. pub inject_input_to_destination: bool, } +/// Metadata for one compiled output. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeOutputInfo { + /// Public output label. pub name: String, + /// Dense runtime output index. pub index: usize, } impl NativeModelInfo { + /// Build public compiled-model metadata from a lowered execution model. pub fn from_execution_model(model: &ExecutionModel) -> Self { let explicit_route_input_usage = explicit_route_input_usage(model); Self { @@ -243,7 +276,7 @@ model explicit_route_usage { } #[test] - fn authoring_shared_channel_routes_keep_declaration_specific_injection() { + fn authoring_shared_input_routes_keep_declaration_specific_injection() { let info = load_model_info( r#" name = shared_authoring diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 202fd45a..d9598172 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::collections::HashMap; use std::sync::Arc; use diffsol::{ @@ -20,16 +21,19 @@ pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, }; use crate::{ - data::{Covariates, Infusion}, + data::error_model::AssayErrorModels, + data::{Covariates, Infusion, InputLabel, OutputLabel}, simulator::{ + cache::{PredictionCache, DEFAULT_CACHE_SIZE}, equation::{ ode::{closure_helpers::PMProblem, ExplicitRkTableau, OdeSolver, SdirkTableau}, sde::simulate_sde_event_with, + EqnKind, Equation, EquationPriv, EquationTypes, }, likelihood::{Prediction, SubjectPredictions}, - M, V, + Fa, Lag, M, T, V, }, - Event, Observation, PharmsolError, Subject, + Event, Observation, Occasion, PharmsolError, Subject, }; pub type DenseKernelFn = unsafe extern "C" fn( @@ -375,6 +379,16 @@ impl SharedNativeModel { Ok(()) } + fn validate_output(&self, outeq: usize) -> Result<(), PharmsolError> { + if outeq >= self.info.output_len { + return Err(PharmsolError::OuteqOutOfRange { + outeq, + nout: self.info.output_len, + }); + } + Ok(()) + } + fn validate_input_for_kind(&self, input: usize, kind: RouteKind) -> Result<(), PharmsolError> { self.validate_input(input)?; if self.route_semantics.supports_input(input, kind) { @@ -382,11 +396,55 @@ impl SharedNativeModel { } Err(PharmsolError::OtherError(format!( - "model `{}` does not declare a {:?} route for input channel {}", + "model `{}` does not declare a {:?} route for input {}", self.info.name, kind, input ))) } + fn resolve_input_label( + &self, + label: &InputLabel, + kind: RouteKind, + ) -> Result { + let input = + self.route_index(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; + self.validate_input_for_kind(input, kind)?; + Ok(input) + } + + fn resolve_output_label(&self, label: &OutputLabel) -> Result { + self.output_index(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + }) + } + + fn resolve_events(&self, occasion: &Occasion) -> Result, PharmsolError> { + let mut events = occasion.process_events(None, true); + + for event in events.iter_mut() { + match event { + Event::Bolus(bolus) => { + let input = self.resolve_input_label(bolus.input(), RouteKind::Bolus)?; + bolus.set_input(input); + } + Event::Infusion(infusion) => { + let input = self.resolve_input_label(infusion.input(), RouteKind::Infusion)?; + infusion.set_input(input); + } + Event::Observation(observation) => { + let outeq = self.resolve_output_label(observation.outeq())?; + observation.set_outeq(outeq); + } + } + } + + Ok(events) + } + fn fill_cov_buffer(&self, covariates: &Covariates, time: f64, buf: &mut [f64]) { for covariate in &self.info.covariates { buf[covariate.index] = match covariates.get_covariate(&covariate.name) { @@ -530,7 +588,13 @@ impl SharedNativeModel { for event in events.iter_mut() { if let Event::Bolus(bolus) = event { - self.validate_input_for_kind(bolus.input(), RouteKind::Bolus)?; + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + self.validate_input_for_kind(input, RouteKind::Bolus)?; if self.artifact.has_kernel(KernelRole::RouteLag) { lag_values.fill(0.0); @@ -556,7 +620,7 @@ impl SharedNativeModel { lag_values.as_mut_ptr(), )?; } - let lag = lag_values[bolus.input()]; + let lag = lag_values[input]; if lag != 0.0 { *bolus.mut_time() += lag; } @@ -586,7 +650,7 @@ impl SharedNativeModel { fa_values.as_mut_ptr(), )?; } - let factor = fa_values[bolus.input()]; + let factor = fa_values[input]; if factor != 1.0 { bolus.set_amount(bolus.amount() * factor); } @@ -610,7 +674,7 @@ impl SharedNativeModel { .bolus_destination(input) .ok_or_else(|| { PharmsolError::OtherError(format!( - "model `{}` does not declare a bolus route for input channel {}", + "model `{}` does not declare a bolus route for input index {}", self.info.name, input )) })?; @@ -651,13 +715,13 @@ impl SharedNativeModel { &cov_buf, &mut outputs, )?; - if observation.outeq() >= outputs.len() { - return Err(PharmsolError::OuteqOutOfRange { - outeq: observation.outeq(), - nout: outputs.len(), - }); - } - Ok(observation.to_prediction(outputs[observation.outeq()], state.to_vec())) + let outeq = observation + .outeq_index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + })?; + self.validate_output(outeq)?; + Ok(observation.to_prediction(outputs[outeq], state.to_vec())) } } @@ -667,6 +731,7 @@ pub struct NativeOdeModel { solver: OdeSolver, rtol: f64, atol: f64, + cache: Option, } #[derive(Clone, Debug)] @@ -694,6 +759,7 @@ impl NativeOdeModel { solver: OdeSolver::default(), rtol: DEFAULT_ODE_RTOL, atol: DEFAULT_ODE_ATOL, + cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), } } @@ -734,18 +800,15 @@ impl NativeOdeModel { let support_vector: V = DVector::from_vec(support_point.to_vec()).into(); for occasion in subject.occasions() { - let infusion_refs = occasion.infusions_ref(); - let infusions = infusion_refs + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared - .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; - } - - let mut events = occasion.process_events(None, true); + let infusion_refs = infusions.iter().collect::>(); let session = RefCell::new(self.shared.artifact.start_session()?); let mut route_session = session.borrow_mut(); self.shared.apply_route_properties( @@ -901,9 +964,15 @@ impl NativeOdeModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; self.shared.apply_bolus( solver.state_mut().y.as_mut_slice(), - bolus.input(), + input, bolus.amount(), )?; } @@ -968,6 +1037,175 @@ impl NativeOdeModel { } } +fn runtime_no_lag(_: &V, _: T, _: &Covariates) -> HashMap { + HashMap::new() +} + +fn runtime_no_fa(_: &V, _: T, _: &Covariates) -> HashMap { + HashMap::new() +} + +#[inline(always)] +fn runtime_ode_predictions( + model: &NativeOdeModel, + subject: &Subject, + support_point: &[f64], +) -> Result { + if let Some(cache) = &model.cache { + let key = ( + subject.hash(), + crate::simulator::equation::spphash(support_point), + ); + if let Some(cached) = cache.get(&key) { + return Ok(cached); + } + + let result = model.estimate_predictions(subject, support_point)?; + cache.insert(key, result.clone()); + Ok(result) + } else { + model.estimate_predictions(subject, support_point) + } +} + +impl crate::simulator::equation::Cache for NativeOdeModel { + fn with_cache_capacity(mut self, size: u64) -> Self { + self.cache = Some(PredictionCache::new(size)); + self + } + + fn enable_cache(mut self) -> Self { + self.cache = Some(PredictionCache::new(DEFAULT_CACHE_SIZE)); + self + } + + fn clear_cache(&self) { + if let Some(cache) = &self.cache { + cache.invalidate_all(); + } + } + + fn disable_cache(mut self) -> Self { + self.cache = None; + self + } +} + +impl EquationTypes for NativeOdeModel { + type S = V; + type P = SubjectPredictions; +} + +impl EquationPriv for NativeOdeModel { + fn lag(&self) -> &Lag { + &(runtime_no_lag as Lag) + } + + fn fa(&self) -> &Fa { + &(runtime_no_fa as Fa) + } + + fn get_nstates(&self) -> usize { + self.shared.info.state_len + } + + fn get_ndrugs(&self) -> usize { + self.shared.info.route_len + } + + fn get_nouteqs(&self) -> usize { + self.shared.info.output_len + } + + fn metadata(&self) -> Option<&crate::ValidatedModelMetadata> { + None + } + + fn solve( + &self, + _state: &mut Self::S, + _support_point: &[f64], + _covariates: &Covariates, + _infusions: &[Infusion], + _start_time: f64, + _end_time: f64, + ) -> Result<(), PharmsolError> { + unimplemented!("solve is not used for runtime ODE models") + } + + fn process_observation( + &self, + _support_point: &[f64], + _observation: &Observation, + _error_models: Option<&AssayErrorModels>, + _time: f64, + _covariates: &Covariates, + _x: &mut Self::S, + _likelihood: &mut Vec, + _output: &mut Self::P, + ) -> Result<(), PharmsolError> { + unimplemented!("process_observation is not used for runtime ODE models") + } + + fn initial_state( + &self, + _support_point: &[f64], + _covariates: &Covariates, + _occasion_index: usize, + ) -> Self::S { + V::zeros(self.shared.info.state_len, NalgebraContext) + } +} + +impl Equation for NativeOdeModel { + fn estimate_likelihood( + &self, + subject: &Subject, + support_point: &[f64], + error_models: &AssayErrorModels, + ) -> Result { + Ok(self + .estimate_log_likelihood(subject, support_point, error_models)? + .exp()) + } + + fn estimate_log_likelihood( + &self, + subject: &Subject, + support_point: &[f64], + error_models: &AssayErrorModels, + ) -> Result { + let predictions = runtime_ode_predictions(self, subject, support_point)?; + predictions.log_likelihood(error_models) + } + + fn kind() -> EqnKind { + EqnKind::ODE + } + + fn estimate_predictions( + &self, + subject: &Subject, + support_point: &[f64], + ) -> Result { + runtime_ode_predictions(self, subject, support_point) + } + + fn simulate_subject( + &self, + subject: &Subject, + support_point: &[f64], + error_models: Option<&AssayErrorModels>, + ) -> Result<(Self::P, Option), PharmsolError> { + let predictions = runtime_ode_predictions(self, subject, support_point)?; + let likelihood = match error_models { + Some(error_models) => Some(predictions.log_likelihood(error_models)?.exp()), + None => None, + }; + Ok((predictions, likelihood)) + } +} + impl NativeAnalyticalModel { pub(crate) fn new(info: NativeModelInfo, artifact: impl RuntimeArtifact + 'static) -> Self { Self { @@ -1000,18 +1238,14 @@ impl NativeAnalyticalModel { let mut output = SubjectPredictions::default(); for occasion in subject.occasions() { - let infusions = occasion - .infusions_ref() + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared - .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; - } - - let mut events = occasion.process_events(None, true); let mut session = self.shared.artifact.start_session()?; self.shared.apply_route_properties( &mut *session, @@ -1030,8 +1264,12 @@ impl NativeAnalyticalModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { - self.shared - .apply_bolus(&mut state, bolus.input(), bolus.amount())? + let input = bolus.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + } + })?; + self.shared.apply_bolus(&mut state, input, bolus.amount())? } Event::Infusion(_) => {} Event::Observation(observation) => { @@ -1171,18 +1409,14 @@ impl NativeSdeModel { let mut output = Array2::from_shape_fn((self.nparticles, 0), |_| Prediction::default()); for occasion in subject.occasions() { - let infusions = occasion - .infusions_ref() + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared - .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; - } - - let mut events = occasion.process_events(None, true); let mut session = self.shared.artifact.start_session()?; self.shared.apply_route_properties( &mut *session, @@ -1204,10 +1438,15 @@ impl NativeSdeModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { + let input = bolus.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + } + })?; for particle in &mut particles { self.shared.apply_bolus( particle.as_mut_slice(), - bolus.input(), + input, bolus.amount(), )?; } @@ -1398,11 +1637,14 @@ impl NativeSdeModel { fn active_route_inputs(infusions: &[Infusion], time: f64, route_len: usize) -> Vec { let mut values = vec![0.0; route_len]; for infusion in infusions { - if infusion.input() < route_len + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + if input < route_len && time >= infusion.time() && time <= infusion.time() + infusion.duration() { - values[infusion.input()] += infusion.amount() / infusion.duration(); + values[input] += infusion.amount() / infusion.duration(); } } values @@ -1417,8 +1659,11 @@ fn interval_route_inputs( let mut values = vec![0.0; route_len]; for infusion in infusions { let finish = infusion.time() + infusion.duration(); - if infusion.input() < route_len && start_time >= infusion.time() && end_time <= finish { - values[infusion.input()] += infusion.amount() / infusion.duration(); + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + if input < route_len && start_time >= infusion.time() && end_time <= finish { + values[input] += infusion.amount() / infusion.duration(); } } values diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index 1d49d82a..1cef784e 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -1,3 +1,82 @@ +//! Unified runtime entrypoints for DSL-backed models. +//! +//! Use this module when you already know you want an executable model and need +//! one backend-neutral surface for compile, load, and prediction workflows. +//! It normalizes the backend-specific JIT, native AoT, and WASM entrypoints so +//! callers can choose a deployment target without rewriting the downstream +//! prediction code. +//! +//! Use the backend modules directly only when you need a backend-specific +//! artifact or compile control: +//! +//! - [`super::jit`] for direct in-process JIT compilation. +//! - [`compile_module_source_to_aot`][crate::dsl::compile_module_source_to_aot] for native artifact export and reload. +//! - [`compile_module_source_to_wasm_bytes`][crate::dsl::compile_module_source_to_wasm_bytes] and [`load_runtime_wasm_bytes`] for portable WASM bytes, +//! browser-loader assets, and host-side WASM loading. +//! +//! Main entrypoints: +//! +//! - [`compile_module_source_to_runtime`] for the one-stop source-to-runtime +//! path. +//! - [`compile_execution_model_to_runtime`] when you already have an +//! [`ExecutionModel`](pharmsol_dsl::ExecutionModel). +//! - [`load_runtime_artifact`] and [`load_runtime_wasm_bytes`] when the model +//! has already been compiled and stored elsewhere. +//! - [`CompiledRuntimeModel::estimate_predictions`] for backend-neutral +//! execution against a [`Subject`](crate::Subject). +//! +//! Backend choice guide: +//! +//! - [`RuntimeCompilationTarget::Jit`] keeps compilation and execution inside +//! the current process. Use it for native interactive workflows and tests. +//! - [`RuntimeCompilationTarget::NativeAot`] emits a native artifact and reloads +//! it into the same runtime model shape. Use it when you want reusable native +//! artifacts and can control the target platform. +//! - [`RuntimeCompilationTarget::Wasm`] emits portable WASM bytes and reloads +//! them into the host-side runtime adapter. Use it when you need a portable +//! artifact or browser-aligned deployment story. +//! +//! Smallest compile-and-run example: +//! +//! This example requires `dsl-jit`. +//! +//! ```rust,no_run +//! use pharmsol::dsl::{compile_module_source_to_runtime, RuntimeCompilationTarget}; +//! use pharmsol::prelude::*; +//! +//! let source = r#" +//! name = bimodal_ke +//! kind = ode +//! +//! params = ke, v +//! states = central +//! outputs = cp +//! +//! infusion(iv) -> central +//! +//! dx(central) = -ke * central +//! out(cp) = central / v +//! "#; +//! +//! let model = compile_module_source_to_runtime( +//! source, +//! Some("bimodal_ke"), +//! RuntimeCompilationTarget::Jit, +//! |_, _| {}, +//! )?; +//! +//! let subject = Subject::builder("patient_001") +//! .infusion(0.0, 500.0, "iv", 0.5) +//! .missing_observation(0.5, "cp") +//! .missing_observation(1.0, "cp") +//! .missing_observation(2.0, "cp") +//! .build(); +//! +//! let predictions = model.estimate_predictions(&subject, &[1.2, 50.0])?; +//! assert!(predictions.as_subject().is_some()); +//! # Ok::<(), pharmsol::dsl::RuntimeError>(()) +//! ``` + use std::fmt; use std::path::Path; @@ -39,24 +118,39 @@ pub type RuntimeOdeModel = NativeOdeModel; pub type RuntimeAnalyticalModel = NativeAnalyticalModel; pub type RuntimeSdeModel = NativeSdeModel; +/// Selects which backend should produce the executable runtime model. +/// +/// This enum is the main backend-switching point for +/// [`compile_module_source_to_runtime`] and +/// [`compile_execution_model_to_runtime`]. #[derive(Debug, Clone, PartialEq, Eq)] pub enum RuntimeCompilationTarget { + /// Compile and execute the model inside the current native process. #[cfg(feature = "dsl-jit")] Jit, + /// Export a native artifact and reload it as a runtime model. #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] NativeAot(NativeAotCompileOptions), + /// Emit WASM bytes and reload them through the host-side WASM runtime. #[cfg(feature = "dsl-wasm")] Wasm, } +/// Identifies the on-disk artifact format for [`load_runtime_artifact`]. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RuntimeArtifactFormat { + /// A native ahead-of-time artifact produced by the AoT compiler. #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] NativeAot, + /// A WASM artifact produced by the WASM compiler. #[cfg(feature = "dsl-wasm")] Wasm, } +/// Backend-neutral prediction output from a compiled runtime model. +/// +/// ODE and analytical models return subject predictions. SDE models return the +/// particle matrix used by the stochastic workflow. #[derive(Clone, Debug)] pub enum RuntimePredictions { Subject(SubjectPredictions), @@ -93,6 +187,10 @@ impl RuntimePredictions { } } +/// Executable runtime model returned by the backend-neutral runtime surface. +/// +/// This type hides the concrete backend and keeps the prediction entrypoint the +/// same across JIT, native AoT, and WASM-based flows. #[derive(Clone, Debug)] pub enum CompiledRuntimeModel { Ode(RuntimeOdeModel), @@ -166,6 +264,8 @@ impl CompiledRuntimeModel { } } +/// Errors produced while parsing, lowering, compiling, loading, or executing a +/// runtime DSL model. #[derive(Error)] pub enum RuntimeError { #[error("failed to parse DSL source: {0}")] @@ -231,6 +331,10 @@ impl fmt::Debug for RuntimeError { } } +/// Parse, analyze, lower, compile, and return a runtime model in one step. +/// +/// Use this when your input is DSL source text and you want the shortest path +/// from source to predictions. pub fn compile_module_source_to_runtime( source: &str, model_name: Option<&str>, @@ -269,6 +373,10 @@ pub fn compile_module_source_to_runtime( }) } +/// Compile a lowered execution model to a selected runtime backend. +/// +/// Use this when you already own the frontend pipeline and only need the final +/// backend step. pub fn compile_execution_model_to_runtime( model: &ExecutionModel, target: RuntimeCompilationTarget, @@ -309,6 +417,7 @@ pub fn compile_execution_model_to_runtime( } } +/// Load a previously compiled native AoT or WASM artifact from disk. pub fn load_runtime_artifact( path: impl AsRef, format: RuntimeArtifactFormat, @@ -330,6 +439,7 @@ pub fn load_runtime_artifact( } #[cfg(feature = "dsl-wasm")] +/// Compile DSL source straight to a host-side runtime model via the WASM path. pub fn compile_module_source_to_runtime_wasm( source: &str, model_name: Option<&str>, @@ -339,6 +449,8 @@ pub fn compile_module_source_to_runtime_wasm( } #[cfg(feature = "dsl-wasm")] +/// Compile a lowered execution model straight to a host-side runtime model via +/// the WASM path. pub fn compile_execution_model_to_runtime_wasm( model: &ExecutionModel, ) -> Result { @@ -347,6 +459,7 @@ pub fn compile_execution_model_to_runtime_wasm( } #[cfg(feature = "dsl-wasm")] +/// Load a runtime model from in-memory WASM bytes. pub fn load_runtime_wasm_bytes(bytes: &[u8]) -> Result { let (info, artifact) = load_wasm_artifact_bytes(bytes)?; Ok(runtime_model_from_parts(info, artifact)) @@ -377,11 +490,110 @@ mod tests { use super::*; use crate::dsl::compile_sde_model_to_jit; use crate::test_fixtures::STRUCTURED_BLOCK_CORPUS; + use crate::PharmsolError; use crate::SubjectBuilderExt; use approx::assert_relative_eq; use pharmsol_dsl::{DiagnosticPhase, DSL_BACKEND_GENERIC, DSL_PARSE_GENERIC}; use tempfile::tempdir; + const MULTI_DIGIT_OUTPUT_ORDER_RUNTIME_DSL: &str = r#" +name = multi_digit_output_runtime +kind = ode + +params = ke, v +states = central +outputs = 2, 10, 11 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(10) = central / v ~ continuous() +out(2) = central / v ~ continuous() +out(11) = central / v ~ continuous() +"#; + + const NUMERIC_ROUTE_LABELS_RUNTIME_DSL: &str = r#" +name = numeric_route_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(10) -> central +bolus(11) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + + const SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" +name = shared_numeric_route_output_runtime +kind = ode + +params = ke, v +states = central +outputs = 1 + +infusion(1) -> central + +dx(central) = -ke * central + +out(1) = central / v ~ continuous() +"#; + + const UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" +name = undeclared_numeric_output_runtime +kind = ode + +params = ke, v +states = central +outputs = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(a0) = central / v ~ continuous() +out(a1) = central / v ~ continuous() +out(a2) = central / v ~ continuous() +out(a3) = central / v ~ continuous() +out(a4) = central / v ~ continuous() +out(a5) = central / v ~ continuous() +out(a6) = central / v ~ continuous() +out(a7) = central / v ~ continuous() +out(a8) = central / v ~ continuous() +out(a9) = central / v ~ continuous() +out(a10) = central / v ~ continuous() +"#; + + const UNDECLARED_NUMERIC_INPUT_LABEL_RUNTIME_DSL: &str = r#" +name = undeclared_numeric_input_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(r0) -> central +bolus(r1) -> central +bolus(r2) -> central +bolus(r3) -> central +bolus(r4) -> central +bolus(r5) -> central +bolus(r6) -> central +bolus(r7) -> central +bolus(r8) -> central +bolus(r9) -> central +bolus(r10) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + fn corpus_source() -> &'static str { STRUCTURED_BLOCK_CORPUS } @@ -397,17 +609,17 @@ mod tests { pharmsol_dsl::lower_typed_model(model).expect("lower corpus model") } - fn ode_subject(output: usize, oral: usize, iv: usize) -> Subject { + fn ode_subject() -> Subject { Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, output) - .missing_observation(1.0, output) - .missing_observation(2.0, output) - .missing_observation(6.0, output) - .missing_observation(7.0, output) - .missing_observation(9.0, output) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build() } @@ -421,6 +633,92 @@ mod tests { .collect() } + fn compile_runtime_backend_matrix( + source: &str, + model_name: &str, + work_dir: &std::path::Path, + ) -> ( + CompiledRuntimeModel, + CompiledRuntimeModel, + CompiledRuntimeModel, + ) { + let jit = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::Jit, + |_, _| {}, + ) + .expect("compile jit runtime model"); + let aot = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::NativeAot( + NativeAotCompileOptions::new(work_dir.join(format!("{model_name}-aot-build"))) + .with_output(work_dir.join(format!("{model_name}.pkm"))), + ), + |_, _| {}, + ) + .expect("compile aot runtime model"); + let wasm = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::Wasm, + |_, _| {}, + ) + .expect("compile wasm runtime model"); + + (jit, aot, wasm) + } + + fn numeric_route_subject() -> Subject { + Subject::builder("numeric-route-runtime") + .bolus(0.0, 120.0, "10") + .bolus(1.0, 80.0, "11") + .missing_observation(0.5, "cp") + .missing_observation(1.5, "cp") + .build() + } + + fn shared_numeric_route_output_subject() -> Subject { + Subject::builder("shared-numeric-route-output-runtime") + .infusion(0.0, 120.0, "1", 1.0) + .missing_observation(0.5, "1") + .missing_observation(1.5, "1") + .build() + } + + fn assert_unknown_output_label( + model: &CompiledRuntimeModel, + subject: &Subject, + support: &[f64], + expected_label: &str, + ) { + let error = model + .estimate_predictions(subject, support) + .expect_err("undeclared numeric output label should fail"); + + assert!(matches!( + error, + RuntimeError::Runtime(PharmsolError::UnknownOutputLabel { label }) if label == expected_label + )); + } + + fn assert_unknown_input_label( + model: &CompiledRuntimeModel, + subject: &Subject, + support: &[f64], + expected_label: &str, + ) { + let error = model + .estimate_predictions(subject, support) + .expect_err("undeclared numeric input label should fail"); + + assert!(matches!( + error, + RuntimeError::Runtime(PharmsolError::UnknownInputLabel { label }) if label == expected_label + )); + } + #[test] fn runtime_backend_matrix_matches_ode_predictions() { let work_dir = tempdir().expect("tempdir"); @@ -460,10 +758,116 @@ mod tests { vec!["ka", "cl", "v", "tlag", "f_oral"] ); - let oral = jit.route_index("oral").expect("oral route"); - let iv = jit.route_index("iv").expect("iv route"); - let cp = jit.output_index("cp").expect("cp output"); - let subject = ode_subject(cp, oral, iv); + assert!(jit.route_index("oral").is_some()); + assert!(jit.route_index("iv").is_some()); + assert_eq!(jit.output_index("cp"), Some(0)); + let subject = ode_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + + #[test] + fn runtime_backend_matrix_preserves_multi_digit_output_label_order() { + let work_dir = tempdir().expect("tempdir"); + let (jit, aot, wasm) = compile_runtime_backend_matrix( + MULTI_DIGIT_OUTPUT_ORDER_RUNTIME_DSL, + "multi_digit_output_runtime", + work_dir.path(), + ); + + assert_eq!(jit.output_index("2"), Some(0)); + assert_eq!(jit.output_index("10"), Some(1)); + assert_eq!(jit.output_index("11"), Some(2)); + assert_eq!(aot.output_index("2"), Some(0)); + assert_eq!(aot.output_index("10"), Some(1)); + assert_eq!(aot.output_index("11"), Some(2)); + assert_eq!(wasm.output_index("2"), Some(0)); + assert_eq!(wasm.output_index("10"), Some(1)); + assert_eq!(wasm.output_index("11"), Some(2)); + } + + #[test] + fn runtime_backend_matrix_supports_multi_digit_numeric_route_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + NUMERIC_ROUTE_LABELS_RUNTIME_DSL, + "numeric_route_runtime", + work_dir.path(), + ); + + assert_eq!(jit.route_index("10"), Some(0)); + assert_eq!(jit.route_index("11"), Some(1)); + assert_eq!(aot.route_index("10"), Some(0)); + assert_eq!(aot.route_index("11"), Some(1)); + assert_eq!(wasm.route_index("10"), Some(0)); + assert_eq!(wasm.route_index("11"), Some(1)); + + let subject = numeric_route_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + + #[test] + fn runtime_backend_matrix_supports_shared_numeric_route_and_output_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL, + "shared_numeric_route_output_runtime", + work_dir.path(), + ); + + assert_eq!(jit.route_index("1"), Some(0)); + assert_eq!(jit.output_index("1"), Some(0)); + assert_eq!(aot.route_index("1"), Some(0)); + assert_eq!(aot.output_index("1"), Some(0)); + assert_eq!(wasm.route_index("1"), Some(0)); + assert_eq!(wasm.output_index("1"), Some(0)); + + let subject = shared_numeric_route_output_subject(); let jit_values = subject_values( &jit.estimate_predictions(&subject, &support) @@ -489,6 +893,44 @@ mod tests { } } + #[test] + fn runtime_backend_matrix_rejects_undeclared_numeric_output_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL, + "undeclared_numeric_output_runtime", + work_dir.path(), + ); + let subject = Subject::builder("runtime-undeclared-numeric-output") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "10") + .build(); + + assert_unknown_output_label(&jit, &subject, &support, "10"); + assert_unknown_output_label(&aot, &subject, &support, "10"); + assert_unknown_output_label(&wasm, &subject, &support, "10"); + } + + #[test] + fn runtime_backend_matrix_rejects_undeclared_numeric_input_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + UNDECLARED_NUMERIC_INPUT_LABEL_RUNTIME_DSL, + "undeclared_numeric_input_runtime", + work_dir.path(), + ); + let subject = Subject::builder("runtime-undeclared-numeric-input") + .bolus(0.0, 100.0, "10") + .missing_observation(0.5, "cp") + .build(); + + assert_unknown_input_label(&jit, &subject, &support, "10"); + assert_unknown_input_label(&aot, &subject, &support, "10"); + assert_unknown_input_label(&wasm, &subject, &support, "10"); + } + #[test] fn runtime_compile_preserves_parse_diagnostic_structure() { let source = "model broken { kind ode outputs { cp = 1 + } }"; diff --git a/src/dsl/wasm.rs b/src/dsl/wasm.rs index f2504d44..e95b799a 100644 --- a/src/dsl/wasm.rs +++ b/src/dsl/wasm.rs @@ -406,11 +406,39 @@ impl RuntimeArtifact for WasmExecutionArtifact { } } +/// Read only the metadata from a compiled WASM artifact on disk. +/// +/// Use this when you need model identity, route labels, output labels, or +/// buffer sizes without loading the executable runtime wrapper. pub fn read_wasm_model_info(path: impl AsRef) -> Result { let (info, _) = load_wasm_artifact(path)?; Ok(info) } +/// Read only the metadata from in-memory compiled WASM bytes. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{compile_module_source_to_wasm_bytes, read_wasm_model_info_bytes}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let bytes = compile_module_source_to_wasm_bytes(source, Some("bimodal_ke"))?; +/// let info = read_wasm_model_info_bytes(&bytes)?; +/// assert_eq!(info.name, "bimodal_ke"); +/// # Ok::<(), Box>(()) +/// ``` pub fn read_wasm_model_info_bytes(bytes: &[u8]) -> Result { let (info, _) = load_wasm_artifact_bytes(bytes)?; Ok(info) diff --git a/src/dsl/wasm_compile.rs b/src/dsl/wasm_compile.rs index caa60216..cda4727d 100644 --- a/src/dsl/wasm_compile.rs +++ b/src/dsl/wasm_compile.rs @@ -19,15 +19,24 @@ use pharmsol_dsl::{ LoweringError, ParseError, SemanticError, }; +/// ABI version for compiled WASM artifacts produced by this crate. pub const WASM_API_VERSION: u32 = 1; +/// Default entry capacity for [`WasmCompileCache`]. pub const DEFAULT_WASM_COMPILE_CACHE_CAPACITY: usize = 32; static BROWSER_LOADER_SOURCE: OnceLock = OnceLock::new(); +/// Portable WASM artifact bundle produced by the WASM compiler path. +/// +/// The bundle includes the raw WASM bytes, model metadata, and a browser loader +/// source string that can instantiate the model in JavaScript. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct CompiledWasmModule { + /// Raw compiled WASM bytes. pub wasm_bytes: Vec, + /// Serialized model metadata and kernel availability. pub metadata: CompiledModelInfoEnvelope, + /// JavaScript loader source for browser-side instantiation. pub browser_loader_source: String, } @@ -52,6 +61,7 @@ struct WasmCompileCacheState { lru: VecDeque, } +/// In-memory LRU cache for repeated WASM compilation from the same DSL source. #[derive(Debug)] pub struct WasmCompileCache { capacity: usize, @@ -65,6 +75,7 @@ impl Default for WasmCompileCache { } impl WasmCompileCache { + /// Create a new compile cache with at least one entry of capacity. pub fn new(capacity: usize) -> Self { Self { capacity: capacity.max(1), @@ -72,10 +83,12 @@ impl WasmCompileCache { } } + /// Return the configured cache capacity. pub fn capacity(&self) -> usize { self.capacity } + /// Return the number of cached compiled modules. pub fn entry_count(&self) -> usize { self.state .lock() @@ -84,6 +97,7 @@ impl WasmCompileCache { .len() } + /// Remove all cached compiled modules. pub fn clear(&self) { let mut state = self .state @@ -93,6 +107,8 @@ impl WasmCompileCache { state.lru.clear(); } + /// Compile DSL source to a full WASM module bundle, reusing the cache when + /// possible. pub fn compile_module_source_to_wasm_module( &self, source: &str, @@ -108,6 +124,7 @@ impl WasmCompileCache { Ok(compiled) } + /// Compile DSL source to raw WASM bytes, reusing the cache when possible. pub fn compile_module_source_to_wasm_bytes( &self, source: &str, @@ -145,6 +162,8 @@ impl WasmCompileCache { } } +/// Error produced while compiling, inspecting, or loading a DSL-backed WASM +/// artifact. #[derive(Error)] pub enum WasmError { #[error(transparent)] @@ -224,10 +243,12 @@ impl fmt::Debug for WasmError { } } +/// Compile a lowered execution model to raw WASM bytes. pub fn compile_execution_model_to_wasm_bytes(model: &ExecutionModel) -> Result, WasmError> { emit_execution_model_to_wasm_bytes(model, WASM_API_VERSION) } +/// Compile a lowered execution model to a portable WASM bundle. pub fn compile_execution_model_to_wasm_module( model: &ExecutionModel, ) -> Result { @@ -238,6 +259,7 @@ pub fn compile_execution_model_to_wasm_module( }) } +/// Parse DSL source, lower one selected model, and return raw WASM bytes. pub fn compile_module_source_to_wasm_bytes( source: &str, model_name: Option<&str>, @@ -245,6 +267,35 @@ pub fn compile_module_source_to_wasm_bytes( Ok(compile_module_source_to_wasm_module(source, model_name)?.wasm_bytes) } +/// Parse DSL source, lower one selected model, and return the full WASM bundle. +/// +/// Use this when you want a portable artifact for browser or host-side loading +/// together with the browser loader source. +/// +/// This function requires `dsl-wasm-compile`. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{browser_loader_source, compile_module_source_to_wasm_module}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let compiled = compile_module_source_to_wasm_module(source, Some("bimodal_ke"))?; +/// let loader = browser_loader_source(); +/// # let _ = (compiled, loader); +/// # Ok::<(), pharmsol::dsl::WasmError>(()) +/// ``` pub fn compile_module_source_to_wasm_module( source: &str, model_name: Option<&str>, @@ -282,6 +333,10 @@ fn compile_module_source_to_wasm_module_uncached( compile_execution_model_to_wasm_module(&execution) } +/// Return the JavaScript loader source for browser-side WASM model execution. +/// +/// This helper is useful when you want to ship compiled WASM bytes together +/// with the minimal browser glue code that understands the pharmsol ABI. pub fn browser_loader_source() -> String { BROWSER_LOADER_SOURCE .get_or_init(build_browser_loader_source) diff --git a/src/error/mod.rs b/src/error/mod.rs index 1316b8a4..c8f70b58 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -37,7 +37,11 @@ pub enum PharmsolError { ZeroLikelihood, #[error("Missing observation in prediction")] MissingObservation, - #[error("Input channel {input} is out of range (ndrugs = {ndrugs})")] + #[error("Input label `{label}` could not be resolved to a route input")] + UnknownInputLabel { label: String }, + #[error("Output label `{label}` could not be resolved to an output")] + UnknownOutputLabel { label: String }, + #[error("Input index {input} is out of range (ndrugs = {ndrugs})")] InputOutOfRange { input: usize, ndrugs: usize }, #[error("Output equation {outeq} is out of range (nout = {nout})")] OuteqOutOfRange { outeq: usize, nout: usize }, diff --git a/src/lib.rs b/src/lib.rs index c84d4ee1..9c9e40b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,105 @@ +//! `pharmsol` is a Rust library for pharmacometric work. +//! +//! You can use it to: +//! +//! - build PK/PD datasets from dose and observation events +//! - simulate analytical, ODE, and SDE models +//! - run non-compartmental analysis (NCA) +//! - compile and run models from the pharmsol DSL when the DSL features are enabled +//! +//! Most users start in one of these places: +//! +//! - [`prelude`] for the common types, traits, and macros +//! - [`data`] to build subjects, occasions, events, and covariates +//! - [`simulator`] to define models and generate predictions +//! - [`nca`] to calculate NCA metrics from the same data structures +//! - [`optimize`] for optimizer-oriented workflows +//! +//! The DSL runtime surface is feature-gated. When you enable `dsl-core`, the +//! `pharmsol::dsl` module adds parsing, analysis, lowering, compile, and runtime +//! entrypoints for models written as DSL source text. +//! +//! ## Quick Start +//! +//! This example shows the smallest full workflow: define a model, build a +//! subject, and generate predictions. +//! +//! ```rust +//! use pharmsol::prelude::*; +//! +//! let model = analytical! { +//! name: "one_cmt_iv", +//! params: [ke, v], +//! states: [central], +//! outputs: [cp], +//! routes: [ +//! infusion(iv) -> central, +//! ], +//! structure: one_compartment, +//! out: |x, _p, _t, _cov, y| { +//! y[cp] = x[central] / v; +//! }, +//! }; +//! +//! let subject = Subject::builder("patient_001") +//! .infusion(0.0, 500.0, "iv", 0.5) +//! .missing_observation(0.5, "cp") +//! .missing_observation(1.0, "cp") +//! .build(); +//! +//! let predictions = model.estimate_predictions(&subject, &[1.022, 194.0])?; +//! assert_eq!(predictions.flat_predictions().len(), 2); +//! # Ok::<(), pharmsol::PharmsolError>(()) +//! ``` +//! +//! ## Choose A Workflow +//! +//! Use this guide when you are deciding where to start. +//! +//! | Task | Start Here | Notes | +//! | --- | --- | --- | +//! | Build subject data | [`data`] or [`prelude`] | Best when you already know dose times, labels, and observations. | +//! | Simulate a model written in Rust | [`simulator`] or [`prelude`] | Supports analytical, ODE, and SDE models. | +//! | Run NCA | [`nca`] or [`prelude`] | Reuses the same `Subject`, `Occasion`, and `Data` types. | +//! | Use optimization helpers | [`optimize`] | Intended for advanced workflows. | +//! | Parse or compile DSL source | `pharmsol::dsl` | Requires one or more DSL features. | +//! +//! ## Feature Guide +//! +//! Core simulation and NCA APIs do not need extra crate features on native +//! targets. +//! +//! DSL work is feature-gated: +//! +//! - `dsl-core`: exposes the `pharmsol::dsl` facade and frontend types +//! - `dsl-jit`: adds in-process JIT compilation +//! - `dsl-aot`: adds native ahead-of-time artifact compilation +//! - `dsl-aot-load`: adds native artifact loading +//! - `dsl-wasm-compile`: adds WASM artifact generation +//! - `dsl-wasm`: adds WASM runtime loading and execution +//! +//! ## Labels And Indices +//! +//! Public data APIs use route labels and output labels such as `"iv"`, +//! `"oral"`, and `"cp"`. +//! +//! Use labels in builders and parsed data unless you are deliberately working +//! with dense internal indices from a lower-level API. +//! +//! ## Platform Notes +//! +//! The main `data`, `simulator`, `nca`, and `optimize` modules are documented +//! for native targets. Some surfaces are not built on `wasm32-unknown-unknown`. +//! The DSL runtime also has feature-specific platform limits. +//! +//! ## Next Stops +//! +//! - Start with [`prelude`] if you want one import for the common workflow. +//! - Open [`data`] if you need to construct subjects or parse input files. +//! - Open [`simulator`] if you need predictions from analytical, ODE, or SDE models. +//! - Open [`nca`] if you need exposure and terminal metrics. +//! - Use `pharmsol::dsl` if the model comes from source text instead of Rust code. + #[cfg(feature = "dsl-aot")] mod build_support; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] @@ -49,19 +151,31 @@ pub use pharmsol_macros::{analytical, ode, sde}; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use std::collections::HashMap; -/// Prelude module that re-exports all commonly used types and traits. +/// Common imports for the main pharmsol workflow. +/// +/// Use the prelude when you want one import that covers the common public API: +/// +/// - subject and dataset types +/// - subject builders and events +/// - simulation types and prediction results +/// - NCA traits and option types +/// - declaration-first macros such as [`crate::ode`] and [`crate::analytical`] /// -/// Importing `pharmsol::prelude::*` brings the main modeling, simulation, -/// and data APIs into scope. +/// This is the fastest way to get started with examples, scripts, and small +/// applications. +/// +/// If you need a narrower import surface, use the modules directly instead. /// /// # Example /// ```rust /// use pharmsol::prelude::*; /// /// let subject = Subject::builder("patient_001") -/// .bolus(0.0, 100.0, 0) -/// .observation(1.0, 10.5, 0) +/// .infusion(0.0, 100.0, "iv", 1.0) +/// .missing_observation(1.0, "cp") /// .build(); +/// +/// assert_eq!(subject.id(), "patient_001"); /// ``` #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub mod prelude { diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index 4734886c..1dd4bbb5 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -32,9 +32,7 @@ pub enum AnalyticalMetadataError { Validation(#[from] ModelMetadataError), #[error("analytical model declares {declared} state metadata entries but model has {expected} states")] StateCountMismatch { expected: usize, declared: usize }, - #[error( - "analytical model declares {declared} route metadata entries but model has {expected} input channels" - )] + #[error("analytical model declares {declared} route metadata entries but model has {expected} inputs")] RouteCountMismatch { expected: usize, declared: usize }, #[error("analytical model declares {declared} output metadata entries but model has {expected} outputs")] OutputCountMismatch { expected: usize, declared: usize }, @@ -119,7 +117,7 @@ impl Analytical { self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; self.invalidate_metadata(); @@ -186,7 +184,7 @@ fn validate_metadata_dimensions( }); } - let declared_routes = metadata.route_channel_count(); + let declared_routes = metadata.route_input_count(); if declared_routes != neqs.ndrugs { return Err(AnalyticalMetadataError::RouteCountMismatch { expected: neqs.ndrugs, @@ -278,6 +276,11 @@ impl EquationPriv for Analytical { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -321,13 +324,19 @@ impl EquationPriv for Analytical { let s = inf.time(); let e = s + inf.duration(); if current_t >= s && next_t <= e { - if inf.input() >= self.get_ndrugs() { + let input = + inf.input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: inf.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: inf.input(), + input, ndrugs: self.get_ndrugs(), }); } - rateiv[inf.input()] += inf.amount() / inf.duration(); + rateiv[input] += inf.amount() / inf.duration(); } } @@ -365,7 +374,12 @@ impl EquationPriv for Analytical { covariates, &mut y, ); - let pred = y[observation.outeq()]; + let outeq = observation + .outeq_index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + })?; + let pred = y[outeq]; let pred = observation.to_prediction(pred, x.as_slice().to_vec()); if let Some(error_models) = error_models { likelihood.push(pred.log_likelihood(error_models)?.exp()); diff --git a/src/simulator/equation/metadata.rs b/src/simulator/equation/metadata.rs index ecf51a52..c7fbd4c9 100644 --- a/src/simulator/equation/metadata.rs +++ b/src/simulator/equation/metadata.rs @@ -1,17 +1,40 @@ -//! Shared model metadata for handwritten simulator models. +//! Metadata builders and validated metadata views for handwritten models. //! -//! This module defines the public metadata contract that handwritten ODE, -//! analytical, and SDE models can attach to. The field set is intentionally -//! aligned with the public subset of the DSL/runtime metadata surface. +//! Use this module when a handwritten [`crate::ODE`], [`crate::Analytical`], or +//! [`crate::SDE`] model should expose the same public names that appear in data +//! rows, subject builders, or parsed files. //! -//! Internal runtime layout details such as dense buffer lengths, derived buffer -//! shape, or ABI-specific offsets remain internal for now. +//! Metadata gives names to parameters, covariates, states, routes, and outputs. +//! After validation, the execution layer can resolve public labels such as +//! `"iv"` and `"cp"` against those declarations before simulation. +//! +//! Without metadata, handwritten models fall back to numeric labels. With +//! metadata, labels are matched by name. +//! +//! # Example +//! +//! ```rust +//! use pharmsol::{metadata, ModelKind}; +//! +//! let metadata = metadata::new("one_cmt") +//! .kind(ModelKind::Ode) +//! .parameters(["cl", "v"]) +//! .states(["central"]) +//! .outputs(["cp"]) +//! .route(metadata::Route::infusion("iv").to_state("central")) +//! .validate() +//! .unwrap(); +//! +//! assert_eq!(metadata.name(), "one_cmt"); +//! assert_eq!(metadata.route("iv").unwrap().destination(), "central"); +//! assert_eq!(metadata.output_index("cp"), Some(0)); +//! ``` use pharmsol_dsl::{AnalyticalKernel, CovariateInterpolation, ModelKind}; use std::fmt; use thiserror::Error; -/// Create a new handwritten-model metadata builder. +/// Shorthand for [`ModelMetadata::new`]. pub fn new(name: impl Into) -> ModelMetadata { ModelMetadata::new(name) } @@ -71,7 +94,17 @@ impl fmt::Display for NameDomain { } } -/// Immutable validated metadata view used by later attachment slices. +/// Validated metadata view used by the execution layer. +/// +/// This type is what handwritten equation builders store after metadata has +/// passed validation. It provides stable lookup helpers from public names to the +/// dense indices used during execution. +/// +/// Route lookups expose two different indices: +/// - [`ValidatedModelMetadata::route_declaration_index`] is the route position in +/// declaration order. +/// - [`ValidatedModelMetadata::route_index`] is the dense execution input index +/// for that route kind. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ValidatedModelMetadata { name: String, @@ -80,17 +113,19 @@ pub struct ValidatedModelMetadata { covariates: Vec, states: Vec, routes: Vec, - route_channel_count: usize, + route_input_count: usize, outputs: Vec, particles: Option, analytical: Option, } impl ValidatedModelMetadata { + /// Get the public model name. pub fn name(&self) -> &str { &self.name } + /// Get the validated model family. pub fn kind(&self) -> ModelKind { self.kind } @@ -111,8 +146,11 @@ impl ValidatedModelMetadata { &self.routes } - pub fn route_channel_count(&self) -> usize { - self.route_channel_count + /// Get the number of dense execution input slots needed for routes. + /// + /// This is the maximum of the bolus-route count and infusion-route count. + pub fn route_input_count(&self) -> usize { + self.route_input_count } pub fn outputs(&self) -> &[Output] { @@ -143,14 +181,17 @@ impl ValidatedModelMetadata { self.states.iter().position(|state| state.name() == name) } + /// Look up a route by public name and return its dense execution input index. pub fn route_index(&self, name: &str) -> Option { - self.route(name).map(ValidatedRoute::channel_index) + self.route(name).map(ValidatedRoute::input_index) } + /// Look up a route by public name and return its declaration-order index. pub fn route_declaration_index(&self, name: &str) -> Option { self.routes.iter().position(|route| route.name() == name) } + /// Look up an output by public name and return its dense output index. pub fn output_index(&self, name: &str) -> Option { self.outputs.iter().position(|output| output.name() == name) } @@ -179,13 +220,17 @@ impl ValidatedModelMetadata { } } -/// One validated route declaration with resolved destination state index. +/// One validated route declaration with resolved execution details. +/// +/// A validated route keeps both the declaration-order index and the dense input +/// index used during execution. Those values can differ from each other when a +/// model mixes bolus and infusion routes. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ValidatedRoute { name: String, kind: RouteKind, declaration_index: usize, - channel_index: usize, + input_index: usize, destination: String, destination_index: usize, has_lag: bool, @@ -194,6 +239,7 @@ pub struct ValidatedRoute { } impl ValidatedRoute { + /// Get the public route name used for label matching. pub fn name(&self) -> &str { &self.name } @@ -202,18 +248,22 @@ impl ValidatedRoute { self.kind } + /// Get the declaration-order index for this route. pub fn declaration_index(&self) -> usize { self.declaration_index } - pub fn channel_index(&self) -> usize { - self.channel_index + /// Get the dense execution input index for this route kind. + pub fn input_index(&self) -> usize { + self.input_index } + /// Get the destination state name. pub fn destination(&self) -> &str { &self.destination } + /// Get the destination state index in model order. pub fn destination_index(&self) -> usize { self.destination_index } @@ -231,7 +281,12 @@ impl ValidatedRoute { } } -/// Metadata describing one handwritten simulator model. +/// Builder for handwritten model metadata. +/// +/// Use [`ModelMetadata`] to declare the public names that should be attached to +/// a handwritten equation. After validation, the resulting metadata can be +/// attached to handwritten [`crate::ODE`], [`crate::Analytical`], and +/// [`crate::SDE`] models through their `with_metadata(...)` methods. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ModelMetadata { name: String, @@ -379,11 +434,17 @@ impl ModelMetadata { } /// Validate this metadata using its declared kind. + /// + /// Use this when the metadata itself already declares whether the model is + /// ODE, analytical, or SDE. pub fn validate(self) -> Result { self.validate_internal(None, None) } /// Validate this metadata for a specific model kind. + /// + /// Use this when the equation type determines the model family and you want + /// validation to enforce that family explicitly. pub fn validate_for( self, kind: ModelKind, @@ -416,7 +477,7 @@ impl ModelMetadata { let particles = resolve_particles(kind, self.particles, fallback_particles)?; validate_kind_specific_fields(kind, self.analytical, particles)?; - let (routes, route_channel_count) = validate_routes(self.routes, &self.states)?; + let (routes, route_input_count) = validate_routes(self.routes, &self.states)?; Ok(ValidatedModelMetadata { name: self.name, @@ -425,7 +486,7 @@ impl ModelMetadata { covariates: self.covariates, states: self.states, routes, - route_channel_count, + route_input_count, outputs: self.outputs, particles, analytical: self.analytical, @@ -440,6 +501,7 @@ pub struct Parameter { } impl Parameter { + /// Create a named parameter declaration. pub fn new(name: impl Into) -> Self { Self { name: name.into() } } @@ -466,6 +528,7 @@ pub struct Covariate { } impl Covariate { + /// Create a named covariate without an explicit interpolation policy. pub fn new(name: impl Into) -> Self { Self { name: name.into(), @@ -473,14 +536,17 @@ impl Covariate { } } + /// Create a continuous covariate that uses linear interpolation. pub fn continuous(name: impl Into) -> Self { Self::new(name).with_interpolation(CovariateInterpolation::Linear) } + /// Create a covariate that uses last-observation-carried-forward semantics. pub fn locf(name: impl Into) -> Self { Self::new(name).with_interpolation(CovariateInterpolation::Locf) } + /// Set the interpolation policy explicitly. pub fn with_interpolation(mut self, interpolation: CovariateInterpolation) -> Self { self.interpolation = Some(interpolation); self @@ -502,6 +568,7 @@ pub struct State { } impl State { + /// Create a named state declaration. pub fn new(name: impl Into) -> Self { Self { name: name.into() } } @@ -527,6 +594,7 @@ pub struct Output { } impl Output { + /// Create a named output declaration. pub fn new(name: impl Into) -> Self { Self { name: name.into() } } @@ -548,18 +616,25 @@ where /// Route declaration kind. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteKind { + /// Instantaneous dose input. Bolus, + /// Dose input over a duration. Infusion, } /// How route inputs should be interpreted by the execution layer. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteInputPolicy { + /// Inject the resolved input directly into the declared destination state. InjectToDestination, + /// Expect the low-level execution path to provide an explicit input vector. ExplicitInputVector, } /// One named route declaration. +/// +/// Route names are the public labels matched against dose events such as `iv` +/// or `oral`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Route { name: String, @@ -571,14 +646,17 @@ pub struct Route { } impl Route { + /// Create a named bolus route declaration. pub fn bolus(name: impl Into) -> Self { Self::new(name, RouteKind::Bolus) } + /// Create a named infusion route declaration. pub fn infusion(name: impl Into) -> Self { Self::new(name, RouteKind::Infusion) } + /// Create a route declaration with an explicit kind. pub fn new(name: impl Into, kind: RouteKind) -> Self { Self { name: name.into(), @@ -590,26 +668,31 @@ impl Route { } } + /// Declare which state this route targets. pub fn to_state(mut self, destination: impl Into) -> Self { self.destination = Some(destination.into()); self } + /// Mark this route as supporting lag handling. pub fn with_lag(mut self) -> Self { self.has_lag = true; self } + /// Mark this route as supporting bioavailability handling. pub fn with_bioavailability(mut self) -> Self { self.has_bioavailability = true; self } + /// Request direct injection into the destination state at execution time. pub fn inject_input_to_destination(mut self) -> Self { self.input_policy = Some(RouteInputPolicy::InjectToDestination); self } + /// Request an explicit low-level input vector at execution time. pub fn expect_explicit_input(mut self) -> Self { self.input_policy = Some(RouteInputPolicy::ExplicitInputVector); self @@ -730,20 +813,20 @@ fn validate_routes( routes: Vec, states: &[State], ) -> Result<(Vec, usize), ModelMetadataError> { - let mut bolus_channels = 0; - let mut infusion_channels = 0; + let mut bolus_inputs = 0; + let mut infusion_inputs = 0; let mut validated_routes = Vec::with_capacity(routes.len()); for (declaration_index, route) in routes.into_iter().enumerate() { - let channel_index = match route.kind { + let input_index = match route.kind { RouteKind::Bolus => { - let index = bolus_channels; - bolus_channels += 1; + let index = bolus_inputs; + bolus_inputs += 1; index } RouteKind::Infusion => { - let index = infusion_channels; - infusion_channels += 1; + let index = infusion_inputs; + infusion_inputs += 1; index } }; @@ -751,18 +834,18 @@ fn validate_routes( validated_routes.push(validate_route( route, declaration_index, - channel_index, + input_index, states, )?); } - Ok((validated_routes, bolus_channels.max(infusion_channels))) + Ok((validated_routes, bolus_inputs.max(infusion_inputs))) } fn validate_route( route: Route, declaration_index: usize, - channel_index: usize, + input_index: usize, states: &[State], ) -> Result { if route.kind == RouteKind::Infusion && route.has_lag { @@ -796,7 +879,7 @@ fn validate_route( name: route.name, kind: route.kind, declaration_index, - channel_index, + input_index, destination, destination_index, has_lag: route.has_lag, @@ -902,7 +985,7 @@ mod tests { assert_eq!(metadata.state_index("central"), Some(0)); assert_eq!(metadata.route_index("iv"), Some(0)); assert_eq!(metadata.route_declaration_index("iv"), Some(0)); - assert_eq!(metadata.route_channel_count(), 1); + assert_eq!(metadata.route_input_count(), 1); assert_eq!(metadata.output_index("cp"), Some(0)); assert_eq!( metadata.route("iv").expect("route exists").destination(), @@ -915,10 +998,7 @@ mod tests { .declaration_index(), 0 ); - assert_eq!( - metadata.route("iv").expect("route exists").channel_index(), - 0 - ); + assert_eq!(metadata.route("iv").expect("route exists").input_index(), 0); assert_eq!( metadata .route("iv") @@ -988,8 +1068,8 @@ mod tests { } #[test] - fn shared_channel_routes_preserve_declaration_and_channel_identity() { - let metadata = new("shared_channel") + fn shared_input_routes_preserve_declaration_and_input_identity() { + let metadata = new("shared_input") .kind(ModelKind::Ode) .parameters(["ke"]) .states(["gut", "central"]) @@ -999,19 +1079,16 @@ mod tests { Route::infusion("iv").to_state("central"), ]) .validate() - .expect("shared-channel metadata should validate"); + .expect("shared-input metadata should validate"); assert_eq!(metadata.routes().len(), 2); - assert_eq!(metadata.route_channel_count(), 1); + assert_eq!(metadata.route_input_count(), 1); assert_eq!(metadata.route_index("oral"), Some(0)); assert_eq!(metadata.route_index("iv"), Some(0)); assert_eq!(metadata.route_declaration_index("oral"), Some(0)); assert_eq!(metadata.route_declaration_index("iv"), Some(1)); - assert_eq!( - metadata.route("oral").expect("oral route").channel_index(), - 0 - ); - assert_eq!(metadata.route("iv").expect("iv route").channel_index(), 0); + assert_eq!(metadata.route("oral").expect("oral route").input_index(), 0); + assert_eq!(metadata.route("iv").expect("iv route").input_index(), 0); assert_eq!( metadata .route("oral") diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 60cb2d8f..03e5318c 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -1,3 +1,51 @@ +//! Handwritten equation families and their shared simulation interfaces. +//! +//! This module is the public home for handwritten [`ODE`], [`Analytical`], and +//! [`SDE`] models, plus the shared [`Equation`] trait and the metadata types +//! that attach public names to parameters, states, routes, and outputs. +//! +//! Use this module when you want to: +//! - choose between deterministic ODE, analytical, and stochastic SDE models +//! - attach metadata so dataset labels such as `"iv"` and `"cp"` resolve by +//! name instead of by dense numeric index +//! - work with prediction or likelihood APIs across equation families +//! +//! # Equation Families +//! +//! - [`ODE`] for deterministic models that must be numerically integrated. +//! - [`Analytical`] for supported closed-form models. +//! - [`SDE`] for stochastic models that use particles. +//! +//! # Labels And Metadata +//! +//! Input and output labels arrive from public data APIs as strings. +//! +//! - Without metadata, handwritten models fall back to numeric labels such as +//! `0` or `1`. +//! - With [`metadata::ModelMetadata`] attached, route and output labels are +//! resolved by name against the declared routes and outputs before +//! simulation. +//! +//! That label-first path is the preferred public workflow for current authoring. +//! +//! # Example +//! +//! ```rust +//! use pharmsol::{metadata, ModelKind}; +//! +//! let metadata = metadata::new("one_cmt") +//! .kind(ModelKind::Ode) +//! .parameters(["cl", "v"]) +//! .states(["central"]) +//! .outputs(["cp"]) +//! .route(metadata::Route::infusion("iv").to_state("central")) +//! .validate() +//! .unwrap(); +//! +//! assert_eq!(metadata.route_index("iv"), Some(0)); +//! assert_eq!(metadata.output_index("cp"), Some(0)); +//! ``` + use std::fmt::Debug; pub mod analytical; pub mod metadata; @@ -12,17 +60,18 @@ pub use sde::*; use crate::{ error_model::AssayErrorModels, simulator::{Fa, Lag}, - Covariates, Event, Infusion, Observation, PharmsolError, Subject, + Covariates, Event, Infusion, InputLabel, Observation, Occasion, OutputLabel, PharmsolError, + Subject, }; use super::likelihood::Prediction; /// Trait for state vectors that can receive bolus doses. pub trait State { - /// Add a bolus dose to the state at the specified input compartment. + /// Add a bolus dose to the state at the specified resolved input index. /// /// # Parameters - /// - `input`: The compartment index + /// - `input`: The resolved dense input index used by the execution layer /// - `amount`: The bolus amount fn add_bolus(&mut self, input: usize, amount: f64); } @@ -113,7 +162,7 @@ pub trait Cache: Sized { fn disable_cache(self) -> Self; } -/// Trait defining the associated types for equations. +/// Associated state and prediction container types for an equation family. pub trait EquationTypes { /// The state vector type type S: State + Debug; @@ -129,6 +178,7 @@ pub(crate) trait EquationPriv: EquationTypes { fn get_nstates(&self) -> usize; fn get_ndrugs(&self) -> usize; fn get_nouteqs(&self) -> usize; + fn metadata(&self) -> Option<&ValidatedModelMetadata>; fn solve( &self, state: &mut Self::S, @@ -141,6 +191,85 @@ pub(crate) trait EquationPriv: EquationTypes { fn nparticles(&self) -> usize { 1 } + + fn resolve_input_label( + &self, + label: &InputLabel, + expected_kind: RouteKind, + ) -> Result { + if let Some(metadata) = self.metadata() { + let route = + metadata + .route(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; + + if route.kind() != expected_kind { + return Err(PharmsolError::OtherError(format!( + "input label `{}` is declared as {:?} but used as {:?}", + label, + route.kind(), + expected_kind + ))); + } + + return Ok(route.input_index()); + } + + label + .index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + }) + } + + fn resolve_output_label(&self, label: &OutputLabel) -> Result { + if let Some(metadata) = self.metadata() { + return metadata.output_index(label.as_str()).ok_or_else(|| { + PharmsolError::UnknownOutputLabel { + label: label.to_string(), + } + }); + } + + label + .index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + }) + } + + fn resolve_occasion_events( + &self, + occasion: &Occasion, + support_point: &[f64], + covariates: &Covariates, + ) -> Result, PharmsolError> { + let mut resolved = occasion.clone(); + + for event in resolved.events_iter_mut() { + match event { + Event::Bolus(bolus) => { + let input = self.resolve_input_label(bolus.input(), RouteKind::Bolus)?; + bolus.set_input(input); + } + Event::Infusion(infusion) => { + let input = self.resolve_input_label(infusion.input(), RouteKind::Infusion)?; + infusion.set_input(input); + } + Event::Observation(observation) => { + let outeq = self.resolve_output_label(observation.outeq())?; + observation.set_outeq(outeq); + } + } + } + + Ok(resolved.process_events( + Some((self.fa(), self.lag(), support_point, covariates)), + true, + )) + } #[allow(dead_code)] fn is_sde(&self) -> bool { false @@ -181,13 +310,20 @@ pub(crate) trait EquationPriv: EquationTypes { ) -> Result<(), PharmsolError> { match event { Event::Bolus(bolus) => { - if bolus.input() >= self.get_ndrugs() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: self.get_ndrugs(), }); } - x.add_bolus(bolus.input(), bolus.amount()); + x.add_bolus(input, bolus.amount()); } Event::Infusion(infusion) => { infusions.push(infusion.clone()); @@ -220,11 +356,15 @@ pub(crate) trait EquationPriv: EquationTypes { } } -/// Trait for model equations that can be simulated. +/// Trait for handwritten model equations that can be simulated. +/// +/// [`Equation`] is the shared interface implemented by handwritten [`ODE`], +/// [`Analytical`], and [`SDE`] models. /// -/// This trait defines the interface for different types of model equations -/// (ODE, SDE, analytical) that can be simulated to generate predictions -/// and estimate parameters. +/// Subject data enters this layer through public labels on dose and observation +/// events. If metadata is attached to the equation, those labels are resolved by +/// name before simulation. Otherwise, the execution layer expects numeric labels +/// that can be interpreted as dense indices. /// /// # Likelihood Calculation /// @@ -332,10 +472,7 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { let mut x = self.initial_state(support_point, covariates, occasion.index()); let mut infusions = Vec::new(); - let events = occasion.process_events( - Some((self.fa(), self.lag(), support_point, covariates)), - true, - ); + let events = self.resolve_occasion_events(occasion, support_point, covariates)?; for (index, event) in events.iter().enumerate() { self.simulate_event( support_point, @@ -355,6 +492,7 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { } } +/// Runtime family tag for handwritten equations. #[repr(C)] #[derive(Clone, Debug)] pub enum EqnKind { diff --git a/src/simulator/equation/ode/closure.rs b/src/simulator/equation/ode/closure.rs index eed65e7a..47f2a81e 100644 --- a/src/simulator/equation/ode/closure.rs +++ b/src/simulator/equation/ode/closure.rs @@ -11,13 +11,13 @@ type C = ::C; type T = ::T; #[derive(Debug, Clone)] -struct InfusionChannel { +struct InfusionTrack { input: usize, event_times: Vec, cumulative_rates: Vec, } -impl InfusionChannel { +impl InfusionTrack { fn new(input: usize, mut events: Vec<(f64, f64)>) -> Self { events.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal)); @@ -63,15 +63,13 @@ impl InfusionChannel { #[derive(Debug, Clone, Default)] struct InfusionSchedule { - channels: Vec, + tracks: Vec, } impl InfusionSchedule { fn new(ndrugs: usize, infusions: &[&Infusion]) -> Result { if ndrugs == 0 || infusions.is_empty() { - return Ok(Self { - channels: Vec::new(), - }); + return Ok(Self { tracks: Vec::new() }); } let mut per_input: Vec> = vec![Vec::new(); ndrugs]; @@ -80,7 +78,11 @@ impl InfusionSchedule { continue; } - let input = infusion.input(); + let input = infusion + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: infusion.input().to_string(), + })?; if input >= ndrugs { return Err(PharmsolError::InputOutOfRange { input, ndrugs }); } @@ -90,27 +92,27 @@ impl InfusionSchedule { per_input[input].push((infusion.time() + infusion.duration(), -rate)); } - let channels = per_input + let tracks = per_input .into_iter() .enumerate() .filter_map(|(input, events)| { if events.is_empty() { None } else { - Some(InfusionChannel::new(input, events)) + Some(InfusionTrack::new(input, events)) } }) .collect(); - Ok(Self { channels }) + Ok(Self { tracks }) } fn fill_rate_vector(&self, time: f64, rateiv: &mut V) { rateiv.fill(0.0); - for channel in &self.channels { - let rate = channel.rate_at(time); + for track in &self.tracks { + let rate = track.rate_at(time); if rate != 0.0 { - rateiv[channel.input] = rate; + rateiv[track.input] = rate; } } } diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index cafe6a96..c65f16a9 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -87,9 +87,7 @@ pub enum OdeMetadataError { Validation(#[from] ModelMetadataError), #[error("ODE declares {declared} state metadata entries but model has {expected} states")] StateCountMismatch { expected: usize, declared: usize }, - #[error( - "ODE declares {declared} route metadata entries but model has {expected} input channels" - )] + #[error("ODE declares {declared} route metadata entries but model has {expected} inputs")] RouteCountMismatch { expected: usize, declared: usize }, #[error("ODE declares {declared} output metadata entries but model has {expected} outputs")] OutputCountMismatch { expected: usize, declared: usize }, @@ -134,7 +132,7 @@ impl ODE { self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; self.invalidate_metadata(); @@ -211,7 +209,7 @@ fn validate_metadata_dimensions( }); } - let declared_routes = metadata.route_channel_count(); + let declared_routes = metadata.route_input_count(); if declared_routes != neqs.ndrugs { return Err(OdeMetadataError::RouteCountMismatch { expected: neqs.ndrugs, @@ -330,6 +328,11 @@ impl EquationPriv for ODE { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -397,14 +400,21 @@ impl ODE { match event { Event::Bolus(bolus) => { - if bolus.input() >= bolus_v.len() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= bolus_v.len() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: bolus_v.len(), }); } bolus_v.fill(0.0); - bolus_v[bolus.input()] = bolus.amount(); + bolus_v[input] = bolus.amount(); state_with_bolus.fill(0.0); state_without_bolus.fill(0.0); @@ -444,7 +454,12 @@ impl ODE { covariates, y_out, ); - let pred = y_out[observation.outeq()]; + let outeq = observation.outeq_index().ok_or_else(|| { + PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + } + })?; + let pred = y_out[outeq]; let pred = observation.to_prediction(pred, solver.state().y.as_slice().to_vec()); if let Some(error_models) = error_models { @@ -550,11 +565,14 @@ impl Equation for ODE { // Iterate over occasions for occasion in subject.occasions() { let covariates = occasion.covariates(); - let infusions = occasion.infusions_ref(); - let events = occasion.process_events( - Some((self.fa(), self.lag(), support_point, covariates)), - true, - ); + let events = self.resolve_occasion_events(occasion, support_point, covariates)?; + let infusions = events + .iter() + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion), + _ => None, + }) + .collect::>(); let problem = OdeBuilder::::new() .atol(vec![self.atol]) @@ -680,9 +698,9 @@ mod tests { fn route_policy_subject() -> Subject { Subject::builder("route_policy") - .bolus(0.0, 100.0, 0) - .infusion(0.0, 100.0, 0, 1.0) - .observation(1.0, 0.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(0.0, 100.0, "iv", 1.0) + .observation(1.0, 0.0, "cp") .build() } diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index bdafbbc3..43a1d48a 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -34,9 +34,7 @@ pub enum SdeMetadataError { Validation(#[from] ModelMetadataError), #[error("SDE declares {declared} state metadata entries but model has {expected} states")] StateCountMismatch { expected: usize, declared: usize }, - #[error( - "SDE declares {declared} route metadata entries but model has {expected} input channels" - )] + #[error("SDE declares {declared} route metadata entries but model has {expected} inputs")] RouteCountMismatch { expected: usize, declared: usize }, #[error("SDE declares {declared} output metadata entries but model has {expected} outputs")] OutputCountMismatch { expected: usize, declared: usize }, @@ -124,7 +122,10 @@ fn simulate_sde_event( let mut rateiv = V::zeros(ndrugs, NalgebraContext); for infusion in &infusion_events { if time >= infusion.time() && time <= infusion.duration() + infusion.time() { - rateiv[infusion.input()] += infusion.amount() / infusion.duration(); + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + rateiv[input] += infusion.amount() / infusion.duration(); } } @@ -233,7 +234,7 @@ impl SDE { self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; self.invalidate_metadata(); @@ -306,7 +307,7 @@ fn validate_metadata_dimensions( }); } - let declared_routes = metadata.route_channel_count(); + let declared_routes = metadata.route_input_count(); if declared_routes != neqs.ndrugs { return Err(SdeMetadataError::RouteCountMismatch { expected: neqs.ndrugs, @@ -466,6 +467,11 @@ impl EquationPriv for SDE { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -524,7 +530,10 @@ impl EquationPriv for SDE { covariates, &mut y, ); - *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec()); + let outeq = observation + .outeq_index() + .expect("resolved observations should use numeric output labels"); + *p = observation.to_prediction(y[outeq], x[i].as_slice().to_vec()); }); let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?; *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap(); @@ -588,17 +597,21 @@ impl EquationPriv for SDE { ) -> Result<(), PharmsolError> { match event { crate::Event::Bolus(bolus) => { - if bolus.input() >= self.get_ndrugs() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: self.get_ndrugs(), }); } - if !self - .injected_bolus_mappings - .apply(x, bolus.input(), bolus.amount()) - { - x.add_bolus(bolus.input(), bolus.amount()); + if !self.injected_bolus_mappings.apply(x, input, bolus.amount()) { + x.add_bolus(input, bolus.amount()); } } crate::Event::Infusion(infusion) => { @@ -909,8 +922,8 @@ mod tests { .expect("injected metadata should validate"); let subject = Subject::builder("bolus_route") - .bolus(0.0, 100.0, 0) - .missing_observation(0.1, 0) + .bolus(0.0, 100.0, "oral") + .missing_observation(0.1, "cp") .build(); let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); @@ -954,8 +967,8 @@ mod tests { .expect("injected metadata should validate"); let subject = Subject::builder("infusion_route") - .infusion(0.0, 100.0, 0, 1.0) - .missing_observation(1.0, 0) + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(1.0, "cp") .build(); let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); diff --git a/src/simulator/mod.rs b/src/simulator/mod.rs index 5cea84fe..058ca125 100644 --- a/src/simulator/mod.rs +++ b/src/simulator/mod.rs @@ -200,7 +200,7 @@ pub type Fa = fn(&V, T, &Covariates) -> HashMap; /// /// # Fields /// - `nstates`: Number of state variables (ODE compartments) -/// - `ndrugs`: Number of drug input channels (size of bolus[] and rateiv[]) +/// - `ndrugs`: Number of drug inputs (size of bolus[] and rateiv[]) /// - `nout`: Number of output equations /// /// # Defaults @@ -218,7 +218,7 @@ pub type Fa = fn(&V, T, &Covariates) -> HashMap; pub struct Neqs { /// Number of state variables pub nstates: usize, - /// Number of drug input channels (bolus/rateiv size) + /// Number of drug inputs (bolus/rateiv size) pub ndrugs: usize, /// Number of output equations pub nout: usize, diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs index e025ec4f..f527978f 100644 --- a/tests/analytical_macro_lowering.rs +++ b/tests/analytical_macro_lowering.rs @@ -1,47 +1,47 @@ use approx::assert_relative_eq; use pharmsol::prelude::*; -fn infusion_subject(input: usize) -> Subject { +fn infusion_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("analytical-macro-iv") .infusion(0.0, 120.0, input, 1.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn oral_subject(input: usize) -> Subject { +fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("analytical-macro-oral") .bolus(0.0, 100.0, input) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn shared_channel_subject(input: usize) -> Subject { +fn shared_input_subject() -> Subject { Subject::builder("analytical-macro-shared") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(6.5, 0) - .missing_observation(7.0, 0) - .missing_observation(8.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } -fn covariate_subject(oral: usize, iv: usize, cp: usize) -> Subject { +fn covariate_subject(oral: impl ToString, iv: impl ToString, cp: impl ToString) -> Subject { Subject::builder("analytical-macro-covariates") .bolus(1.0, 100.0, oral) .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) + .missing_observation(0.25, cp.to_string()) + .missing_observation(0.75, cp.to_string()) + .missing_observation(1.5, cp.to_string()) + .missing_observation(3.0, cp.to_string()) + .missing_observation(6.5, cp.to_string()) + .missing_observation(7.0, cp.to_string()) .missing_observation(8.0, cp) .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) @@ -56,9 +56,9 @@ fn macro_one_compartment() -> equation::Analytical { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _t, y| { y[cp] = x[central] / v; @@ -99,9 +99,9 @@ fn macro_one_compartment_with_absorption() -> equation::Analytical { params: [ka, ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], structure: one_compartment_with_absorption, lag: |_t| { lag! { oral => tlag } @@ -160,16 +160,16 @@ fn handwritten_one_compartment_with_absorption() -> equation::Analytical { .expect("handwritten absorption metadata should validate") } -fn macro_shared_channel_analytical() -> equation::Analytical { +fn macro_shared_input_analytical() -> equation::Analytical { analytical! { name: "one_cmt_abs_shared", params: [ka, ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, lag: |_t| { lag! { oral => tlag } @@ -183,7 +183,7 @@ fn macro_shared_channel_analytical() -> equation::Analytical { } } -fn handwritten_shared_channel_analytical() -> equation::Analytical { +fn handwritten_shared_input_analytical() -> equation::Analytical { equation::Analytical::new( equation::one_compartment_with_absorption, |_p, _t, _cov| {}, @@ -219,7 +219,7 @@ fn handwritten_shared_channel_analytical() -> equation::Analytical { ]) .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), ) - .expect("handwritten shared-channel analytical metadata should validate") + .expect("handwritten shared-input analytical metadata should validate") } fn macro_covariate_analytical() -> equation::Analytical { @@ -229,10 +229,10 @@ fn macro_covariate_analytical() -> equation::Analytical { covariates: [wt, renal], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, sec: |_t| { let wt_scale = (wt / 70.0).powf(0.75); @@ -382,7 +382,7 @@ fn assert_prediction_match(left: &[f64], right: &[f64]) { fn analytical_macro_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_one_compartment(); let handwritten_model = handwritten_one_compartment(); - let subject = infusion_subject(0); + let subject = infusion_subject("iv", "cp"); let support_point = [0.2, 10.0]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -408,7 +408,7 @@ fn analytical_macro_lowering_matches_handwritten_metadata_and_predictions() { fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { let macro_model = macro_one_compartment_with_absorption(); let handwritten_model = handwritten_one_compartment_with_absorption(); - let subject = oral_subject(0); + let subject = oral_subject("oral", "cp"); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -438,10 +438,10 @@ fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { } #[test] -fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { - let macro_model = macro_shared_channel_analytical(); - let handwritten_model = handwritten_shared_channel_analytical(); - let subject = shared_channel_subject(0); +fn analytical_macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_input_analytical(); + let handwritten_model = handwritten_shared_input_analytical(); + let subject = shared_input_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -453,12 +453,12 @@ fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_pre let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) - .expect("macro shared-channel analytical model should simulate") + .expect("macro shared-input analytical model should simulate") .flat_predictions() .to_vec(); let handwritten_predictions = handwritten_model .estimate_predictions(&subject, &support_point) - .expect("handwritten shared-channel analytical model should simulate") + .expect("handwritten shared-input analytical model should simulate") .flat_predictions() .to_vec(); @@ -472,14 +472,9 @@ fn analytical_macro_covariates_lower_to_handwritten_behavior() { assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - let oral = macro_model.route_index("oral").expect("oral route exists"); - let iv = macro_model.route_index("iv").expect("iv route exists"); - let cp = macro_model.output_index("cp").expect("cp output exists"); - let subject = covariate_subject(oral, iv, cp); + let subject = covariate_subject("oral", "iv", "cp"); let support_point = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; - assert_eq!(oral, iv); - let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) .expect("macro covariate analytical model should simulate") diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index 43621e8a..be80f10e 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -1,3 +1,4 @@ +use approx::assert_relative_eq; #[cfg(feature = "dsl-jit")] use pharmsol::dsl::{self, RuntimeCompilationTarget, RuntimePredictions}; #[cfg(feature = "dsl-jit")] @@ -52,6 +53,59 @@ dx(central) = ka * depot - (cl / v) * central out(cp) = central / (v * (wt / 70.0)) ~ continuous() "#; +const ODE_MULTI_DIGIT_OUTPUT_ORDER_DSL: &str = r#" +name = multi_digit_output_order +kind = ode + +params = ke, v +states = central +outputs = 2, 10, 11 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(10) = central / v ~ continuous() +out(2) = central / v ~ continuous() +out(11) = central / v ~ continuous() +"#; + +const ODE_NUMERIC_ROUTE_LABELS_AUTHORING_DSL: &str = r#" +name = authoring_numeric_routes +kind = ode + +states = first, second +outputs = cp + +bolus(10) -> first +bolus(11) -> second + +dx(first) = 0 +dx(second) = 0 + +out(cp) = first + second ~ continuous() +"#; + +const ODE_NUMERIC_ROUTE_LABELS_STRUCTURED_DSL: &str = r#"model structured_numeric_routes { + kind ode + states { + first, + second, + } + routes { + 10 -> first + 11 -> second + } + dynamics { + ddt(first) = 0 + ddt(second) = 0 + } + outputs { + cp = first + second + } +} +"#; + const ODE_INVALID_INFUSION_LAG_DSL: &str = r#" name = invalid_infusion_lag_parity kind = ode @@ -69,8 +123,8 @@ out(cp) = central / v ~ continuous() "#; #[cfg(feature = "dsl-jit")] -const ODE_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" -name = shared_channel_one_cpt +const ODE_RUNTIME_SHARED_INPUT_DSL: &str = r#" +name = shared_input_one_cpt kind = ode params = ka, ke, v, tlag, f_oral @@ -88,6 +142,76 @@ dx(central) = ka * depot - ke * central out(cp) = central / v ~ continuous() "#; +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_MIXED_OUTPUT_LABELS_DSL: &str = r#" +name = mixed_output_labels_runtime +kind = ode + +params = ke, v +states = central +outputs = cp, 0, 1 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +out(0) = 2 * central / v ~ continuous() +out(1) = 3 * central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_UNDECLARED_NUMERIC_OUTPUT_LABEL_DSL: &str = r#" +name = undeclared_numeric_output_runtime +kind = ode + +params = ke, v +states = central +outputs = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(a0) = central / v ~ continuous() +out(a1) = central / v ~ continuous() +out(a2) = central / v ~ continuous() +out(a3) = central / v ~ continuous() +out(a4) = central / v ~ continuous() +out(a5) = central / v ~ continuous() +out(a6) = central / v ~ continuous() +out(a7) = central / v ~ continuous() +out(a8) = central / v ~ continuous() +out(a9) = central / v ~ continuous() +out(a10) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_UNDECLARED_NUMERIC_INPUT_LABEL_DSL: &str = r#" +name = undeclared_numeric_input_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(r0) -> central +bolus(r1) -> central +bolus(r2) -> central +bolus(r3) -> central +bolus(r4) -> central +bolus(r5) -> central +bolus(r6) -> central +bolus(r7) -> central +bolus(r8) -> central +bolus(r9) -> central +bolus(r10) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + const ANALYTICAL_DSL: &str = r#" name = one_cmt_abs_parity kind = analytical @@ -103,7 +227,7 @@ out(cp) = central / v ~ continuous() "#; #[cfg(feature = "dsl-jit")] -const ANALYTICAL_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +const ANALYTICAL_RUNTIME_SHARED_INPUT_DSL: &str = r#" name = one_cmt_abs_shared kind = analytical @@ -158,7 +282,7 @@ out(cp) = central / v ~ continuous() "#; #[cfg(feature = "dsl-jit")] -const SDE_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +const SDE_RUNTIME_SHARED_INPUT_DSL: &str = r#" name = one_cmt_shared_sde kind = sde @@ -186,7 +310,7 @@ struct MetadataParityView { parameters: Vec, covariates: Vec, states: Vec, - route_channel_count: usize, + route_input_count: usize, routes: Vec, outputs: Vec, analytical_kernel: Option, @@ -211,7 +335,7 @@ struct RouteParity { name: String, kind: Option, declaration_index: usize, - channel_index: usize, + input_index: usize, destination_name: String, destination_index: usize, has_lag: bool, @@ -223,7 +347,7 @@ struct RouteParity { struct RouteInputPolicyParity { name: String, declaration_index: usize, - channel_index: usize, + input_index: usize, input_policy: RouteInputPolicy, } @@ -267,16 +391,16 @@ fn compile_runtime_jit_model(src: &str, model_name: &str) -> dsl::CompiledRuntim } #[cfg(feature = "dsl-jit")] -fn shared_channel_prediction_subject(input: usize, output: usize) -> Subject { - Subject::builder("authoring-parity-shared-channel") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, output) - .missing_observation(1.0, output) - .missing_observation(2.0, output) - .missing_observation(6.5, output) - .missing_observation(7.0, output) - .missing_observation(8.0, output) +fn shared_input_prediction_subject() -> Subject { + Subject::builder("authoring-parity-shared-input") + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } @@ -328,7 +452,7 @@ fn dsl_metadata_view(src: &str) -> MetadataParityView { name: route.name.clone(), kind: route.kind.map(RouteKindParity::from_dsl), declaration_index: route.declaration_index, - channel_index: route.index, + input_index: route.index, destination_name: route.destination.state_name.clone(), destination_index: route.destination.state_offset, has_lag: route.has_lag, @@ -342,7 +466,7 @@ fn dsl_metadata_view(src: &str) -> MetadataParityView { parameters, covariates, states, - route_channel_count: model.abi.route_buffer.len, + route_input_count: model.abi.route_buffer.len, routes, outputs, analytical_kernel: model.metadata.analytical, @@ -360,7 +484,7 @@ fn dsl_route_input_policy_view(src: &str) -> Vec { .map(|route| RouteInputPolicyParity { name: route.name, declaration_index: route.declaration_index, - channel_index: route.index, + input_index: route.index, input_policy: if route.inject_input_to_destination { RouteInputPolicy::InjectToDestination } else { @@ -402,7 +526,7 @@ fn validated_metadata_view(metadata: &ValidatedModelMetadata) -> MetadataParityV index, }) .collect(), - route_channel_count: metadata.route_channel_count(), + route_input_count: metadata.route_input_count(), routes: metadata .routes() .iter() @@ -410,7 +534,7 @@ fn validated_metadata_view(metadata: &ValidatedModelMetadata) -> MetadataParityV name: route.name().to_string(), kind: Some(RouteKindParity::from_handwritten(route.kind())), declaration_index: route.declaration_index(), - channel_index: route.channel_index(), + input_index: route.input_index(), destination_name: route.destination().to_string(), destination_index: route.destination_index(), has_lag: route.has_lag(), @@ -441,7 +565,7 @@ fn handwritten_route_input_policy_view( .map(|route| RouteInputPolicyParity { name: route.name().to_string(), declaration_index: route.declaration_index(), - channel_index: route.channel_index(), + input_index: route.input_index(), input_policy: route .input_policy() .expect("route input policy should be explicit in this handwritten fixture"), @@ -456,9 +580,9 @@ fn macro_ode_model() -> equation::ODE { covariates: [wt], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[depot] = -ka * x[depot]; dx[central] = ka * x[depot] - (cl / v) * x[central]; @@ -546,19 +670,19 @@ fn handwritten_ode_model() -> equation::ODE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_macro_ode() -> equation::ODE { +fn runtime_shared_input_macro_ode() -> equation::ODE { ode! { - name: "shared_channel_one_cpt", + name: "shared_input_one_cpt", params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, infusion(iv) -> central, - }, - diffeq: |x, _p, _t, dx, bolus, rateiv, _cov| { - dx[depot] = bolus[oral] - ka * x[depot]; - dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; }, lag: |_p, _t, _cov| { lag! { oral => tlag } @@ -573,7 +697,7 @@ fn runtime_shared_channel_macro_ode() -> equation::ODE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_handwritten_ode() -> equation::ODE { +fn runtime_shared_input_handwritten_ode() -> equation::ODE { equation::ODE::new( |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); @@ -598,7 +722,7 @@ fn runtime_shared_channel_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_channel_one_cpt") + equation::metadata::new("shared_input_one_cpt") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -607,17 +731,17 @@ fn runtime_shared_channel_handwritten_ode() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ]), ) - .expect("handwritten shared-channel ODE metadata should validate") + .expect("handwritten shared-input ODE metadata should validate") } #[cfg(feature = "dsl-jit")] -fn runtime_mismatched_shared_channel_ode() -> equation::ODE { +fn runtime_mismatched_shared_input_ode() -> equation::ODE { equation::ODE::new( |x, p, _t, dx, _bolus, _rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); @@ -642,7 +766,7 @@ fn runtime_mismatched_shared_channel_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_channel_one_cpt_mismatched") + equation::metadata::new("shared_input_one_cpt_mismatched") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -657,20 +781,20 @@ fn runtime_mismatched_shared_channel_ode() -> equation::ODE { .expect_explicit_input(), ]), ) - .expect("mismatched shared-channel ODE metadata should validate") + .expect("mismatched shared-input ODE metadata should validate") } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_macro_analytical() -> equation::Analytical { +fn runtime_shared_input_macro_analytical() -> equation::Analytical { analytical! { name: "one_cmt_abs_shared", params: [ka, ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { oral => tlag } @@ -685,7 +809,7 @@ fn runtime_shared_channel_macro_analytical() -> equation::Analytical { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_handwritten_analytical() -> equation::Analytical { +fn runtime_shared_input_handwritten_analytical() -> equation::Analytical { equation::Analytical::new( equation::one_compartment_with_absorption, |_p, _t, _cov| {}, @@ -721,21 +845,21 @@ fn runtime_shared_channel_handwritten_analytical() -> equation::Analytical { ]) .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), ) - .expect("handwritten shared-channel analytical metadata should validate") + .expect("handwritten shared-input analytical metadata should validate") } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_macro_sde() -> equation::SDE { +fn runtime_shared_input_macro_sde() -> equation::SDE { sde! { name: "one_cmt_shared_sde", params: [ka, ke, sigma_ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], drift: |x, _p, _t, dx, _cov| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -761,7 +885,7 @@ fn runtime_shared_channel_macro_sde() -> equation::SDE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_handwritten_sde() -> equation::SDE { +fn runtime_shared_input_handwritten_sde() -> equation::SDE { equation::SDE::new( |x, p, _t, dx, rateiv, _cov| { fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); @@ -811,7 +935,7 @@ fn runtime_shared_channel_handwritten_sde() -> equation::SDE { ]) .particles(8), ) - .expect("handwritten shared-channel SDE metadata should validate") + .expect("handwritten shared-input SDE metadata should validate") } #[cfg(feature = "dsl-jit")] @@ -852,9 +976,9 @@ fn macro_analytical_model() -> equation::Analytical { params: [ka, ke, v], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, - }, + ], structure: one_compartment_with_absorption, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -895,9 +1019,9 @@ fn macro_sde_model() -> equation::SDE { states: [depot, central], outputs: [cp], particles: 256, - routes: { + routes: [ bolus(oral) -> depot, - }, + ], drift: |x, _p, _t, dx, _cov| { dx[depot] = -ka * x[depot]; dx[central] = ka * x[depot] - ke * x[central]; @@ -1037,6 +1161,93 @@ fn ode_dsl_and_handwritten_metadata_agree_on_public_shape() { assert_eq!(handwritten_view, dsl_view); } +#[test] +fn ode_dsl_declared_output_order_controls_dense_indices_for_multi_digit_labels() { + let dsl_view = dsl_metadata_view(ODE_MULTI_DIGIT_OUTPUT_ORDER_DSL); + + assert_eq!( + dsl_view.outputs, + vec![ + NamedIndex { + name: "2".to_string(), + index: 0, + }, + NamedIndex { + name: "10".to_string(), + index: 1, + }, + NamedIndex { + name: "11".to_string(), + index: 2, + }, + ] + ); +} + +#[test] +fn ode_authoring_dsl_supports_multi_digit_numeric_route_labels() { + let dsl_view = dsl_metadata_view(ODE_NUMERIC_ROUTE_LABELS_AUTHORING_DSL); + + assert_eq!(dsl_view.route_input_count, 2); + assert_eq!( + dsl_view.routes, + vec![ + RouteParity { + name: "10".to_string(), + kind: Some(RouteKindParity::Bolus), + declaration_index: 0, + input_index: 0, + destination_name: "first".to_string(), + destination_index: 0, + has_lag: false, + has_bioavailability: false, + }, + RouteParity { + name: "11".to_string(), + kind: Some(RouteKindParity::Bolus), + declaration_index: 1, + input_index: 1, + destination_name: "second".to_string(), + destination_index: 1, + has_lag: false, + has_bioavailability: false, + }, + ] + ); +} + +#[test] +fn ode_structured_dsl_supports_multi_digit_numeric_route_labels() { + let dsl_view = dsl_metadata_view(ODE_NUMERIC_ROUTE_LABELS_STRUCTURED_DSL); + + assert_eq!(dsl_view.route_input_count, 2); + assert_eq!( + dsl_view.routes, + vec![ + RouteParity { + name: "10".to_string(), + kind: None, + declaration_index: 0, + input_index: 0, + destination_name: "first".to_string(), + destination_index: 0, + has_lag: false, + has_bioavailability: false, + }, + RouteParity { + name: "11".to_string(), + kind: None, + declaration_index: 1, + input_index: 1, + destination_name: "second".to_string(), + destination_index: 1, + has_lag: false, + has_bioavailability: false, + }, + ] + ); +} + #[test] fn ode_macro_dsl_and_handwritten_metadata_agree_on_macro_authorable_shape() { let handwritten = handwritten_ode_macro_model(); @@ -1177,11 +1388,11 @@ fn invalid_dsl_infusion_route_properties_fail_explicitly() { #[cfg(feature = "dsl-jit")] #[test] -fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { +fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { let runtime_model = - compile_runtime_jit_model(ODE_RUNTIME_SHARED_CHANNEL_DSL, "shared_channel_one_cpt"); - let macro_model = runtime_shared_channel_macro_ode(); - let handwritten_model = runtime_shared_channel_handwritten_ode(); + compile_runtime_jit_model(ODE_RUNTIME_SHARED_INPUT_DSL, "shared_input_one_cpt"); + let macro_model = runtime_shared_input_macro_ode(); + let handwritten_model = runtime_shared_input_handwritten_ode(); let oral = runtime_model .route_index("oral") @@ -1192,11 +1403,12 @@ fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_input_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(macro_model.route_index("oral"), Some(oral)); assert_eq!(macro_model.route_index("iv"), Some(iv)); assert_eq!(handwritten_model.route_index("oral"), Some(oral)); @@ -1226,11 +1438,11 @@ fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha #[cfg(feature = "dsl-jit")] #[test] -fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { +fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { let runtime_model = - compile_runtime_jit_model(ANALYTICAL_RUNTIME_SHARED_CHANNEL_DSL, "one_cmt_abs_shared"); - let macro_model = runtime_shared_channel_macro_analytical(); - let handwritten_model = runtime_shared_channel_handwritten_analytical(); + compile_runtime_jit_model(ANALYTICAL_RUNTIME_SHARED_INPUT_DSL, "one_cmt_abs_shared"); + let macro_model = runtime_shared_input_macro_analytical(); + let handwritten_model = runtime_shared_input_handwritten_analytical(); let oral = runtime_model .route_index("oral") @@ -1241,11 +1453,12 @@ fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_chan let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_input_prediction_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(macro_model.route_index("oral"), Some(oral)); assert_eq!(macro_model.route_index("iv"), Some(iv)); assert_eq!(handwritten_model.route_index("oral"), Some(oral)); @@ -1277,11 +1490,11 @@ fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_chan #[cfg(feature = "dsl-jit")] #[test] -fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { +fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { let runtime_model = - compile_runtime_jit_model(SDE_RUNTIME_SHARED_CHANNEL_DSL, "one_cmt_shared_sde"); - let macro_model = runtime_shared_channel_macro_sde(); - let handwritten_model = runtime_shared_channel_handwritten_sde(); + compile_runtime_jit_model(SDE_RUNTIME_SHARED_INPUT_DSL, "one_cmt_shared_sde"); + let macro_model = runtime_shared_input_macro_sde(); + let handwritten_model = runtime_shared_input_handwritten_sde(); let oral = runtime_model .route_index("oral") @@ -1292,11 +1505,12 @@ fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_input_prediction_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(macro_model.route_index("oral"), Some(oral)); assert_eq!(macro_model.route_index("iv"), Some(iv)); assert_eq!(handwritten_model.route_index("oral"), Some(oral)); @@ -1328,8 +1542,8 @@ fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha #[test] fn route_input_policy_runtime_mismatches_are_detected_explicitly() { let runtime_model = - compile_runtime_jit_model(ODE_RUNTIME_SHARED_CHANNEL_DSL, "shared_channel_one_cpt"); - let mismatched_model = runtime_mismatched_shared_channel_ode(); + compile_runtime_jit_model(ODE_RUNTIME_SHARED_INPUT_DSL, "shared_input_one_cpt"); + let mismatched_model = runtime_mismatched_shared_input_ode(); let oral = runtime_model .route_index("oral") @@ -1340,11 +1554,12 @@ fn route_input_policy_runtime_mismatches_are_detected_explicitly() { let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_input_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(mismatched_model.route_index("oral"), Some(oral)); assert_eq!(mismatched_model.route_index("iv"), Some(iv)); @@ -1363,3 +1578,81 @@ fn route_input_policy_runtime_mismatches_are_detected_explicitly() { assert_prediction_vectors_diverge(&runtime_predictions, &mismatched_predictions, 1e-4); } + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_preserves_mixed_output_labels() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_MIXED_OUTPUT_LABELS_DSL, + "mixed_output_labels_runtime", + ); + let subject = Subject::builder("runtime-mixed-output-labels") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "cp") + .missing_observation(0.5, "0") + .missing_observation(0.5, "1") + .build(); + let support_point = [0.2, 10.0]; + + assert_eq!(runtime_model.output_index("cp"), Some(0)); + assert_eq!(runtime_model.output_index("0"), Some(1)); + assert_eq!(runtime_model.output_index("1"), Some(2)); + + let predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime mixed-output model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => panic!("ODE runtime should return subject predictions"), + }; + + assert_eq!(predictions.len(), 3); + assert_relative_eq!(predictions[1], 2.0 * predictions[0], epsilon = 1e-6); + assert_relative_eq!(predictions[2], 3.0 * predictions[0], epsilon = 1e-6); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_rejects_undeclared_numeric_output_labels_even_when_dense_index_exists() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_UNDECLARED_NUMERIC_OUTPUT_LABEL_DSL, + "undeclared_numeric_output_runtime", + ); + let subject = Subject::builder("runtime-undeclared-numeric-output") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "10") + .build(); + let support_point = [0.2, 10.0]; + + let error = runtime_model + .estimate_predictions(&subject, &support_point) + .expect_err("undeclared numeric output label should fail"); + + assert!(matches!( + error, + dsl::RuntimeError::Runtime(PharmsolError::UnknownOutputLabel { label }) if label == "10" + )); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_rejects_undeclared_numeric_input_labels_even_when_dense_index_exists() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_UNDECLARED_NUMERIC_INPUT_LABEL_DSL, + "undeclared_numeric_input_runtime", + ); + let subject = Subject::builder("runtime-undeclared-numeric-input") + .bolus(0.0, 100.0, "10") + .missing_observation(0.5, "cp") + .build(); + let support_point = [0.2, 10.0]; + + let error = runtime_model + .estimate_predictions(&subject, &support_point) + .expect_err("undeclared numeric input label should fail"); + + assert!(matches!( + error, + dsl::RuntimeError::Runtime(PharmsolError::UnknownInputLabel { label }) if label == "10" + )); +} diff --git a/tests/full_feature_dsl_backend_parity.rs b/tests/full_feature_dsl_backend_parity.rs new file mode 100644 index 00000000..929e7243 --- /dev/null +++ b/tests/full_feature_dsl_backend_parity.rs @@ -0,0 +1,216 @@ +#[path = "support/runtime_corpus.rs"] +mod runtime_corpus; + +#[cfg(all(feature = "dsl-jit", feature = "dsl-wasm"))] +mod tests { + use super::runtime_corpus::{self as corpus, CorpusCase}; + use pharmsol::dsl::{CompiledRuntimeModel, RuntimeBackend}; + + fn owned_names(names: &[&str]) -> Vec { + names.iter().map(|name| (*name).to_owned()).collect() + } + + fn assert_info_matches( + left_label: &str, + left: &CompiledRuntimeModel, + right_label: &str, + right: &CompiledRuntimeModel, + ) { + assert_eq!( + left.info(), + right.info(), + "{left_label} model info diverged from {right_label}" + ); + } + + fn assert_ode_full_public_shape(model: &CompiledRuntimeModel) { + let info = model.info(); + + assert_eq!(info.name, "ode_full_feature_parity"); + assert_eq!( + info.parameters, + owned_names(&[ + "ka", + "ke", + "kcp", + "kpc", + "v", + "tlag", + "f_oral", + "base_depot", + "base_central", + "base_peripheral", + ]) + ); + assert_eq!( + info.covariates + .iter() + .map(|covariate| covariate.name.as_str()) + .collect::>(), + vec!["wt", "renal"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["oral", "load", "iv"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.declaration_index) + .collect::>(), + vec![0, 1, 2] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.index) + .collect::>(), + vec![0, 1, 0] + ); + assert_eq!( + info.outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp"] + ); + assert_eq!(model.route_index("oral"), Some(0)); + assert_eq!(model.route_index("load"), Some(1)); + assert_eq!(model.route_index("iv"), Some(0)); + assert_eq!(model.output_index("cp"), Some(0)); + } + + fn assert_analytical_full_public_shape(model: &CompiledRuntimeModel) { + let info = model.info(); + + assert_eq!(info.name, "analytical_full_feature_parity"); + assert_eq!( + info.parameters, + owned_names(&[ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ]) + ); + assert_eq!( + info.covariates + .iter() + .map(|covariate| covariate.name.as_str()) + .collect::>(), + vec!["wt", "renal"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["oral", "load", "iv"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.declaration_index) + .collect::>(), + vec![0, 1, 2] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.index) + .collect::>(), + vec![0, 1, 0] + ); + assert_eq!( + info.outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp"] + ); + assert_eq!(model.route_index("oral"), Some(0)); + assert_eq!(model.route_index("load"), Some(1)); + assert_eq!(model.route_index("iv"), Some(0)); + assert_eq!(model.output_index("cp"), Some(0)); + } + + fn assert_full_backend_parity( + case: CorpusCase, + assert_public_shape: fn(&CompiledRuntimeModel), + ) -> Result<(), Box> { + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let workspace = super::runtime_corpus::ArtifactWorkspace::new()?; + + let jit = corpus::compile_runtime_jit_model(case)?; + assert_eq!(jit.backend(), RuntimeBackend::Jit); + assert_public_shape(&jit); + corpus::assert_runtime_model_matches_reference(case, "runtime-jit", &jit)?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let aot = corpus::compile_runtime_native_aot_model(case, &workspace)?; + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + { + assert_eq!(aot.backend(), RuntimeBackend::NativeAot); + assert_public_shape(&aot); + corpus::assert_runtime_model_matches_reference(case, "runtime-native-aot", &aot)?; + } + + let wasm = corpus::compile_runtime_wasm_model(case)?; + assert_eq!(wasm.backend(), RuntimeBackend::Wasm); + assert_public_shape(&wasm); + corpus::assert_runtime_model_matches_reference(case, "runtime-wasm", &wasm)?; + + assert_info_matches("runtime-jit", &jit, "runtime-wasm", &wasm); + corpus::assert_runtime_models_match_each_other( + case, + "runtime-jit", + &jit, + "runtime-wasm", + &wasm, + )?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + { + assert_info_matches("runtime-jit", &jit, "runtime-native-aot", &aot); + assert_info_matches("runtime-native-aot", &aot, "runtime-wasm", &wasm); + corpus::assert_runtime_models_match_each_other( + case, + "runtime-jit", + &jit, + "runtime-native-aot", + &aot, + )?; + corpus::assert_runtime_models_match_each_other( + case, + "runtime-native-aot", + &aot, + "runtime-wasm", + &wasm, + )?; + } + + Ok(()) + } + + #[test] + fn ode_full_feature_dsl_matches_handwritten_across_backends( + ) -> Result<(), Box> { + assert_full_backend_parity(CorpusCase::OdeFull, assert_ode_full_public_shape) + } + + #[test] + fn analytical_full_feature_dsl_matches_handwritten_across_backends( + ) -> Result<(), Box> { + assert_full_backend_parity( + CorpusCase::AnalyticalFull, + assert_analytical_full_public_shape, + ) + } +} diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs index 71a1afa7..5017902e 100644 --- a/tests/full_feature_macro_parity.rs +++ b/tests/full_feature_macro_parity.rs @@ -14,19 +14,19 @@ fn macro_ode_model() -> equation::ODE { covariates: [wt, renal], states: [depot, central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, bolus(load) -> central, infusion(iv) -> central, - }, - diffeq: |x, _t, dx, bolus, rateiv| { + ], + diffeq: |x, _t, dx| { let wt_scale = (wt / 70.0).powf(0.75); let renal_scale = (renal / 90.0).powf(0.25); let adjusted_ke = ke * wt_scale * renal_scale; let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); - dx[depot] = bolus[oral] - ka * x[depot]; - dx[central] = bolus[load] + ka * x[depot] + rateiv[iv] + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - (adjusted_ke + adjusted_kcp) * x[central] + kpc * x[peripheral]; dx[peripheral] = adjusted_kcp * x[central] - kpc * x[peripheral]; @@ -185,31 +185,31 @@ fn handwritten_ode_model() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::bolus("load") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ]), ) .expect("handwritten ODE metadata should validate") } -fn build_ode_subject(oral: usize, load: usize, iv: usize, cp: usize) -> Subject { +fn build_ode_subject() -> Subject { Subject::builder("macro-vs-handwritten-ode-full-features") - .bolus(0.0, 80.0, load) - .bolus(1.0, 120.0, oral) - .infusion(6.0, 150.0, iv, 2.5) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 80.0, "load") + .bolus(1.0, 120.0, "oral") + .infusion(6.0, 150.0, "iv", 2.5) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -224,11 +224,11 @@ fn macro_analytical_model() -> equation::Analytical { covariates: [wt, renal], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, bolus(load) -> central, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, sec: |_t| { let wt_scale = (wt / 70.0).powf(0.75); @@ -368,19 +368,19 @@ fn handwritten_analytical_model() -> equation::Analytical { .expect("handwritten analytical metadata should validate") } -fn build_analytical_subject(oral: usize, load: usize, iv: usize, cp: usize) -> Subject { +fn build_analytical_subject() -> Subject { Subject::builder("macro-vs-handwritten-analytical-full-features") - .bolus(0.0, 60.0, load) - .bolus(1.0, 100.0, oral) - .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 60.0, "load") + .bolus(1.0, 100.0, "oral") + .infusion(6.0, 140.0, "iv", 2.0) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -407,7 +407,7 @@ fn ode_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::Pharmsol assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); - let subject = build_ode_subject(oral, load, iv, cp); + let subject = build_ode_subject(); let params = [1.1, 0.18, 0.07, 0.04, 35.0, 0.6, 0.85, 4.0, 18.0, 9.0]; let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; @@ -453,7 +453,7 @@ fn analytical_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::P assert_eq!(handwritten_analytical.route_index("iv"), Some(iv)); assert_eq!(handwritten_analytical.output_index("cp"), Some(cp)); - let subject = build_analytical_subject(oral, load, iv, cp); + let subject = build_analytical_subject(); let params = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; let macro_predictions = macro_analytical.estimate_predictions(&subject, ¶ms)?; diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index 7b068733..99e0eeab 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -1,47 +1,64 @@ use approx::assert_relative_eq; +use pharmsol::prelude::data::read_pmetrics; use pharmsol::prelude::*; +use tempfile::NamedTempFile; -fn subject_for_route(input: usize) -> Subject { +fn write_pmetrics_fixture(contents: &str) -> NamedTempFile { + let file = NamedTempFile::new().expect("create temporary Pmetrics fixture"); + std::fs::write(file.path(), contents).expect("write temporary Pmetrics fixture"); + file +} + +fn subject_for_route(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("macro-lowering") .infusion(0.0, 100.0, input, 1.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn subject_for_shared_channel(input: usize) -> Subject { - Subject::builder("macro-shared-channel") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(6.5, 0) - .missing_observation(7.0, 0) - .missing_observation(8.0, 0) +fn subject_for_shared_input() -> Subject { + Subject::builder("macro-shared-input") + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } -fn subject_for_covariates(input: usize) -> Subject { +fn subject_for_covariates(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("macro-covariates") .bolus(0.0, 100.0, input) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .covariate("wt", 0.0, 70.0) .build() } +fn subject_for_numeric_bolus_route(input: impl ToString, outeq: impl ToString) -> Subject { + Subject::builder("numeric-bolus-route") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) + .build() +} + fn injected_macro_ode() -> equation::ODE { ode! { name: "injected_one_cpt", params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _t, dx| { dx[central] = -ke * x[central]; }, @@ -82,25 +99,25 @@ fn injected_handwritten_ode() -> equation::ODE { .expect("handwritten injected metadata should validate") } -fn explicit_macro_ode() -> equation::ODE { +fn numeric_label_macro_ode() -> equation::ODE { ode! { - name: "explicit_one_cpt", + name: "numeric_label_one_cpt", params: [ke, v], states: [central], - outputs: [cp], - routes: { - infusion(iv) -> central, - }, - diffeq: |x, _t, dx, _bolus, rateiv| { - dx[central] = rateiv[iv] - ke * x[central]; + outputs: [1], + routes: [ + infusion(1) -> central, + ], + diffeq: |x, _t, dx| { + dx[central] = -ke * x[central]; }, out: |x, _t, y| { - y[cp] = x[central] / v; + y[1] = x[central] / v; }, } } -fn explicit_handwritten_ode() -> equation::ODE { +fn numeric_label_handwritten_ode() -> equation::ODE { equation::ODE::new( |x, p, _t, dx, _bolus, rateiv, _cov| { fetch_params!(p, ke, _v); @@ -118,32 +135,32 @@ fn explicit_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("explicit_one_cpt") + equation::metadata::new("numeric_label_one_cpt") .parameters(["ke", "v"]) .states(["central"]) - .outputs(["cp"]) + .outputs(["1"]) .route( - equation::Route::infusion("iv") + equation::Route::infusion("1") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ), ) - .expect("handwritten explicit metadata should validate") + .expect("handwritten numeric-label metadata should validate") } -fn shared_channel_macro_ode() -> equation::ODE { +fn shared_input_macro_ode() -> equation::ODE { ode! { - name: "shared_channel_one_cpt", + name: "shared_input_one_cpt", params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, infusion(iv) -> central, - }, - diffeq: |x, _t, dx, bolus, rateiv| { - dx[depot] = bolus[oral] - ka * x[depot]; - dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; + ], + diffeq: |x, _t, dx| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; }, lag: |_t| { lag! { oral => tlag } @@ -157,7 +174,7 @@ fn shared_channel_macro_ode() -> equation::ODE { } } -fn shared_channel_handwritten_ode() -> equation::ODE { +fn shared_input_handwritten_ode() -> equation::ODE { equation::ODE::new( |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); @@ -182,7 +199,7 @@ fn shared_channel_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_channel_one_cpt") + equation::metadata::new("shared_input_one_cpt") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -191,13 +208,131 @@ fn shared_channel_handwritten_ode() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ]), ) - .expect("handwritten shared-channel metadata should validate") + .expect("handwritten shared-input metadata should validate") +} + +fn numeric_route_property_macro_ode() -> equation::ODE { + ode! { + name: "numeric_route_property_one_cpt", + params: [ka, ke, v, tlag, f_oral], + states: [depot, central], + outputs: [1], + routes: [ + bolus(1) -> depot, + ], + diffeq: |x, _t, dx| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; + }, + lag: |_t| { + lag! { 1 => tlag } + }, + fa: |_t| { + fa! { 1 => f_oral } + }, + out: |x, _t, y| { + y[1] = x[central] / v; + }, + } +} + +fn numeric_route_property_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, _rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("numeric_route_property_one_cpt") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["1"]) + .route( + equation::Route::bolus("1") + .to_state("depot") + .with_lag() + .with_bioavailability() + .inject_input_to_destination(), + ), + ) + .expect("handwritten numeric route-property metadata should validate") +} + +fn mixed_output_labels_macro_ode() -> equation::ODE { + ode! { + name: "mixed_output_labels_one_cpt", + params: [ke, v], + states: [central], + outputs: [cp, 0, 1], + routes: [ + infusion(iv) -> central, + ], + diffeq: |x, _t, dx| { + dx[central] = -ke * x[central]; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + y[0] = 2.0 * x[central] / v; + y[1] = 3.0 * x[central] / v; + }, + } +} + +fn mixed_output_labels_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + y[1] = 2.0 * x[0] / v; + y[2] = 3.0 * x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(3) + .with_metadata( + equation::metadata::new("mixed_output_labels_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp", "0", "1"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .inject_input_to_destination(), + ), + ) + .expect("handwritten mixed-output metadata should validate") } fn covariate_macro_ode() -> equation::ODE { @@ -207,9 +342,9 @@ fn covariate_macro_ode() -> equation::ODE { covariates: [wt], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _t, dx| { let scaled_ke = ke * (wt / 70.0); dx[gut] = -ka * x[gut]; @@ -267,7 +402,7 @@ fn assert_prediction_match(left: &[f64], right: &[f64]) { fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { let macro_ode = injected_macro_ode(); let handwritten_ode = injected_handwritten_ode(); - let subject = subject_for_route(0); + let subject = subject_for_route("iv", "cp"); let support_point = [0.2, 10.0]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); @@ -290,25 +425,25 @@ fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { } #[test] -fn macro_explicit_lowering_matches_handwritten_metadata_and_predictions() { - let macro_ode = explicit_macro_ode(); - let handwritten_ode = explicit_handwritten_ode(); - let subject = subject_for_route(0); +fn macro_numeric_labels_lower_to_dense_slots() { + let macro_ode = numeric_label_macro_ode(); + let handwritten_ode = numeric_label_handwritten_ode(); + let subject = subject_for_route("1", "1"); let support_point = [0.2, 10.0]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - assert_eq!(macro_ode.route_index("iv"), Some(0)); - assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.route_index("1"), Some(0)); + assert_eq!(macro_ode.output_index("1"), Some(0)); assert_eq!(macro_ode.state_index("central"), Some(0)); let macro_predictions = macro_ode .estimate_predictions(&subject, &support_point) - .expect("macro explicit model should simulate") + .expect("macro numeric-label model should simulate") .flat_predictions() .to_vec(); let handwritten_predictions = handwritten_ode .estimate_predictions(&subject, &support_point) - .expect("handwritten explicit model should simulate") + .expect("handwritten numeric-label model should simulate") .flat_predictions() .to_vec(); @@ -316,10 +451,10 @@ fn macro_explicit_lowering_matches_handwritten_metadata_and_predictions() { } #[test] -fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { - let macro_ode = shared_channel_macro_ode(); - let handwritten_ode = shared_channel_handwritten_ode(); - let subject = subject_for_shared_channel(0); +fn macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = shared_input_macro_ode(); + let handwritten_ode = shared_input_handwritten_ode(); + let subject = subject_for_shared_input(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); @@ -331,23 +466,131 @@ fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() let macro_predictions = macro_ode .estimate_predictions(&subject, &support_point) - .expect("macro shared-channel model should simulate") + .expect("macro shared-input model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten shared-input model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_mixed_output_labels_lower_to_dense_slots() { + let macro_ode = mixed_output_labels_macro_ode(); + let handwritten_ode = mixed_output_labels_handwritten_ode(); + let subject = Subject::builder("mixed-output-labels") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "0") + .missing_observation(2.0, "1") + .build(); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.output_index("0"), Some(1)); + assert_eq!(macro_ode.output_index("1"), Some(2)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro mixed-output model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten mixed-output model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_numeric_route_properties_lower_to_dense_slots() { + let macro_ode = numeric_route_property_macro_ode(); + let handwritten_ode = numeric_route_property_handwritten_ode(); + let subject = subject_for_numeric_bolus_route("1", "1"); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("1"), Some(0)); + assert_eq!(macro_ode.output_index("1"), Some(0)); + assert_eq!(macro_ode.state_index("depot"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(1)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro numeric route-property model should simulate") .flat_predictions() .to_vec(); let handwritten_predictions = handwritten_ode .estimate_predictions(&subject, &support_point) - .expect("handwritten shared-channel model should simulate") + .expect("handwritten numeric route-property model should simulate") .flat_predictions() .to_vec(); assert_prediction_match(¯o_predictions, &handwritten_predictions); } +#[test] +fn macro_named_labels_resolve_from_pmetrics_ingestion() { + let file = write_pmetrics_fixture( + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,iv,.,.,.,.,.,.,.\npt1,0,0.5,.,.,.,.,.,.,cp,0,.,.,.,.\npt1,0,1.0,.,.,.,.,.,.,cp,0,.,.,.,.\npt1,0,2.0,.,.,.,.,.,.,cp,0,.,.,.,.\n", + ); + + let data = + read_pmetrics(file.path().display().to_string()).expect("read named-label Pmetrics data"); + let subject = &data.subjects()[0]; + let support_point = [0.2, 10.0]; + + let pmetrics_predictions = injected_macro_ode() + .estimate_predictions(subject, &support_point) + .expect("macro named-label model should simulate") + .flat_predictions() + .to_vec(); + let manual_predictions = injected_macro_ode() + .estimate_predictions(&subject_for_route("iv", "cp"), &support_point) + .expect("macro internal-index model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(&pmetrics_predictions, &manual_predictions); +} + +#[test] +fn macro_numeric_labels_resolve_from_pmetrics_ingestion() { + let file = write_pmetrics_fixture( + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,1,.,.,.,.,.,.,.\npt1,0,0.5,.,.,.,.,.,.,1,0,.,.,.,.\npt1,0,1.0,.,.,.,.,.,.,1,0,.,.,.,.\npt1,0,2.0,.,.,.,.,.,.,1,0,.,.,.,.\n", + ); + + let data = + read_pmetrics(file.path().display().to_string()).expect("read numeric-label Pmetrics data"); + let subject = &data.subjects()[0]; + let support_point = [0.2, 10.0]; + + let pmetrics_predictions = numeric_label_macro_ode() + .estimate_predictions(subject, &support_point) + .expect("macro numeric-label model should simulate") + .flat_predictions() + .to_vec(); + let manual_predictions = numeric_label_macro_ode() + .estimate_predictions(&subject_for_route("1", "1"), &support_point) + .expect("macro internal-index numeric-label model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(&pmetrics_predictions, &manual_predictions); +} + #[test] fn macro_covariate_lowering_matches_handwritten_metadata_and_predictions() { let macro_ode = covariate_macro_ode(); let handwritten_ode = covariate_handwritten_ode(); - let subject = subject_for_covariates(0); + let subject = subject_for_covariates("oral", "cp"); let support_point = [1.0, 0.2, 10.0]; let macro_metadata = macro_ode .metadata() diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 876d2b23..474c7bab 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -2,47 +2,47 @@ use approx::assert_relative_eq; use pharmsol::prelude::*; use pharmsol::Predictions; -fn infusion_subject(input: usize) -> Subject { +fn infusion_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("sde-macro-iv") .infusion(0.0, 120.0, input, 1.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn oral_subject(input: usize) -> Subject { +fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("sde-macro-oral") .bolus(0.0, 100.0, input) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn shared_channel_subject(input: usize) -> Subject { +fn shared_input_subject() -> Subject { Subject::builder("sde-macro-shared") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(6.5, 0) - .missing_observation(7.0, 0) - .missing_observation(8.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } -fn covariate_subject(oral: usize, iv: usize, cp: usize) -> Subject { +fn covariate_subject(oral: impl ToString, iv: impl ToString, cp: impl ToString) -> Subject { Subject::builder("sde-macro-covariates") .bolus(1.0, 100.0, oral) .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) + .missing_observation(0.25, cp.to_string()) + .missing_observation(0.75, cp.to_string()) + .missing_observation(1.5, cp.to_string()) + .missing_observation(3.0, cp.to_string()) + .missing_observation(6.5, cp.to_string()) + .missing_observation(7.0, cp.to_string()) .missing_observation(8.0, cp) .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) @@ -73,9 +73,9 @@ fn macro_infusion_sde() -> equation::SDE { states: [central], outputs: [cp], particles: 16, - routes: { + routes: [ infusion(iv) -> central, - }, + ], drift: |x, _t, dx| { dx[central] = -ke * x[central]; }, @@ -133,9 +133,9 @@ fn macro_absorption_sde() -> equation::SDE { states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, - }, + ], drift: |x, _t, dx| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -211,17 +211,17 @@ fn handwritten_absorption_sde() -> equation::SDE { .expect("handwritten absorption SDE metadata should validate") } -fn macro_shared_channel_sde() -> equation::SDE { +fn macro_shared_input_sde() -> equation::SDE { sde! { name: "one_cmt_shared_sde", params: [ka, ke, sigma_ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], drift: |x, _t, dx| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -246,7 +246,7 @@ fn macro_shared_channel_sde() -> equation::SDE { } } -fn handwritten_shared_channel_sde() -> equation::SDE { +fn handwritten_shared_input_sde() -> equation::SDE { equation::SDE::new( |x, p, _t, dx, rateiv, _cov| { fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); @@ -296,7 +296,7 @@ fn handwritten_shared_channel_sde() -> equation::SDE { ]) .particles(8), ) - .expect("handwritten shared-channel SDE metadata should validate") + .expect("handwritten shared-input SDE metadata should validate") } fn macro_covariate_sde() -> equation::SDE { @@ -307,10 +307,10 @@ fn macro_covariate_sde() -> equation::SDE { states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], drift: |x, _t, dx| { let wt_scale = (wt / 70.0).powf(0.75); let renal_scale = (renal / 90.0).powf(0.25); @@ -491,7 +491,7 @@ fn handwritten_covariate_sde() -> equation::SDE { fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_infusion_sde(); let handwritten_model = handwritten_infusion_sde(); - let subject = infusion_subject(0); + let subject = infusion_subject("iv", "cp"); let support_point = [0.2, 0.0, 10.0]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -516,7 +516,7 @@ fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { let macro_model = macro_absorption_sde(); let handwritten_model = handwritten_absorption_sde(); - let subject = oral_subject(0); + let subject = oral_subject("oral", "cp"); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -538,10 +538,10 @@ fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { } #[test] -fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { - let macro_model = macro_shared_channel_sde(); - let handwritten_model = handwritten_shared_channel_sde(); - let subject = shared_channel_subject(0); +fn sde_macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_input_sde(); + let handwritten_model = handwritten_shared_input_sde(); + let subject = shared_input_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -553,10 +553,10 @@ fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_prediction let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) - .expect("macro shared-channel SDE should simulate"); + .expect("macro shared-input SDE should simulate"); let handwritten_predictions = handwritten_model .estimate_predictions(&subject, &support_point) - .expect("handwritten shared-channel SDE should simulate"); + .expect("handwritten shared-input SDE should simulate"); assert_prediction_match( &prediction_means(¯o_predictions), @@ -571,14 +571,9 @@ fn sde_macro_covariates_lower_to_handwritten_behavior() { assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - let oral = macro_model.route_index("oral").expect("oral route exists"); - let iv = macro_model.route_index("iv").expect("iv route exists"); - let cp = macro_model.output_index("cp").expect("cp output exists"); - let subject = covariate_subject(oral, iv, cp); + let subject = covariate_subject("oral", "iv", "cp"); let support_point = [1.0, 0.16, 0.0, 32.0, 0.5, 0.8, 3.0, 14.0]; - assert_eq!(oral, iv); - let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) .expect("macro covariate SDE should simulate"); diff --git a/tests/support/bimodal_ke.rs b/tests/support/bimodal_ke.rs index 4c82be4f..6e7e5f8e 100644 --- a/tests/support/bimodal_ke.rs +++ b/tests/support/bimodal_ke.rs @@ -55,6 +55,14 @@ fn subject_for_indices(route_index: usize, output_index: usize) -> Subject { builder.build() } +fn subject_for_labels(route_label: &str, output_label: &str) -> Subject { + let mut builder = Subject::builder(MODEL_NAME).infusion(0.0, 500.0, route_label, 0.5); + for time in OBSERVATION_TIMES { + builder = builder.missing_observation(time, output_label); + } + builder.build() +} + pub fn subject() -> Subject { subject_for_indices(0, 0) } @@ -65,12 +73,15 @@ pub fn subject() -> Subject { feature = "dsl-wasm" ))] pub fn subject_for_runtime_model(model: &pharmsol::dsl::CompiledRuntimeModel) -> Subject { - let route_index = model - .route_index("iv") - .or_else(|| model.route_index("input_0")) - .expect("bimodal_ke route is available"); - let output_index = model.output_index("cp").expect("cp output is available"); - subject_for_indices(route_index, output_index) + let route_label = if model.route_index("iv").is_some() { + "iv" + } else if model.route_index("input_0").is_some() { + "input_0" + } else { + panic!("bimodal_ke route is available"); + }; + model.output_index("cp").expect("cp output is available"); + subject_for_labels(route_label, "cp") } pub fn reference_values() -> Result, Box> { diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index 3ed75511..1ca8ae78 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -208,52 +208,52 @@ impl CorpusCase { } fn runtime_subject(self, model: &CompiledRuntimeModel) -> Result> { - let cp = model + model .output_index("cp") .ok_or_else(|| io::Error::other(format!("{}: missing cp output", self.label())))?; let subject = match self { Self::Ode => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; - let iv = model.route_index("iv").ok_or_else(|| { + model.route_index("iv").ok_or_else(|| { io::Error::other(format!("{}: missing iv route", self.label())) })?; Subject::builder(self.label()) .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build() } Self::OdeFull => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; - let load = model.route_index("load").ok_or_else(|| { + model.route_index("load").ok_or_else(|| { io::Error::other(format!("{}: missing load route", self.label())) })?; - let iv = model.route_index("iv").ok_or_else(|| { + model.route_index("iv").ok_or_else(|| { io::Error::other(format!("{}: missing iv route", self.label())) })?; Subject::builder(self.label()) - .bolus(0.0, 80.0, load) - .bolus(1.0, 120.0, oral) - .infusion(6.0, 150.0, iv, 2.5) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 80.0, "load") + .bolus(1.0, 120.0, "oral") + .infusion(6.0, 150.0, "iv", 2.5) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -261,39 +261,39 @@ impl CorpusCase { .build() } Self::Analytical => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; Subject::builder(self.label()) - .bolus(0.0, 100.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 100.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build() } Self::AnalyticalFull => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; - let load = model.route_index("load").ok_or_else(|| { + model.route_index("load").ok_or_else(|| { io::Error::other(format!("{}: missing load route", self.label())) })?; - let iv = model.route_index("iv").ok_or_else(|| { + model.route_index("iv").ok_or_else(|| { io::Error::other(format!("{}: missing iv route", self.label())) })?; Subject::builder(self.label()) - .bolus(0.0, 60.0, load) - .bolus(1.0, 100.0, oral) - .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 60.0, "load") + .bolus(1.0, 100.0, "oral") + .infusion(6.0, 140.0, "iv", 2.0) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -301,16 +301,16 @@ impl CorpusCase { .build() } Self::Sde => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; Subject::builder(self.label()) .covariate("wt", 0.0, 70.0) - .bolus(0.0, 80.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 80.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build() } };