From 0a04f41a66a7fd67bc5c3bdcf46151c55e89cac2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 13:39:28 +0100 Subject: [PATCH 01/12] feature: INPUT/OUTEQ colums in the Pmetrics format are now read as Strings instead of Integer. This gives flexibility to modelers to define routes in terms of indices or named elements. All three surfaces for Model creation are also updated (macros, SDL, and ::new() constructors) --- examples/macro_vs_handwritten_one_cpt.rs | 15 +- examples/macro_vs_handwritten_two_cpt.rs | 30 +- pharmsol-dsl/src/authoring.rs | 51 +- pharmsol-dsl/src/parser.rs | 77 ++- .../tests/dsl_authoring_edge_cases.rs | 50 +- pharmsol-macros/src/lib.rs | 456 +++++++++++++++--- src/data/builder.rs | 12 +- src/data/event.rs | 145 +++++- src/data/parser/pmetrics.rs | 62 ++- src/data/row.rs | 34 +- src/data/structs.rs | 57 +-- src/dsl/native.rs | 182 +++++-- src/error/mod.rs | 4 + src/simulator/equation/analytical/mod.rs | 24 +- src/simulator/equation/mod.rs | 100 +++- src/simulator/equation/ode/closure.rs | 6 +- src/simulator/equation/ode/mod.rs | 44 +- src/simulator/equation/sde/mod.rs | 41 +- tests/analytical_macro_lowering.rs | 61 ++- tests/authoring_parity_corpus.rs | 81 +++- tests/ode_macro_lowering.rs | 360 +++++++++++++- tests/sde_macro_lowering.rs | 61 ++- 22 files changed, 1589 insertions(+), 364 deletions(-) diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs index 4d8f74d0..ddff59f8 100644 --- a/examples/macro_vs_handwritten_one_cpt.rs +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -26,6 +26,9 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( + // Handwritten closures stay on dense internal channels. + // Public labels like `iv` and `cp` live in attached metadata, not in + // the low-level `rateiv[]` / `y[]` buffers. |x, p, _t, dx, _bolus, rateiv, _cov| { fetch_params!(p, ke, _v); dx[0] = rateiv[0] - ke * x[0]; @@ -75,12 +78,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); let subject = Subject::builder("macro-vs-handwritten-one-cpt") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) - .missing_observation(8.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") .build(); let params = [1.022, 194.0]; diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs index 9ab1a675..915267d6 100644 --- a/examples/macro_vs_handwritten_two_cpt.rs +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -29,6 +29,10 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( + // Handwritten closures stay on dense internal channels. + // Public route labels like `load` and `iv` are metadata names; the + // low-level `bolus[]`, `rateiv[]`, and `y[]` buffers remain indexed by + // dense internal slots. |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ke, kcp, kpc, _v); dx[0] = -ke * x[0] - kcp * x[0] + kpc * x[1] + rateiv[0] + bolus[0]; @@ -88,19 +92,19 @@ fn main() -> Result<(), pharmsol::PharmsolError> { assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); let subject = Subject::builder("macro-vs-handwritten-two-cpt") - .bolus(0.0, 100.0, load) - .infusion(12.0, 200.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) - .missing_observation(12.5, cp) - .missing_observation(13.0, cp) - .missing_observation(14.0, cp) - .missing_observation(16.0, cp) - .missing_observation(24.0, cp) + .bolus(0.0, 100.0, "load") + .infusion(12.0, 200.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .missing_observation(12.5, "cp") + .missing_observation(13.0, "cp") + .missing_observation(14.0, "cp") + .missing_observation(16.0, "cp") + .missing_observation(24.0, "cp") .build(); let params = [0.1, 0.05, 0.03, 50.0]; diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index c81c19eb..129f07c8 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -371,7 +371,7 @@ impl<'a> AuthoringParser<'a> { if lhs_trimmed == "outputs" { self.declared_outputs_span = Some(span); - for ident in parse_ident_list(rhs, rhs_abs)? { + for ident in parse_output_label_list(rhs, rhs_abs)? { self.declared_outputs.insert(ident.text.clone()); self.explicit_outputs.insert(ident.text, ident.span); } @@ -413,7 +413,20 @@ impl<'a> AuthoringParser<'a> { return self.parse_call_assignment(call, rhs, rhs_abs, span); } - let target = parse_ident_segment(lhs, lhs_abs)?; + let target = match parse_ident_segment(lhs, lhs_abs) { + Ok(target) => target, + Err(error) => { + if self.declared_outputs_span.is_none() { + return Err(error); + } + + let target = parse_output_label_segment(lhs, lhs_abs)?; + if !self.declared_outputs.contains(&target.text) { + return Err(self.undeclared_output_error(&target.text, target.span)); + } + target + } + }; let rhs = parse_surface_rhs(rhs, rhs_abs)?; let stmt = build_assignment_statement( AssignTarget { @@ -552,7 +565,7 @@ impl<'a> AuthoringParser<'a> { self.init_statements.push(stmt); } "out" => { - let output = parse_ident_segment(call.argument, call.argument_start)?; + let output = parse_output_label_segment(call.argument, call.argument_start)?; self.validate_output_target(&output)?; self.declared_outputs.insert(output.text.clone()); self.note_output_assignment(&output); @@ -839,6 +852,13 @@ fn parse_ident_list(src: &str, abs_start: usize) -> Result, ParseErro .collect() } +fn parse_output_label_list(src: &str, abs_start: usize) -> Result, ParseError> { + split_top_level(src, ',') + .into_iter() + .map(|(segment, start)| parse_output_label_segment(segment, abs_start + start)) + .collect() +} + fn parse_covariates_list(src: &str, abs_start: usize) -> Result, ParseError> { let mut covariates = Vec::new(); for (segment, start) in split_top_level(src, ',') { @@ -907,6 +927,27 @@ fn parse_ident_segment(src: &str, abs_start: usize) -> Result )) } +fn parse_output_label_segment(src: &str, abs_start: usize) -> Result { + let trimmed = src.trim(); + let leading = src.len() - src.trim_start().len(); + if trimmed.is_empty() { + return Err(ParseError::new( + "expected output label", + Span::new(abs_start, abs_start + src.len()), + )); + } + if !is_valid_output_label(trimmed) { + return Err(ParseError::new( + format!("expected output label, found `{trimmed}`"), + Span::new(abs_start + leading, abs_start + leading + trimmed.len()), + )); + } + Ok(Ident::new( + trimmed, + Span::new(abs_start + leading, abs_start + leading + trimmed.len()), + )) +} + fn parse_place_at(src: &str, abs_start: usize) -> Result { let mut place = parse_place_fragment(src).map_err(|error| error.shifted(abs_start))?; shift_place(&mut place, abs_start); @@ -1344,6 +1385,10 @@ fn is_valid_ident(src: &str) -> bool { chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_') } +fn is_valid_output_label(src: &str) -> bool { + is_valid_ident(src) || src.chars().all(|ch| ch.is_ascii_digit()) +} + fn is_ident_byte(byte: u8) -> bool { (byte as char).is_ascii_alphanumeric() || byte == b'_' } diff --git a/pharmsol-dsl/src/parser.rs b/pharmsol-dsl/src/parser.rs index f07fbd50..c265b4df 100644 --- a/pharmsol-dsl/src/parser.rs +++ b/pharmsol-dsl/src/parser.rs @@ -106,12 +106,18 @@ struct Parser { #[derive(Clone, Copy)] enum LayoutBoundary { ModelItem, - Statement, + Statement(StatementContext), Binding, IdentItem, RouteDecl, } +#[derive(Clone, Copy, PartialEq, Eq)] +enum StatementContext { + Standard, + Outputs, +} + impl Parser { fn new(src: &str) -> Result { Ok(Self::from_tokens(lex(src)?, src.len())) @@ -655,8 +661,13 @@ impl Parser { fn parse_statement_block(&mut self, name: &str) -> Result { let start = self.bump().unwrap().span; let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; + let statement_context = if name == "outputs" { + StatementContext::Outputs + } else { + StatementContext::Standard + }; let (statements, mut errors) = - self.with_layout_boundary(LayoutBoundary::Statement, |parser| { + self.with_layout_boundary(LayoutBoundary::Statement(statement_context), |parser| { let mut statements = Vec::new(); let mut errors = Vec::new(); while !parser.is_eof() && !parser.at(|kind| matches!(kind, TokenKind::RBrace)) { @@ -790,8 +801,9 @@ impl Parser { fn parse_stmt_body(&mut self) -> Result, ParseError> { let open = self.expect_simple(|kind| matches!(kind, TokenKind::LBrace), "`{`")?; + let statement_context = self.current_statement_context(); let (statements, mut errors) = - self.with_layout_boundary(LayoutBoundary::Statement, |parser| { + self.with_layout_boundary(LayoutBoundary::Statement(statement_context), |parser| { let mut statements = Vec::new(); let mut errors = Vec::new(); while !parser.is_eof() && !parser.at(|kind| matches!(kind, TokenKind::RBrace)) { @@ -854,7 +866,11 @@ impl Parser { } fn parse_assign_target(&mut self) -> Result { - let name = self.parse_ident()?; + let name = if matches!(self.current_statement_context(), StatementContext::Outputs) { + self.parse_output_target_name()? + } else { + self.parse_ident()? + }; let mut span = name.span; let kind = if let Some(open) = self.take_if(|kind| matches!(kind, TokenKind::LParen)) { let args = self.parse_expr_list(&open, TokenKindMatcher::RPAREN)?; @@ -885,6 +901,30 @@ impl Parser { Ok(AssignTarget { kind, span }) } + fn parse_output_target_name(&mut self) -> Result { + let token = self + .bump() + .ok_or_else(|| ParseError::new("expected output label", Span::empty(self.src_len)))?; + match token.kind { + TokenKind::Ident(name) => Ok(Ident::new(name, token.span)), + TokenKind::Number(value) + if value.is_finite() + && value >= 0.0 + && value.fract() == 0.0 + && value <= usize::MAX as f64 => + { + Ok(Ident::new((value as usize).to_string(), token.span)) + } + other => Err(ParseError::new( + format!( + "expected output label identifier or non-negative integer, found {}", + other.describe() + ), + token.span, + )), + } + } + fn parse_ident(&mut self) -> Result { let token = self .bump() @@ -1320,9 +1360,12 @@ impl Parser { | TokenKind::Diffusion | TokenKind::Particles ), - LayoutBoundary::Statement => match &token.kind { + LayoutBoundary::Statement(context) => match &token.kind { TokenKind::If | TokenKind::For | TokenKind::Let => true, TokenKind::Ident(_) => self.line_starts_assignment_target(index), + TokenKind::Number(_) if matches!(context, StatementContext::Outputs) => { + self.line_starts_numeric_output_assignment(index) + } _ => false, }, LayoutBoundary::Binding => self.line_starts_named_assignment(index), @@ -1379,6 +1422,26 @@ impl Parser { } } + fn line_starts_numeric_output_assignment(&self, index: usize) -> bool { + matches!( + self.tokens.get(index).map(|token| &token.kind), + Some(TokenKind::Number(_)) + ) && self + .next_same_line_index(index) + .is_some_and(|next| matches!(self.tokens[next].kind, TokenKind::Eq)) + } + + fn current_statement_context(&self) -> StatementContext { + self.layout_boundaries + .iter() + .rev() + .find_map(|boundary| match boundary { + LayoutBoundary::Statement(context) => Some(*context), + _ => None, + }) + .unwrap_or(StatementContext::Standard) + } + fn next_same_line_index(&self, index: usize) -> Option { let next = index + 1; let token = self.tokens.get(next)?; @@ -1413,7 +1476,7 @@ impl Parser { fn current_boundary_label(&self) -> &'static str { match self.current_layout_boundary() { Some(LayoutBoundary::ModelItem) => "next model item starts here", - Some(LayoutBoundary::Statement) => "next statement starts here", + Some(LayoutBoundary::Statement(_)) => "next statement starts here", Some(LayoutBoundary::Binding) => "next binding starts here", Some(LayoutBoundary::IdentItem) => "next declaration starts here", Some(LayoutBoundary::RouteDecl) => "next route starts here", @@ -1424,7 +1487,7 @@ impl Parser { fn current_boundary_subject(&self) -> &'static str { match self.current_layout_boundary() { Some(LayoutBoundary::ModelItem) => "model item", - Some(LayoutBoundary::Statement) => "statement", + Some(LayoutBoundary::Statement(_)) => "statement", Some(LayoutBoundary::Binding) => "binding", Some(LayoutBoundary::IdentItem) => "declaration", Some(LayoutBoundary::RouteDecl) => "route", diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 797be3e9..404487dc 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -1,4 +1,4 @@ -use pharmsol_dsl::{analyze_model, parse_model, parse_module}; +use pharmsol_dsl::{analyze_model, lower_typed_model, parse_model, parse_module}; #[test] fn output_annotation_is_optional() { @@ -161,6 +161,54 @@ out(cp) = central ~ continous() ); } +#[test] +fn mixed_named_and_numeric_output_labels_lower_and_round_trip() { + let src = r#" +name = mixed_output_labels +kind = ode +params = ke, v +states = central +outputs = cp, 0, 1 +infusion(iv) -> central +ddt(central) = -ke * central +out(cp) = central / v +out(0) = 2 * central / v +out(1) = 3 * central / v +"#; + + let module = parse_module(src).expect("mixed output labels should parse in authoring DSL"); + let model = module + .models + .first() + .expect("authoring DSL should produce one model"); + let typed = analyze_model(&model).expect("mixed output labels should analyze"); + let lowered = lower_typed_model(&typed).expect("mixed output labels should lower"); + + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp", "0", "1"] + ); + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.index) + .collect::>(), + vec![0, 1, 2] + ); + + let rendered = module.to_string(); + let reparsed = parse_module(&rendered).expect("rendered mixed-output model should reparse"); + + assert_eq!(rendered, reparsed.to_string()); +} + #[test] fn unknown_route_destination_state_suggests_declared_state() { let src = r#" diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 54a79fe3..96b9536e 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -6,13 +6,14 @@ use proc_macro::TokenStream; use proc_macro2::{Span, TokenStream as TokenStream2}; use quote::{quote, ToTokens}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use syn::{ parse::{Parse, ParseStream, Parser}, punctuated::Punctuated, token, visit::Visit, - Expr, ExprClosure, Ident, LitStr, Pat, Stmt, Token, + visit_mut::VisitMut, + Expr, ExprClosure, Ident, Lit, LitInt, LitStr, Pat, Stmt, Token, }; // --------------------------------------------------------------------------- @@ -24,7 +25,7 @@ struct OdeInput { params: Vec, covariates: Vec, states: Vec, - outputs: Vec, + outputs: Vec, routes: Vec, diffeq_mode: OdeDiffeqMode, diffeq: ExprClosure, @@ -39,7 +40,7 @@ struct AnalyticalInput { params: Vec, covariates: Vec, states: Vec, - outputs: Vec, + outputs: Vec, routes: Vec, structure: Ident, sec: Option, @@ -54,7 +55,7 @@ struct SdeInput { params: Vec, covariates: Vec, states: Vec, - outputs: Vec, + outputs: Vec, routes: Vec, particles: Expr, drift: ExprClosure, @@ -73,7 +74,7 @@ enum OdeDiffeqMode { struct OdeRouteDecl { kind: OdeRouteKind, - input: Ident, + input: SymbolicIndex, destination: Ident, } @@ -91,10 +92,78 @@ struct AnalyticalKernelSpec { } struct RoutePropertyEntry { - route: Ident, + route: SymbolicIndex, value: Expr, } +#[derive(Clone)] +enum SymbolicIndex { + Ident(Ident), + Int(LitInt), +} + +impl SymbolicIndex { + fn name(&self) -> String { + match self { + Self::Ident(ident) => ident.to_string(), + Self::Int(lit) => lit.base10_digits().to_string(), + } + } + + fn ident(&self) -> Option<&Ident> { + match self { + Self::Ident(ident) => Some(ident), + Self::Int(_) => None, + } + } + + fn numeric_value(&self) -> Option { + match self { + Self::Ident(_) => None, + Self::Int(lit) => Some( + lit.base10_parse::() + .expect("validated numeric label should fit usize"), + ), + } + } + + fn numeric(value: usize) -> Self { + Self::Int(LitInt::new(&value.to_string(), Span::call_site())) + } +} + +impl Parse for SymbolicIndex { + fn parse(input: ParseStream) -> syn::Result { + if input.peek(LitInt) { + let lit: LitInt = input.parse()?; + lit.base10_parse::().map_err(|_| { + syn::Error::new_spanned( + &lit, + "numeric declaration-first labels must be non-negative base-10 integers that fit in usize", + ) + })?; + Ok(Self::Int(lit)) + } else { + Ok(Self::Ident(input.parse()?)) + } + } +} + +impl ToTokens for SymbolicIndex { + fn to_tokens(&self, tokens: &mut TokenStream2) { + match self { + Self::Ident(ident) => ident.to_tokens(tokens), + Self::Int(lit) => lit.to_tokens(tokens), + } + } +} + +impl std::fmt::Display for SymbolicIndex { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(&self.name()) + } +} + impl Parse for OdeRouteDecl { fn parse(input: ParseStream) -> syn::Result { let kind_ident: Ident = input.parse()?; @@ -111,7 +180,7 @@ impl Parse for OdeRouteDecl { let content; syn::parenthesized!(content in input); - let route_input: Ident = content.parse()?; + let route_input: SymbolicIndex = content.parse()?; if !content.is_empty() { return Err(content.error("expected a single route input name inside `(...)`")); } @@ -166,7 +235,12 @@ impl Parse for OdeInput { "covariates", )?, "states" => set_once_ode(&mut states, parse_ident_list(input)?, &key, "states")?, - "outputs" => set_once_ode(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, + "outputs" => set_once_ode( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, "routes" => set_once_ode(&mut routes, parse_route_list(input)?, &key, "routes")?, "diffeq" => set_once_ode(&mut diffeq, input.parse()?, &key, "diffeq")?, "lag" => set_once_ode(&mut lag, input.parse()?, &key, "lag")?, @@ -206,14 +280,16 @@ impl Parse for OdeInput { validate_unique_idents("parameter", ¶ms, "ode!")?; validate_unique_idents("covariate", &covariates, "ode!")?; validate_unique_idents("state", &states, "ode!")?; - validate_unique_idents("output", &outputs, "ode!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "ode!")?; validate_routes(&routes, &states, "ode!")?; validate_named_binding_compatibility( NamedBindingSets { params: ¶ms, covariates: &covariates, states: &states, - outputs: &outputs, + outputs: &output_idents, routes: &routes, }, OdeBindingClosures { @@ -247,7 +323,7 @@ impl Parse for OdeInput { impl Parse for RoutePropertyEntry { fn parse(input: ParseStream) -> syn::Result { - let route: Ident = input.parse()?; + let route: SymbolicIndex = input.parse()?; input.parse::]>()?; let value: Expr = input.parse()?; Ok(Self { route, value }) @@ -287,9 +363,12 @@ impl Parse for AnalyticalInput { "states" => { set_once_analytical(&mut states, parse_ident_list(input)?, &key, "states")? } - "outputs" => { - set_once_analytical(&mut outputs, parse_ident_list(input)?, &key, "outputs")? - } + "outputs" => set_once_analytical( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, "routes" => { set_once_analytical(&mut routes, parse_route_list(input)?, &key, "routes")? } @@ -328,7 +407,9 @@ impl Parse for AnalyticalInput { validate_unique_idents("parameter", ¶ms, "analytical!")?; validate_unique_idents("covariate", &covariates, "analytical!")?; validate_unique_idents("state", &states, "analytical!")?; - validate_unique_idents("output", &outputs, "analytical!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "analytical!")?; validate_routes(&routes, &states, "analytical!")?; let kernel_spec = resolve_analytical_structure(&structure)?; @@ -358,7 +439,7 @@ impl Parse for AnalyticalInput { params: ¶ms, covariates: &covariates, states: &states, - outputs: &outputs, + outputs: &output_idents, routes: &routes, }, AnalyticalBindingClosures { @@ -431,7 +512,12 @@ impl Parse for SdeInput { "covariates", )?, "states" => set_once_sde(&mut states, parse_ident_list(input)?, &key, "states")?, - "outputs" => set_once_sde(&mut outputs, parse_ident_list(input)?, &key, "outputs")?, + "outputs" => set_once_sde( + &mut outputs, + parse_symbolic_index_list(input)?, + &key, + "outputs", + )?, "routes" => set_once_sde(&mut routes, parse_route_list(input)?, &key, "routes")?, "particles" => set_once_sde(&mut particles, input.parse()?, &key, "particles")?, "drift" => set_once_sde(&mut drift, input.parse()?, &key, "drift")?, @@ -469,14 +555,16 @@ impl Parse for SdeInput { validate_unique_idents("parameter", ¶ms, "sde!")?; validate_unique_idents("covariate", &covariates, "sde!")?; validate_unique_idents("state", &states, "sde!")?; - validate_unique_idents("output", &outputs, "sde!")?; + let output_idents = symbolic_index_idents(&outputs); + + validate_unique_symbolic_indices("output", &outputs, "sde!")?; validate_routes(&routes, &states, "sde!")?; validate_sde_named_binding_compatibility( NamedBindingSets { params: ¶ms, covariates: &covariates, states: &states, - outputs: &outputs, + outputs: &output_idents, routes: &routes, }, SdeBindingClosures { @@ -595,6 +683,16 @@ fn parse_ident_list(input: ParseStream) -> syn::Result> { .collect()) } +fn parse_symbolic_index_list(input: ParseStream) -> syn::Result> { + let content; + syn::bracketed!(content in input); + Ok( + Punctuated::::parse_terminated(&content)? + .into_iter() + .collect(), + ) +} + fn parse_route_list(input: ParseStream) -> syn::Result> { let content; syn::braced!(content in input); @@ -627,6 +725,29 @@ fn generated_ident(name: &str) -> Ident { Ident::new(name, Span::call_site()) } +fn symbolic_index_idents(labels: &[SymbolicIndex]) -> Vec { + labels + .iter() + .filter_map(|label| label.ident().cloned()) + .collect() +} + +fn symbolic_index_bindings(labels: &[SymbolicIndex]) -> Vec<(SymbolicIndex, usize)> { + labels + .iter() + .cloned() + .enumerate() + .map(|(index, label)| (label, index)) + .collect() +} + +fn symbolic_numeric_binding_map(bindings: &[(SymbolicIndex, usize)]) -> HashMap { + bindings + .iter() + .filter_map(|(label, index)| label.numeric_value().map(|value| (value, *index))) + .collect() +} + #[derive(Default)] struct ClosureBodyUsage { idents: HashSet, @@ -713,6 +834,124 @@ impl<'ast> Visit<'ast> for ClosureBodyUsage { } } +struct IndexRewriteTarget { + container: Ident, + labels: HashMap, +} + +impl IndexRewriteTarget { + fn new(container: Ident, labels: HashMap) -> Self { + Self { container, labels } + } +} + +struct NumericLabelRewriter { + index_targets: Vec, + route_labels: Option>, +} + +impl NumericLabelRewriter { + fn rewrite( + expr: &Expr, + index_targets: Vec, + route_labels: Option>, + ) -> Expr { + let mut rewritten = expr.clone(); + let mut rewriter = Self { + index_targets, + route_labels, + }; + rewriter.visit_expr_mut(&mut rewritten); + rewritten + } + + fn target_labels(&self, path: &syn::ExprPath) -> Option<&HashMap> { + if path.qself.is_some() + || path.path.leading_colon.is_some() + || path.path.segments.len() != 1 + { + return None; + } + + let ident = &path.path.segments[0].ident; + self.index_targets + .iter() + .find(|target| target.container == *ident) + .map(|target| &target.labels) + } + + fn rewrite_route_macro(&self, mac: &mut syn::Macro) { + let Some(route_labels) = self.route_labels.as_ref() else { + return; + }; + if !(mac.path.is_ident("lag") || mac.path.is_ident("fa")) { + return; + } + + let Ok(entries) = Punctuated::::parse_terminated + .parse2(mac.tokens.clone()) + else { + return; + }; + + let entries = entries.into_iter().map(|mut entry| { + if let Some(value) = entry.route.numeric_value() { + if let Some(internal_index) = route_labels.get(&value) { + entry.route = SymbolicIndex::numeric(*internal_index); + } + } + entry + }); + + let tokens = entries.map(|entry| { + let route = entry.route; + let value = entry.value; + quote! { #route => #value } + }); + mac.tokens = quote! { #(#tokens),* }; + } +} + +impl VisitMut for NumericLabelRewriter { + fn visit_expr_index_mut(&mut self, expr_index: &mut syn::ExprIndex) { + syn::visit_mut::visit_expr_index_mut(self, expr_index); + + let Expr::Path(expr_path) = expr_index.expr.as_ref() else { + return; + }; + let Some(labels) = self.target_labels(expr_path) else { + return; + }; + let Expr::Lit(expr_lit) = expr_index.index.as_ref() else { + return; + }; + let Lit::Int(lit) = &expr_lit.lit else { + return; + }; + let Ok(external_index) = lit.base10_parse::() else { + return; + }; + let Some(internal_index) = labels.get(&external_index) else { + return; + }; + + expr_index.index = Box::new(Expr::Lit(syn::ExprLit { + attrs: Vec::new(), + lit: Lit::Int(LitInt::new(&internal_index.to_string(), lit.span())), + })); + } + + fn visit_expr_macro_mut(&mut self, expr_macro: &mut syn::ExprMacro) { + self.rewrite_route_macro(&mut expr_macro.mac); + syn::visit_mut::visit_expr_macro_mut(self, expr_macro); + } + + fn visit_stmt_macro_mut(&mut self, stmt_macro: &mut syn::StmtMacro) { + self.rewrite_route_macro(&mut stmt_macro.mac); + syn::visit_mut::visit_stmt_macro_mut(self, stmt_macro); + } +} + fn generate_closure_input_aliases( closure: &ExprClosure, internal_names: &[Ident], @@ -856,10 +1095,17 @@ fn classify_diffeq_mode( } fn route_input_idents(routes: &[OdeRouteDecl]) -> Vec { - routes.iter().map(|route| route.input.clone()).collect() + routes + .iter() + .filter_map(|route| route.input.ident().cloned()) + .collect() +} + +fn route_input_names(routes: &[OdeRouteDecl]) -> Vec { + routes.iter().map(|route| route.input.name()).collect() } -fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(Ident, usize)> { +fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(SymbolicIndex, usize)> { let mut next_bolus_index = 0usize; let mut next_infusion_index = 0usize; @@ -883,7 +1129,7 @@ fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(Ident, usize)> { .collect() } -fn dense_index_len(bindings: &[(Ident, usize)]) -> usize { +fn dense_index_len(bindings: &[(SymbolicIndex, usize)]) -> usize { bindings .iter() .map(|(_, index)| index + 1) @@ -1361,12 +1607,14 @@ fn generate_index_consts(idents: &[Ident]) -> TokenStream2 { } } -fn generate_mapped_index_consts(bindings: &[(Ident, usize)]) -> TokenStream2 { - let bindings = bindings.iter().map(|(ident, index)| { - quote! { - #[allow(non_upper_case_globals, dead_code)] - const #ident: usize = #index; - } +fn generate_mapped_index_consts(bindings: &[(SymbolicIndex, usize)]) -> TokenStream2 { + let bindings = bindings.iter().filter_map(|(label, index)| { + label.ident().map(|ident| { + quote! { + #[allow(non_upper_case_globals, dead_code)] + const #ident: usize = #index; + } + }) }); quote! { @@ -1379,10 +1627,11 @@ fn expand_out( params: &[Ident], covariates: &[Ident], states: &[Ident], - outputs: &[Ident], + outputs: &[SymbolicIndex], ) -> syn::Result { let state_consts = generate_index_consts(states); - let output_consts = generate_index_consts(outputs); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); let x = generated_ident("__pharmsol_x"); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); @@ -1397,7 +1646,19 @@ fn expand_out( )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); - let body = &out.body; + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); Ok(quote! {{ let __pharmsol_out: fn( @@ -1480,14 +1741,13 @@ fn extract_route_property_routes( let macro_expr = find_terminal_macro_invocation(macro_name, label, closure)?; let entries = Punctuated::::parse_terminated .parse2(macro_expr.tokens.clone())?; - let known_routes = route_input_idents(routes) + let known_routes = route_input_names(routes) .into_iter() - .map(|route| route.to_string()) .collect::>(); let mut seen = HashSet::new(); for entry in entries { - let route_name = entry.route.to_string(); + let route_name = entry.route.name(); if !known_routes.contains(&route_name) { return Err(syn::Error::new_spanned( &entry.route, @@ -1515,7 +1775,7 @@ fn validate_route_property_kinds( property_routes: &HashSet, ) -> syn::Result<()> { for route in routes { - if property_routes.contains(&route.input.to_string()) + if property_routes.contains(&route.input.name()) && matches!(route.kind, OdeRouteKind::Infusion) { return Err(syn::Error::new_spanned( @@ -1536,7 +1796,7 @@ fn expand_ode_route_map( closure: &ExprClosure, params: &[Ident], covariates: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); @@ -1553,7 +1813,11 @@ fn expand_ode_route_map( )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); - let body = &closure.body; + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); Ok(quote! {{ let __pharmsol_route_map: fn( @@ -1626,7 +1890,7 @@ fn expand_route_metadata( .map(|route| { let input = &route.input; let destination = &route.destination; - let route_name = route.input.to_string(); + let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } @@ -1671,7 +1935,7 @@ fn expand_analytical_route_metadata( .map(|route| { let input = &route.input; let destination = &route.destination; - let route_name = route.input.to_string(); + let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } @@ -1711,7 +1975,7 @@ fn expand_sde_route_metadata( .map(|route| { let input = &route.input; let destination = &route.destination; - let route_name = route.input.to_string(); + let route_name = route.input.name(); let route_builder = match route.kind { OdeRouteKind::Bolus => { quote! { ::pharmsol::equation::Route::bolus(stringify!(#input)) } @@ -1752,7 +2016,7 @@ fn route_destination_index(route: &OdeRouteDecl, states: &[Ident]) -> usize { fn expand_injected_ode_route_terms( routes: &[OdeRouteDecl], states: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], dx: &Ident, bolus: &Ident, rateiv: &Ident, @@ -1780,7 +2044,7 @@ fn expand_injected_ode_route_terms( fn expand_injected_sde_rate_terms( routes: &[OdeRouteDecl], states: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], dx: &Ident, rateiv: &Ident, ) -> TokenStream2 { @@ -1806,7 +2070,7 @@ fn expand_injected_sde_rate_terms( fn expand_injected_sde_bolus_mappings( routes: &[OdeRouteDecl], states: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> TokenStream2 { let mut destinations = vec![quote! { None }; dense_index_len(route_bindings)]; @@ -1836,12 +2100,30 @@ fn validate_unique_idents(kind: &str, idents: &[Ident], macro_name: &str) -> syn Ok(()) } +fn validate_unique_symbolic_indices( + kind: &str, + labels: &[SymbolicIndex], + macro_name: &str, +) -> syn::Result<()> { + let mut seen = HashSet::new(); + for label in labels { + let name = label.name(); + if !seen.insert(name.clone()) { + return Err(syn::Error::new_spanned( + label, + format!("duplicate {kind} `{name}` in declaration-first `{macro_name}`"), + )); + } + } + Ok(()) +} + fn validate_routes(routes: &[OdeRouteDecl], states: &[Ident], macro_name: &str) -> syn::Result<()> { let known_states = states.iter().map(Ident::to_string).collect::>(); let mut seen_routes = HashSet::new(); for route in routes { - let route_name = route.input.to_string(); + let route_name = route.input.name(); if !seen_routes.insert(route_name.clone()) { return Err(syn::Error::new_spanned( &route.input, @@ -1869,7 +2151,7 @@ fn expand_diffeq( covariates: &[Ident], states: &[Ident], routes: &[OdeRouteDecl], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], diffeq_mode: OdeDiffeqMode, ) -> syn::Result { let state_consts = generate_index_consts(states); @@ -1907,7 +2189,25 @@ fn expand_diffeq( )?; let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); - let body = &diffeq.body; + let bolus_binding = if diffeq.inputs.len() == full_inputs.len() { + closure_param_ident(diffeq, 4).unwrap_or_else(|| bolus.clone()) + } else { + closure_param_ident(diffeq, 3).unwrap_or_else(|| bolus.clone()) + }; + let rateiv_binding = if diffeq.inputs.len() == full_inputs.len() { + closure_param_ident(diffeq, 5).unwrap_or_else(|| rateiv.clone()) + } else { + closure_param_ident(diffeq, 4).unwrap_or_else(|| rateiv.clone()) + }; + let route_label_map = symbolic_numeric_binding_map(route_bindings); + let body = NumericLabelRewriter::rewrite( + diffeq.body.as_ref(), + vec![ + IndexRewriteTarget::new(bolus_binding, route_label_map.clone()), + IndexRewriteTarget::new(rateiv_binding, route_label_map), + ], + None, + ); Ok(quote! {{ let __pharmsol_diffeq: fn( @@ -2094,7 +2394,7 @@ fn expand_analytical_route_map( closure: &ExprClosure, params: &[Ident], covariates: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); @@ -2111,7 +2411,11 @@ fn expand_analytical_route_map( )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); - let body = &closure.body; + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); Ok(quote! {{ let __pharmsol_route_map: fn( @@ -2221,10 +2525,11 @@ fn expand_analytical_out( params: &[Ident], covariates: &[Ident], states: &[Ident], - outputs: &[Ident], + outputs: &[SymbolicIndex], ) -> syn::Result { let state_consts = generate_index_consts(states); - let output_consts = generate_index_consts(outputs); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); let x = generated_ident("__pharmsol_x"); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); @@ -2239,7 +2544,19 @@ fn expand_analytical_out( )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); - let body = &out.body; + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); Ok(quote! {{ let __pharmsol_out: fn( @@ -2270,7 +2587,7 @@ fn expand_sde_drift( covariates: &[Ident], states: &[Ident], routes: &[OdeRouteDecl], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let state_consts = generate_index_consts(states); let x = generated_ident("__pharmsol_x"); @@ -2360,7 +2677,7 @@ fn expand_sde_route_map( closure: &ExprClosure, params: &[Ident], covariates: &[Ident], - route_bindings: &[(Ident, usize)], + route_bindings: &[(SymbolicIndex, usize)], ) -> syn::Result { let route_consts = generate_mapped_index_consts(route_bindings); let p = generated_ident("__pharmsol_p"); @@ -2377,7 +2694,11 @@ fn expand_sde_route_map( )?; let parameter_bindings = generate_parameter_bindings(params, closure, &p); let covariate_bindings = generate_covariate_bindings(covariates, closure, &cov, &t); - let body = &closure.body; + let body = NumericLabelRewriter::rewrite( + closure.body.as_ref(), + Vec::new(), + Some(symbolic_numeric_binding_map(route_bindings)), + ); Ok(quote! {{ let __pharmsol_route_map: fn( @@ -2444,10 +2765,11 @@ fn expand_sde_out( params: &[Ident], covariates: &[Ident], states: &[Ident], - outputs: &[Ident], + outputs: &[SymbolicIndex], ) -> syn::Result { let state_consts = generate_index_consts(states); - let output_consts = generate_index_consts(outputs); + let output_bindings = symbolic_index_bindings(outputs); + let output_consts = generate_mapped_index_consts(&output_bindings); let x = generated_ident("__pharmsol_x"); let p = generated_ident("__pharmsol_p"); let t = generated_ident("__pharmsol_t"); @@ -2462,7 +2784,19 @@ fn expand_sde_out( )?; let parameter_bindings = generate_parameter_bindings(params, out, &p); let covariate_bindings = generate_covariate_bindings(covariates, out, &cov, &t); - let body = &out.body; + let y_binding = if out.inputs.len() == full_inputs.len() { + closure_param_ident(out, 4).unwrap_or_else(|| y.clone()) + } else { + closure_param_ident(out, 2).unwrap_or_else(|| y.clone()) + }; + let body = NumericLabelRewriter::rewrite( + out.body.as_ref(), + vec![IndexRewriteTarget::new( + y_binding, + symbolic_numeric_binding_map(&output_bindings), + )], + None, + ); Ok(quote! {{ let __pharmsol_out: fn( @@ -3039,11 +3373,11 @@ mod tests { let bindings = ode_route_channel_bindings(&input.routes); assert_eq!(dense_index_len(&bindings), 2); - assert_eq!(bindings[0].0.to_string(), "oral"); + assert_eq!(bindings[0].0.name(), "oral"); assert_eq!(bindings[0].1, 0); - assert_eq!(bindings[1].0.to_string(), "iv"); + assert_eq!(bindings[1].0.name(), "iv"); assert_eq!(bindings[1].1, 0); - assert_eq!(bindings[2].0.to_string(), "sc"); + assert_eq!(bindings[2].0.name(), "sc"); assert_eq!(bindings[2].1, 1); } diff --git a/src/data/builder.rs b/src/data/builder.rs index 18aa17fe..a1718dc7 100644 --- a/src/data/builder.rs +++ b/src/data/builder.rs @@ -67,7 +67,7 @@ impl SubjectBuilder { /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered /// * `input` - The compartment number receiving the dose - pub fn bolus(self, time: f64, amount: f64, input: usize) -> Self { + pub fn bolus(self, time: f64, amount: f64, input: impl ToString) -> Self { let bolus = Bolus::new(time, amount, input, self.current_occasion.index()); let event = Event::Bolus(bolus); self.event(event) @@ -81,7 +81,7 @@ impl SubjectBuilder { /// * `amount` - Total amount of drug to be administered /// * `input` - The compartment number receiving the dose /// * `duration` - Duration of the infusion in time units - pub fn infusion(self, time: f64, amount: f64, input: usize, duration: f64) -> Self { + pub fn infusion(self, time: f64, amount: f64, input: impl ToString, duration: f64) -> Self { let infusion = Infusion::new(time, amount, input, duration, self.current_occasion.index()); let event = Event::Infusion(infusion); self.event(event) @@ -94,7 +94,7 @@ impl SubjectBuilder { /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) /// * `outeq` - Output equation number corresponding to this observation - pub fn observation(self, time: f64, value: f64, outeq: usize) -> Self { + pub fn observation(self, time: f64, value: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, Some(value), @@ -118,7 +118,7 @@ impl SubjectBuilder { self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, censoring: Censor, ) -> Self { let observation = Observation::new( @@ -139,7 +139,7 @@ impl SubjectBuilder { /// /// * `time` - Time of the observation /// * `outeq` - Output equation number (zero-indexed) corresponding to this observation - pub fn missing_observation(self, time: f64, outeq: usize) -> Self { + pub fn missing_observation(self, time: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, None, @@ -165,7 +165,7 @@ impl SubjectBuilder { self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: ErrorPoly, censored: Censor, ) -> Self { diff --git a/src/data/event.rs b/src/data/event.rs index ff88e097..46995ef5 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -93,6 +93,78 @@ pub enum Event { /// An observation of drug concentration or other measure Observation(Observation), } + +#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +pub struct ChannelId(String); + +impl ChannelId { + pub fn new(label: impl ToString) -> Self { + Self(label.to_string()) + } + + pub fn as_str(&self) -> &str { + &self.0 + } + + pub fn index(&self) -> Option { + self.0.parse::().ok() + } +} + +impl From for ChannelId { + fn from(value: String) -> Self { + Self(value) + } +} + +impl From<&str> for ChannelId { + fn from(value: &str) -> Self { + Self(value.to_string()) + } +} + +impl From for ChannelId { + fn from(value: usize) -> Self { + Self(value.to_string()) + } +} + +impl AsRef for ChannelId { + fn as_ref(&self) -> &str { + self.as_str() + } +} + +impl fmt::Display for ChannelId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + +impl PartialEq for ChannelId { + fn eq(&self, other: &usize) -> bool { + self.index() == Some(*other) + } +} + +impl PartialEq for usize { + fn eq(&self, other: &ChannelId) -> bool { + other == self + } +} + +impl PartialEq for &ChannelId { + fn eq(&self, other: &usize) -> bool { + (**self).eq(other) + } +} + +impl PartialEq<&ChannelId> for usize { + fn eq(&self, other: &&ChannelId) -> bool { + other.eq(self) + } +} + impl Event { /// Get the time of the event pub fn time(&self) -> f64 { @@ -152,7 +224,7 @@ impl Event { pub struct Bolus { time: f64, amount: f64, - input: usize, + input: ChannelId, occasion: usize, } impl Bolus { @@ -163,11 +235,11 @@ impl Bolus { /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered /// * `input` - The compartment number receiving the dose - pub fn new(time: f64, amount: f64, input: usize, occasion: usize) -> Self { + pub fn new(time: f64, amount: f64, input: impl ToString, occasion: usize) -> Self { Bolus { time, amount, - input, + input: ChannelId::new(input), occasion, } } @@ -178,8 +250,12 @@ impl Bolus { } /// Get the compartment number that receives the bolus - pub fn input(&self) -> usize { - self.input + pub fn input(&self) -> &ChannelId { + &self.input + } + + pub fn input_index(&self) -> Option { + self.input.index() } /// Get the time of the bolus administration @@ -193,8 +269,8 @@ impl Bolus { } /// Set the compartment number that receives the bolus - pub fn set_input(&mut self, input: usize) { - self.input = input; + pub fn set_input(&mut self, input: impl ToString) { + self.input = ChannelId::new(input); } /// Set the time of the bolus administration @@ -208,7 +284,7 @@ impl Bolus { } /// Get a mutable reference to the compartment number (1-indexed) that receives the bolus - pub fn mut_input(&mut self) -> &mut usize { + pub fn mut_input(&mut self) -> &mut ChannelId { &mut self.input } @@ -235,7 +311,7 @@ impl Bolus { pub struct Infusion { time: f64, amount: f64, - input: usize, + input: ChannelId, duration: f64, occasion: usize, } @@ -248,11 +324,17 @@ impl Infusion { /// * `amount` - Total amount of drug to be administered /// * `input` - The compartment number receiving the dose /// * `duration` - Duration of the infusion in time units - pub fn new(time: f64, amount: f64, input: usize, duration: f64, occasion: usize) -> Self { + pub fn new( + time: f64, + amount: f64, + input: impl ToString, + duration: f64, + occasion: usize, + ) -> Self { Infusion { time, amount, - input, + input: ChannelId::new(input), duration, occasion, } @@ -264,8 +346,12 @@ impl Infusion { } /// Get the compartment number that receives the infusion - pub fn input(&self) -> usize { - self.input + pub fn input(&self) -> &ChannelId { + &self.input + } + + pub fn input_index(&self) -> Option { + self.input.index() } /// Get the duration of the infusion @@ -286,8 +372,8 @@ impl Infusion { } /// Set the compartment number that receives the infusion - pub fn set_input(&mut self, input: usize) { - self.input = input; + pub fn set_input(&mut self, input: impl ToString) { + self.input = ChannelId::new(input); } /// Set the time of the infusion administration @@ -306,7 +392,7 @@ impl Infusion { } /// Get a mutable reference to the compartment number (1-indexed) that receives the infusion - pub fn mut_input(&mut self) -> &mut usize { + pub fn mut_input(&mut self) -> &mut ChannelId { &mut self.input } @@ -348,7 +434,7 @@ pub enum Censor { pub struct Observation { time: f64, value: Option, - outeq: usize, + outeq: ChannelId, errorpoly: Option, occasion: usize, censoring: Censor, @@ -367,7 +453,7 @@ impl Observation { pub(crate) fn new( time: f64, value: Option, - outeq: usize, + outeq: impl ToString, errorpoly: Option, occasion: usize, censoring: Censor, @@ -375,7 +461,7 @@ impl Observation { Observation { time, value, - outeq, + outeq: ChannelId::new(outeq), errorpoly, occasion, censoring, @@ -393,8 +479,12 @@ impl Observation { } /// Get the output equation number corresponding to this observation - pub fn outeq(&self) -> usize { - self.outeq + pub fn outeq(&self) -> &ChannelId { + &self.outeq + } + + pub fn outeq_index(&self) -> Option { + self.outeq.index() } /// Get the error polynomial coefficients (c0, c1, c2, c3) if available @@ -415,8 +505,8 @@ impl Observation { } /// Set the output equation number corresponding to this observation - pub fn set_outeq(&mut self, outeq: usize) { - self.outeq = outeq; + pub fn set_outeq(&mut self, outeq: impl ToString) { + self.outeq = ChannelId::new(outeq); } /// Set the [ErrorPoly] for this observation @@ -435,7 +525,7 @@ impl Observation { } /// Get a mutable reference to the output equation number - pub fn mut_outeq(&mut self) -> &mut usize { + pub fn mut_outeq(&mut self) -> &mut ChannelId { &mut self.outeq } @@ -460,7 +550,9 @@ impl Observation { time: self.time(), observation: self.value(), prediction: pred, - outeq: self.outeq(), + outeq: self + .outeq_index() + .expect("prediction requires a resolved or numeric output label"), errorpoly: self.errorpoly(), state, occasion: self.occasion(), @@ -539,6 +631,7 @@ mod tests { assert_eq!(bolus.time(), 2.5); assert_eq!(bolus.amount(), 100.0); assert_eq!(bolus.input(), 1); + assert_eq!(bolus.input().as_str(), "1"); } #[test] @@ -561,6 +654,7 @@ mod tests { assert_eq!(infusion.time(), 1.0); assert_eq!(infusion.amount(), 200.0); assert_eq!(infusion.input(), 1); + assert_eq!(infusion.input().as_str(), "1"); assert_eq!(infusion.duration(), 2.5); } @@ -589,6 +683,7 @@ mod tests { assert_eq!(observation.time(), 5.0); assert_eq!(observation.value(), Some(75.5)); assert_eq!(observation.outeq(), 2); + assert_eq!(observation.outeq().as_str(), "2"); assert_eq!(observation.errorpoly(), error_poly); } diff --git a/src/data/parser/pmetrics.rs b/src/data/parser/pmetrics.rs index c410d689..4554e435 100644 --- a/src/data/parser/pmetrics.rs +++ b/src/data/parser/pmetrics.rs @@ -95,14 +95,14 @@ struct Row { #[serde(deserialize_with = "deserialize_option_f64")] ii: Option, /// Input compartment - #[serde(deserialize_with = "deserialize_option_usize")] - input: Option, + #[serde(deserialize_with = "deserialize_option_channel_id")] + input: Option, /// Observed value #[serde(deserialize_with = "deserialize_option_f64")] out: Option, /// Corresponding output equation for the observation - #[serde(deserialize_with = "deserialize_option_usize")] - outeq: Option, + #[serde(deserialize_with = "deserialize_option_channel_id")] + outeq: Option, /// Censoring output #[serde(default, deserialize_with = "deserialize_option_censor")] cens: Option, @@ -134,12 +134,12 @@ impl Row { dur: self.dur, addl: self.addl.map(|a| a as i64), ii: self.ii, - input: self.input, + input: self.input.clone(), // Treat -99 as missing value (Pmetrics convention) out: self .out .and_then(|v| if v == -99.0 { None } else { Some(v) }), - outeq: self.outeq, + outeq: self.outeq.clone(), cens: self.cens, c0: self.c0, c1: self.c1, @@ -196,11 +196,11 @@ where } } -fn deserialize_option_usize<'de, D>(deserializer: D) -> Result, D::Error> +fn deserialize_option_channel_id<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - deserialize_option::(deserializer) + deserialize_option::(deserializer).map(|value| value.map(ChannelId::from)) } fn deserialize_option_isize<'de, D>(deserializer: D) -> Result, D::Error> @@ -496,4 +496,50 @@ mod tests { assert_eq!(second.get(11), Some(".")); assert_eq!(second.get(14), Some(".")); } + + #[test] + fn read_pmetrics_preserves_named_channel_labels() { + let file = NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,iv,.,.,.,.,.,.,.\npt1,0,1,.,.,.,.,.,42,cp,0,.,.,.,.\n", + ) + .unwrap(); + + let data = read_pmetrics(file.path().display().to_string()).unwrap(); + let events = data.subjects()[0].occasions()[0].events(); + + match &events[0] { + Event::Infusion(infusion) => assert_eq!(infusion.input().as_str(), "iv"), + _ => panic!("expected infusion event"), + } + + match &events[1] { + Event::Observation(observation) => assert_eq!(observation.outeq().as_str(), "cp"), + _ => panic!("expected observation event"), + } + } + + #[test] + fn read_pmetrics_preserves_numeric_labels_as_strings() { + let file = NamedTempFile::new().unwrap(); + std::fs::write( + file.path(), + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,.,100,.,.,1,.,.,.,.,.,.,.\npt1,0,1,.,.,.,.,.,42,1,0,.,.,.,.\n", + ) + .unwrap(); + + let data = read_pmetrics(file.path().display().to_string()).unwrap(); + let events = data.subjects()[0].occasions()[0].events(); + + match &events[0] { + Event::Bolus(bolus) => assert_eq!(bolus.input().as_str(), "1"), + _ => panic!("expected bolus event"), + } + + match &events[1] { + Event::Observation(observation) => assert_eq!(observation.outeq().as_str(), "1"), + _ => panic!("expected observation event"), + } + } } diff --git a/src/data/row.rs b/src/data/row.rs index b3b38ad8..f6e44e98 100644 --- a/src/data/row.rs +++ b/src/data/row.rs @@ -79,11 +79,11 @@ pub struct DataRow { /// Interdose interval for ADDL pub ii: Option, /// Input compartment - pub input: Option, + pub input: Option, /// Observed value (for EVID=0) pub out: Option, /// Output equation number - pub outeq: Option, + pub outeq: Option, /// Censoring indicator pub cens: Option, /// Error polynomial coefficients @@ -180,14 +180,17 @@ impl DataRow { match self.evid { 0 => { // Observation event - events.push(Event::Observation(Observation::new( - self.time, - self.out, + let outeq = self.outeq + .clone() .ok_or_else(|| DataError::MissingObservationOuteq { id: self.id.clone(), time: self.time, - })?, // Keep 1-indexed as provided by Pmetrics + })?; + events.push(Event::Observation(Observation::new( + self.time, + self.out, + outeq, self.get_errorpoly(), 0, // occasion set later self.cens.unwrap_or(Censor::None), @@ -196,10 +199,13 @@ impl DataRow { 1 | 4 => { // Dosing event (1) or reset with dose (4) - let input = self.input.ok_or_else(|| DataError::MissingBolusInput { - id: self.id.clone(), - time: self.time, - })?; // Keep 1-indexed as provided by Pmetrics + let input = self + .input + .clone() + .ok_or_else(|| DataError::MissingBolusInput { + id: self.id.clone(), + time: self.time, + })?; let event = if self.dur.unwrap_or(0.0) > 0.0 { // Infusion @@ -371,8 +377,8 @@ impl DataRowBuilder { /// /// Required for EVID=1 (dosing events). /// Kept as 1-indexed; user must size state arrays accordingly. - pub fn input(mut self, input: usize) -> Self { - self.row.input = Some(input); + pub fn input(mut self, input: impl ToString) -> Self { + self.row.input = Some(ChannelId::new(input)); self } @@ -388,8 +394,8 @@ impl DataRowBuilder { /// /// Required for EVID=0 (observation events). /// Will be converted to 0-indexed internally. - pub fn outeq(mut self, outeq: usize) -> Self { - self.row.outeq = Some(outeq); + pub fn outeq(mut self, outeq: impl ToString) -> Self { + self.row.outeq = Some(ChannelId::new(outeq)); self } diff --git a/src/data/structs.rs b/src/data/structs.rs index 82cd3faf..c977d89a 100644 --- a/src/data/structs.rs +++ b/src/data/structs.rs @@ -180,13 +180,13 @@ impl Data { let old_events = occasion.process_events(None, true); // Create a set of existing (time, outeq) pairs for fast lookup - let existing_obs: std::collections::HashSet<(u64, usize)> = old_events + let existing_obs: std::collections::HashSet<(u64, ChannelId)> = old_events .iter() .filter_map(|event| match event { Event::Observation(obs) => { // Convert to microseconds for consistent comparison let time_key = (obs.time() * 1e6).round() as u64; - Some((time_key, obs.outeq())) + Some((time_key, obs.outeq().clone())) } _ => None, }) @@ -198,13 +198,13 @@ impl Data { while time < last_time { let time_key = (time * 1e6).round() as u64; - for &outeq in &outeq_values { + for outeq in &outeq_values { // Only add if this (time, outeq) combination doesn't exist - if !existing_obs.contains(&(time_key, outeq)) { + if !existing_obs.contains(&(time_key, outeq.clone())) { let obs = Observation::new( time, None, - outeq, + outeq.clone(), None, occasion.index, Censor::None, @@ -274,14 +274,14 @@ impl Data { } /// Get a vector of all unique output equations (outeq) across all subjects - pub fn get_output_equations(&self) -> Vec { + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let mut outeq_values: Vec = self + let mut outeq_values: Vec = self .subjects .iter() .flat_map(|subject| subject.get_output_equations()) .collect(); - outeq_values.sort_unstable(); + outeq_values.sort(); outeq_values.dedup(); outeq_values } @@ -396,14 +396,14 @@ impl Subject { self.occasions.iter_mut() } - pub fn get_output_equations(&self) -> Vec { + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let outeq_values: Vec = self + let outeq_values: Vec = self .occasions .iter() .flat_map(|occasion| { occasion.events.iter().filter_map(|event| match event { - Event::Observation(obs) => Some(obs.outeq()), + Event::Observation(obs) => Some(obs.outeq().clone()), _ => None, }) }) @@ -598,8 +598,10 @@ impl Occasion { let time = event.time(); if let Event::Bolus(bolus) = event { let lagtime = fn_lag(&spp.clone().into(), time, covariates); - if let Some(l) = lagtime.get(&bolus.input()) { - *bolus.mut_time() += l; + if let Some(input) = bolus.input_index() { + if let Some(l) = lagtime.get(&input) { + *bolus.mut_time() += l; + } } } } @@ -615,8 +617,10 @@ impl Occasion { let time = event.time(); if let Event::Bolus(bolus) = event { let fa = fn_fa(&spp.clone().into(), time, covariates); - if let Some(f) = fa.get(&bolus.input()) { - bolus.set_amount(bolus.amount() * f); + if let Some(input) = bolus.input_index() { + if let Some(f) = fa.get(&input) { + bolus.set_amount(bolus.amount() * f); + } } } } @@ -703,7 +707,7 @@ impl Occasion { &mut self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: Option, censored: Censor, ) { @@ -713,7 +717,7 @@ impl Occasion { } /// Add a missing [Observation] event to the [Occasion] - pub fn add_missing_observation(&mut self, time: f64, outeq: usize) { + pub fn add_missing_observation(&mut self, time: f64, outeq: impl ToString) { let observation = Observation::new(time, None, outeq, None, self.index, Censor::None); self.add_event(Event::Observation(observation)); } @@ -725,7 +729,7 @@ impl Occasion { &mut self, time: f64, value: f64, - outeq: usize, + outeq: impl ToString, errorpoly: ErrorPoly, censored: Censor, ) { @@ -741,13 +745,13 @@ impl Occasion { } /// Add a [Bolus] event to the [Occasion] - pub fn add_bolus(&mut self, time: f64, amount: f64, input: usize) { + pub fn add_bolus(&mut self, time: f64, amount: f64, input: impl ToString) { let bolus = Bolus::new(time, amount, input, self.index); self.add_event(Event::Bolus(bolus)); } /// Add an [Infusion] event to the [Occasion] - pub fn add_infusion(&mut self, time: f64, amount: f64, input: usize, duration: f64) { + pub fn add_infusion(&mut self, time: f64, amount: f64, input: impl ToString, duration: f64) { let infusion = Infusion::new(time, amount, input, duration, self.index); self.add_event(Event::Infusion(infusion)); } @@ -775,17 +779,6 @@ impl Occasion { .unwrap_or(0.0) } - pub(crate) fn infusions_ref(&self) -> Vec<&Infusion> { - //TODO this can be pre-computed when the struct is initially created - self.events - .iter() - .filter_map(|event| match event { - Event::Infusion(infusion) => Some(infusion), - _ => None, - }) - .collect() - } - /// Get an iterator over all events /// /// # Returns @@ -967,7 +960,7 @@ impl Occasion { for event in &self.events { if let Event::Observation(obs) = event { - if obs.outeq() == outeq { + if obs.outeq_index() == Some(outeq) { if let Some(value) = obs.value() { times.push(obs.time()); concs.push(value); diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 202fd45a..d2600e67 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -20,7 +20,7 @@ pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, }; use crate::{ - data::{Covariates, Infusion}, + data::{ChannelId, Covariates, Infusion}, simulator::{ equation::{ ode::{closure_helpers::PMProblem, ExplicitRkTableau, OdeSolver, SdirkTableau}, @@ -29,7 +29,7 @@ use crate::{ likelihood::{Prediction, SubjectPredictions}, M, V, }, - Event, Observation, PharmsolError, Subject, + Event, Observation, Occasion, PharmsolError, Subject, }; pub type DenseKernelFn = unsafe extern "C" fn( @@ -375,6 +375,16 @@ impl SharedNativeModel { Ok(()) } + fn validate_output(&self, outeq: usize) -> Result<(), PharmsolError> { + if outeq >= self.info.output_len { + return Err(PharmsolError::OuteqOutOfRange { + outeq, + nout: self.info.output_len, + }); + } + Ok(()) + } + fn validate_input_for_kind(&self, input: usize, kind: RouteKind) -> Result<(), PharmsolError> { self.validate_input(input)?; if self.route_semantics.supports_input(input, kind) { @@ -387,6 +397,62 @@ impl SharedNativeModel { ))) } + fn resolve_input_label( + &self, + label: &ChannelId, + kind: RouteKind, + ) -> Result { + if let Some(input) = self.route_index(label.as_str()) { + self.validate_input_for_kind(input, kind)?; + return Ok(input); + } + + let input = label + .index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; + self.validate_input_for_kind(input, kind)?; + Ok(input) + } + + fn resolve_output_label(&self, label: &ChannelId) -> Result { + if let Some(outeq) = self.output_index(label.as_str()) { + return Ok(outeq); + } + + let outeq = label + .index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + })?; + self.validate_output(outeq)?; + Ok(outeq) + } + + fn resolve_events(&self, occasion: &Occasion) -> Result, PharmsolError> { + let mut events = occasion.process_events(None, true); + + for event in events.iter_mut() { + match event { + Event::Bolus(bolus) => { + let input = self.resolve_input_label(bolus.input(), RouteKind::Bolus)?; + bolus.set_input(input); + } + Event::Infusion(infusion) => { + let input = self.resolve_input_label(infusion.input(), RouteKind::Infusion)?; + infusion.set_input(input); + } + Event::Observation(observation) => { + let outeq = self.resolve_output_label(observation.outeq())?; + observation.set_outeq(outeq); + } + } + } + + Ok(events) + } + fn fill_cov_buffer(&self, covariates: &Covariates, time: f64, buf: &mut [f64]) { for covariate in &self.info.covariates { buf[covariate.index] = match covariates.get_covariate(&covariate.name) { @@ -530,7 +596,13 @@ impl SharedNativeModel { for event in events.iter_mut() { if let Event::Bolus(bolus) = event { - self.validate_input_for_kind(bolus.input(), RouteKind::Bolus)?; + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + self.validate_input_for_kind(input, RouteKind::Bolus)?; if self.artifact.has_kernel(KernelRole::RouteLag) { lag_values.fill(0.0); @@ -556,7 +628,7 @@ impl SharedNativeModel { lag_values.as_mut_ptr(), )?; } - let lag = lag_values[bolus.input()]; + let lag = lag_values[input]; if lag != 0.0 { *bolus.mut_time() += lag; } @@ -586,7 +658,7 @@ impl SharedNativeModel { fa_values.as_mut_ptr(), )?; } - let factor = fa_values[bolus.input()]; + let factor = fa_values[input]; if factor != 1.0 { bolus.set_amount(bolus.amount() * factor); } @@ -651,13 +723,13 @@ impl SharedNativeModel { &cov_buf, &mut outputs, )?; - if observation.outeq() >= outputs.len() { - return Err(PharmsolError::OuteqOutOfRange { - outeq: observation.outeq(), - nout: outputs.len(), - }); - } - Ok(observation.to_prediction(outputs[observation.outeq()], state.to_vec())) + let outeq = observation + .outeq_index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + })?; + self.validate_output(outeq)?; + Ok(observation.to_prediction(outputs[outeq], state.to_vec())) } } @@ -734,18 +806,15 @@ impl NativeOdeModel { let support_vector: V = DVector::from_vec(support_point.to_vec()).into(); for occasion in subject.occasions() { - let infusion_refs = occasion.infusions_ref(); - let infusions = infusion_refs + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared - .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; - } - - let mut events = occasion.process_events(None, true); + let infusion_refs = infusions.iter().collect::>(); let session = RefCell::new(self.shared.artifact.start_session()?); let mut route_session = session.borrow_mut(); self.shared.apply_route_properties( @@ -901,9 +970,15 @@ impl NativeOdeModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; self.shared.apply_bolus( solver.state_mut().y.as_mut_slice(), - bolus.input(), + input, bolus.amount(), )?; } @@ -1000,18 +1075,14 @@ impl NativeAnalyticalModel { let mut output = SubjectPredictions::default(); for occasion in subject.occasions() { - let infusions = occasion - .infusions_ref() + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared - .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; - } - - let mut events = occasion.process_events(None, true); let mut session = self.shared.artifact.start_session()?; self.shared.apply_route_properties( &mut *session, @@ -1030,8 +1101,12 @@ impl NativeAnalyticalModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { - self.shared - .apply_bolus(&mut state, bolus.input(), bolus.amount())? + let input = bolus.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + } + })?; + self.shared.apply_bolus(&mut state, input, bolus.amount())? } Event::Infusion(_) => {} Event::Observation(observation) => { @@ -1171,18 +1246,14 @@ impl NativeSdeModel { let mut output = Array2::from_shape_fn((self.nparticles, 0), |_| Prediction::default()); for occasion in subject.occasions() { - let infusions = occasion - .infusions_ref() + let mut events = self.shared.resolve_events(occasion)?; + let infusions = events .iter() - .map(|infusion| (*infusion).clone()) + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion.clone()), + _ => None, + }) .collect::>(); - - for infusion in &infusions { - self.shared - .validate_input_for_kind(infusion.input(), RouteKind::Infusion)?; - } - - let mut events = occasion.process_events(None, true); let mut session = self.shared.artifact.start_session()?; self.shared.apply_route_properties( &mut *session, @@ -1204,10 +1275,15 @@ impl NativeSdeModel { for (index, event) in events.iter().enumerate() { match event { Event::Bolus(bolus) => { + let input = bolus.input_index().ok_or_else(|| { + PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + } + })?; for particle in &mut particles { self.shared.apply_bolus( particle.as_mut_slice(), - bolus.input(), + input, bolus.amount(), )?; } @@ -1398,11 +1474,14 @@ impl NativeSdeModel { fn active_route_inputs(infusions: &[Infusion], time: f64, route_len: usize) -> Vec { let mut values = vec![0.0; route_len]; for infusion in infusions { - if infusion.input() < route_len + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + if input < route_len && time >= infusion.time() && time <= infusion.time() + infusion.duration() { - values[infusion.input()] += infusion.amount() / infusion.duration(); + values[input] += infusion.amount() / infusion.duration(); } } values @@ -1417,8 +1496,11 @@ fn interval_route_inputs( let mut values = vec![0.0; route_len]; for infusion in infusions { let finish = infusion.time() + infusion.duration(); - if infusion.input() < route_len && start_time >= infusion.time() && end_time <= finish { - values[infusion.input()] += infusion.amount() / infusion.duration(); + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + if input < route_len && start_time >= infusion.time() && end_time <= finish { + values[input] += infusion.amount() / infusion.duration(); } } values diff --git a/src/error/mod.rs b/src/error/mod.rs index 1316b8a4..5145626e 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -37,6 +37,10 @@ pub enum PharmsolError { ZeroLikelihood, #[error("Missing observation in prediction")] MissingObservation, + #[error("Input label `{label}` could not be resolved to a route channel")] + UnknownInputLabel { label: String }, + #[error("Output label `{label}` could not be resolved to an output channel")] + UnknownOutputLabel { label: String }, #[error("Input channel {input} is out of range (ndrugs = {ndrugs})")] InputOutOfRange { input: usize, ndrugs: usize }, #[error("Output equation {outeq} is out of range (nout = {nout})")] diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index 4734886c..b0d78481 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -278,6 +278,11 @@ impl EquationPriv for Analytical { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -321,13 +326,19 @@ impl EquationPriv for Analytical { let s = inf.time(); let e = s + inf.duration(); if current_t >= s && next_t <= e { - if inf.input() >= self.get_ndrugs() { + let input = + inf.input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: inf.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: inf.input(), + input, ndrugs: self.get_ndrugs(), }); } - rateiv[inf.input()] += inf.amount() / inf.duration(); + rateiv[input] += inf.amount() / inf.duration(); } } @@ -365,7 +376,12 @@ impl EquationPriv for Analytical { covariates, &mut y, ); - let pred = y[observation.outeq()]; + let outeq = observation + .outeq_index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + })?; + let pred = y[outeq]; let pred = observation.to_prediction(pred, x.as_slice().to_vec()); if let Some(error_models) = error_models { likelihood.push(pred.log_likelihood(error_models)?.exp()); diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index 60cb2d8f..f3532382 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -12,7 +12,7 @@ pub use sde::*; use crate::{ error_model::AssayErrorModels, simulator::{Fa, Lag}, - Covariates, Event, Infusion, Observation, PharmsolError, Subject, + ChannelId, Covariates, Event, Infusion, Observation, Occasion, PharmsolError, Subject, }; use super::likelihood::Prediction; @@ -129,6 +129,7 @@ pub(crate) trait EquationPriv: EquationTypes { fn get_nstates(&self) -> usize; fn get_ndrugs(&self) -> usize; fn get_nouteqs(&self) -> usize; + fn metadata(&self) -> Option<&ValidatedModelMetadata>; fn solve( &self, state: &mut Self::S, @@ -141,6 +142,85 @@ pub(crate) trait EquationPriv: EquationTypes { fn nparticles(&self) -> usize { 1 } + + fn resolve_input_label( + &self, + label: &ChannelId, + expected_kind: RouteKind, + ) -> Result { + if let Some(metadata) = self.metadata() { + let route = + metadata + .route(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; + + if route.kind() != expected_kind { + return Err(PharmsolError::OtherError(format!( + "input label `{}` is declared as {:?} but used as {:?}", + label, + route.kind(), + expected_kind + ))); + } + + return Ok(route.channel_index()); + } + + label + .index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + }) + } + + fn resolve_output_label(&self, label: &ChannelId) -> Result { + if let Some(metadata) = self.metadata() { + return metadata.output_index(label.as_str()).ok_or_else(|| { + PharmsolError::UnknownOutputLabel { + label: label.to_string(), + } + }); + } + + label + .index() + .ok_or_else(|| PharmsolError::UnknownOutputLabel { + label: label.to_string(), + }) + } + + fn resolve_occasion_events( + &self, + occasion: &Occasion, + support_point: &[f64], + covariates: &Covariates, + ) -> Result, PharmsolError> { + let mut resolved = occasion.clone(); + + for event in resolved.events_iter_mut() { + match event { + Event::Bolus(bolus) => { + let input = self.resolve_input_label(bolus.input(), RouteKind::Bolus)?; + bolus.set_input(input); + } + Event::Infusion(infusion) => { + let input = self.resolve_input_label(infusion.input(), RouteKind::Infusion)?; + infusion.set_input(input); + } + Event::Observation(observation) => { + let outeq = self.resolve_output_label(observation.outeq())?; + observation.set_outeq(outeq); + } + } + } + + Ok(resolved.process_events( + Some((self.fa(), self.lag(), support_point, covariates)), + true, + )) + } #[allow(dead_code)] fn is_sde(&self) -> bool { false @@ -181,13 +261,20 @@ pub(crate) trait EquationPriv: EquationTypes { ) -> Result<(), PharmsolError> { match event { Event::Bolus(bolus) => { - if bolus.input() >= self.get_ndrugs() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: self.get_ndrugs(), }); } - x.add_bolus(bolus.input(), bolus.amount()); + x.add_bolus(input, bolus.amount()); } Event::Infusion(infusion) => { infusions.push(infusion.clone()); @@ -332,10 +419,7 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { let mut x = self.initial_state(support_point, covariates, occasion.index()); let mut infusions = Vec::new(); - let events = occasion.process_events( - Some((self.fa(), self.lag(), support_point, covariates)), - true, - ); + let events = self.resolve_occasion_events(occasion, support_point, covariates)?; for (index, event) in events.iter().enumerate() { self.simulate_event( support_point, diff --git a/src/simulator/equation/ode/closure.rs b/src/simulator/equation/ode/closure.rs index eed65e7a..cb9c0726 100644 --- a/src/simulator/equation/ode/closure.rs +++ b/src/simulator/equation/ode/closure.rs @@ -80,7 +80,11 @@ impl InfusionSchedule { continue; } - let input = infusion.input(); + let input = infusion + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: infusion.input().to_string(), + })?; if input >= ndrugs { return Err(PharmsolError::InputOutOfRange { input, ndrugs }); } diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index cafe6a96..853b3108 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -330,6 +330,11 @@ impl EquationPriv for ODE { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -397,14 +402,21 @@ impl ODE { match event { Event::Bolus(bolus) => { - if bolus.input() >= bolus_v.len() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= bolus_v.len() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: bolus_v.len(), }); } bolus_v.fill(0.0); - bolus_v[bolus.input()] = bolus.amount(); + bolus_v[input] = bolus.amount(); state_with_bolus.fill(0.0); state_without_bolus.fill(0.0); @@ -444,7 +456,12 @@ impl ODE { covariates, y_out, ); - let pred = y_out[observation.outeq()]; + let outeq = observation.outeq_index().ok_or_else(|| { + PharmsolError::UnknownOutputLabel { + label: observation.outeq().to_string(), + } + })?; + let pred = y_out[outeq]; let pred = observation.to_prediction(pred, solver.state().y.as_slice().to_vec()); if let Some(error_models) = error_models { @@ -550,11 +567,14 @@ impl Equation for ODE { // Iterate over occasions for occasion in subject.occasions() { let covariates = occasion.covariates(); - let infusions = occasion.infusions_ref(); - let events = occasion.process_events( - Some((self.fa(), self.lag(), support_point, covariates)), - true, - ); + let events = self.resolve_occasion_events(occasion, support_point, covariates)?; + let infusions = events + .iter() + .filter_map(|event| match event { + Event::Infusion(infusion) => Some(infusion), + _ => None, + }) + .collect::>(); let problem = OdeBuilder::::new() .atol(vec![self.atol]) @@ -680,9 +700,9 @@ mod tests { fn route_policy_subject() -> Subject { Subject::builder("route_policy") - .bolus(0.0, 100.0, 0) - .infusion(0.0, 100.0, 0, 1.0) - .observation(1.0, 0.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(0.0, 100.0, "iv", 1.0) + .observation(1.0, 0.0, "cp") .build() } diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index bdafbbc3..c24b615d 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -124,7 +124,10 @@ fn simulate_sde_event( let mut rateiv = V::zeros(ndrugs, NalgebraContext); for infusion in &infusion_events { if time >= infusion.time() && time <= infusion.duration() + infusion.time() { - rateiv[infusion.input()] += infusion.amount() / infusion.duration(); + let input = infusion + .input_index() + .expect("resolved infusions should use numeric input labels"); + rateiv[input] += infusion.amount() / infusion.duration(); } } @@ -466,6 +469,11 @@ impl EquationPriv for SDE { fn get_nouteqs(&self) -> usize { self.neqs.nout } + + fn metadata(&self) -> Option<&ValidatedModelMetadata> { + self.metadata.as_ref() + } + #[inline(always)] fn solve( &self, @@ -524,7 +532,10 @@ impl EquationPriv for SDE { covariates, &mut y, ); - *p = observation.to_prediction(y[observation.outeq()], x[i].as_slice().to_vec()); + let outeq = observation + .outeq_index() + .expect("resolved observations should use numeric output labels"); + *p = observation.to_prediction(y[outeq], x[i].as_slice().to_vec()); }); let out = Array2::from_shape_vec((self.nparticles, 1), pred.clone())?; *output = concatenate(Axis(1), &[output.view(), out.view()]).unwrap(); @@ -588,17 +599,21 @@ impl EquationPriv for SDE { ) -> Result<(), PharmsolError> { match event { crate::Event::Bolus(bolus) => { - if bolus.input() >= self.get_ndrugs() { + let input = + bolus + .input_index() + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: bolus.input().to_string(), + })?; + + if input >= self.get_ndrugs() { return Err(PharmsolError::InputOutOfRange { - input: bolus.input(), + input, ndrugs: self.get_ndrugs(), }); } - if !self - .injected_bolus_mappings - .apply(x, bolus.input(), bolus.amount()) - { - x.add_bolus(bolus.input(), bolus.amount()); + if !self.injected_bolus_mappings.apply(x, input, bolus.amount()) { + x.add_bolus(input, bolus.amount()); } } crate::Event::Infusion(infusion) => { @@ -909,8 +924,8 @@ mod tests { .expect("injected metadata should validate"); let subject = Subject::builder("bolus_route") - .bolus(0.0, 100.0, 0) - .missing_observation(0.1, 0) + .bolus(0.0, 100.0, "oral") + .missing_observation(0.1, "cp") .build(); let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); @@ -954,8 +969,8 @@ mod tests { .expect("injected metadata should validate"); let subject = Subject::builder("infusion_route") - .infusion(0.0, 100.0, 0, 1.0) - .missing_observation(1.0, 0) + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(1.0, "cp") .build(); let explicit_predictions = explicit.estimate_predictions(&subject, &[0.0]).unwrap(); diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs index e025ec4f..c55719f5 100644 --- a/tests/analytical_macro_lowering.rs +++ b/tests/analytical_macro_lowering.rs @@ -1,47 +1,47 @@ use approx::assert_relative_eq; use pharmsol::prelude::*; -fn infusion_subject(input: usize) -> Subject { +fn infusion_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("analytical-macro-iv") .infusion(0.0, 120.0, input, 1.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn oral_subject(input: usize) -> Subject { +fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("analytical-macro-oral") .bolus(0.0, 100.0, input) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn shared_channel_subject(input: usize) -> Subject { +fn shared_channel_subject() -> Subject { Subject::builder("analytical-macro-shared") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(6.5, 0) - .missing_observation(7.0, 0) - .missing_observation(8.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } -fn covariate_subject(oral: usize, iv: usize, cp: usize) -> Subject { +fn covariate_subject(oral: impl ToString, iv: impl ToString, cp: impl ToString) -> Subject { Subject::builder("analytical-macro-covariates") .bolus(1.0, 100.0, oral) .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) + .missing_observation(0.25, cp.to_string()) + .missing_observation(0.75, cp.to_string()) + .missing_observation(1.5, cp.to_string()) + .missing_observation(3.0, cp.to_string()) + .missing_observation(6.5, cp.to_string()) + .missing_observation(7.0, cp.to_string()) .missing_observation(8.0, cp) .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) @@ -382,7 +382,7 @@ fn assert_prediction_match(left: &[f64], right: &[f64]) { fn analytical_macro_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_one_compartment(); let handwritten_model = handwritten_one_compartment(); - let subject = infusion_subject(0); + let subject = infusion_subject("iv", "cp"); let support_point = [0.2, 10.0]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -408,7 +408,7 @@ fn analytical_macro_lowering_matches_handwritten_metadata_and_predictions() { fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { let macro_model = macro_one_compartment_with_absorption(); let handwritten_model = handwritten_one_compartment_with_absorption(); - let subject = oral_subject(0); + let subject = oral_subject("oral", "cp"); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -441,7 +441,7 @@ fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_shared_channel_analytical(); let handwritten_model = handwritten_shared_channel_analytical(); - let subject = shared_channel_subject(0); + let subject = shared_channel_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -472,14 +472,9 @@ fn analytical_macro_covariates_lower_to_handwritten_behavior() { assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - let oral = macro_model.route_index("oral").expect("oral route exists"); - let iv = macro_model.route_index("iv").expect("iv route exists"); - let cp = macro_model.output_index("cp").expect("cp output exists"); - let subject = covariate_subject(oral, iv, cp); + let subject = covariate_subject("oral", "iv", "cp"); let support_point = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; - assert_eq!(oral, iv); - let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) .expect("macro covariate analytical model should simulate") diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index 43621e8a..67f91c7a 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -1,3 +1,4 @@ +use approx::assert_relative_eq; #[cfg(feature = "dsl-jit")] use pharmsol::dsl::{self, RuntimeCompilationTarget, RuntimePredictions}; #[cfg(feature = "dsl-jit")] @@ -88,6 +89,24 @@ dx(central) = ka * depot - ke * central out(cp) = central / v ~ continuous() "#; +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_MIXED_OUTPUT_LABELS_DSL: &str = r#" +name = mixed_output_labels_runtime +kind = ode + +params = ke, v +states = central +outputs = cp, 0, 1 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +out(0) = 2 * central / v ~ continuous() +out(1) = 3 * central / v ~ continuous() +"#; + const ANALYTICAL_DSL: &str = r#" name = one_cmt_abs_parity kind = analytical @@ -267,16 +286,16 @@ fn compile_runtime_jit_model(src: &str, model_name: &str) -> dsl::CompiledRuntim } #[cfg(feature = "dsl-jit")] -fn shared_channel_prediction_subject(input: usize, output: usize) -> Subject { +fn shared_channel_prediction_subject() -> Subject { Subject::builder("authoring-parity-shared-channel") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, output) - .missing_observation(1.0, output) - .missing_observation(2.0, output) - .missing_observation(6.5, output) - .missing_observation(7.0, output) - .missing_observation(8.0, output) + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } @@ -1192,11 +1211,12 @@ fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_channel_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(macro_model.route_index("oral"), Some(oral)); assert_eq!(macro_model.route_index("iv"), Some(iv)); assert_eq!(handwritten_model.route_index("oral"), Some(oral)); @@ -1241,11 +1261,12 @@ fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_chan let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_channel_prediction_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(macro_model.route_index("oral"), Some(oral)); assert_eq!(macro_model.route_index("iv"), Some(iv)); assert_eq!(handwritten_model.route_index("oral"), Some(oral)); @@ -1292,11 +1313,12 @@ fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_channel_prediction_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(macro_model.route_index("oral"), Some(oral)); assert_eq!(macro_model.route_index("iv"), Some(iv)); assert_eq!(handwritten_model.route_index("oral"), Some(oral)); @@ -1340,11 +1362,12 @@ fn route_input_policy_runtime_mismatches_are_detected_explicitly() { let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(oral, cp); + let subject = shared_channel_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); assert_eq!(iv, oral); + assert_eq!(cp, 0); assert_eq!(mismatched_model.route_index("oral"), Some(oral)); assert_eq!(mismatched_model.route_index("iv"), Some(iv)); @@ -1363,3 +1386,35 @@ fn route_input_policy_runtime_mismatches_are_detected_explicitly() { assert_prediction_vectors_diverge(&runtime_predictions, &mismatched_predictions, 1e-4); } + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_preserves_mixed_output_labels() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_MIXED_OUTPUT_LABELS_DSL, + "mixed_output_labels_runtime", + ); + let subject = Subject::builder("runtime-mixed-output-labels") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "cp") + .missing_observation(0.5, "0") + .missing_observation(0.5, "1") + .build(); + let support_point = [0.2, 10.0]; + + assert_eq!(runtime_model.output_index("cp"), Some(0)); + assert_eq!(runtime_model.output_index("0"), Some(1)); + assert_eq!(runtime_model.output_index("1"), Some(2)); + + let predictions = match runtime_model + .estimate_predictions(&subject, &support_point) + .expect("runtime mixed-output model should simulate") + { + RuntimePredictions::Subject(predictions) => predictions.flat_predictions().to_vec(), + RuntimePredictions::Particles(_) => panic!("ODE runtime should return subject predictions"), + }; + + assert_eq!(predictions.len(), 3); + assert_relative_eq!(predictions[1], 2.0 * predictions[0], epsilon = 1e-6); + assert_relative_eq!(predictions[2], 3.0 * predictions[0], epsilon = 1e-6); +} diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index 7b068733..a556f428 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -1,38 +1,55 @@ use approx::assert_relative_eq; +use pharmsol::prelude::data::read_pmetrics; use pharmsol::prelude::*; +use tempfile::NamedTempFile; -fn subject_for_route(input: usize) -> Subject { +fn write_pmetrics_fixture(contents: &str) -> NamedTempFile { + let file = NamedTempFile::new().expect("create temporary Pmetrics fixture"); + std::fs::write(file.path(), contents).expect("write temporary Pmetrics fixture"); + file +} + +fn subject_for_route(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("macro-lowering") .infusion(0.0, 100.0, input, 1.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn subject_for_shared_channel(input: usize) -> Subject { +fn subject_for_shared_channel() -> Subject { Subject::builder("macro-shared-channel") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(6.5, 0) - .missing_observation(7.0, 0) - .missing_observation(8.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } -fn subject_for_covariates(input: usize) -> Subject { +fn subject_for_covariates(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("macro-covariates") .bolus(0.0, 100.0, input) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .covariate("wt", 0.0, 70.0) .build() } +fn subject_for_numeric_bolus_route(input: impl ToString, outeq: impl ToString) -> Subject { + Subject::builder("numeric-bolus-route") + .bolus(0.0, 100.0, input) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) + .build() +} + fn injected_macro_ode() -> equation::ODE { ode! { name: "injected_one_cpt", @@ -131,6 +148,55 @@ fn explicit_handwritten_ode() -> equation::ODE { .expect("handwritten explicit metadata should validate") } +fn numeric_label_macro_ode() -> equation::ODE { + ode! { + name: "numeric_label_one_cpt", + params: [ke, v], + states: [central], + outputs: [1], + routes: { + infusion(1) -> central, + }, + diffeq: |x, _t, dx, _bolus, rateiv| { + dx[central] = rateiv[1] - ke * x[central]; + }, + out: |x, _t, y| { + y[1] = x[central] / v; + }, + } +} + +fn numeric_label_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("numeric_label_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["1"]) + .route( + equation::Route::infusion("1") + .to_state("central") + .expect_explicit_input(), + ), + ) + .expect("handwritten numeric-label metadata should validate") +} + fn shared_channel_macro_ode() -> equation::ODE { ode! { name: "shared_channel_one_cpt", @@ -200,6 +266,124 @@ fn shared_channel_handwritten_ode() -> equation::ODE { .expect("handwritten shared-channel metadata should validate") } +fn numeric_route_property_macro_ode() -> equation::ODE { + ode! { + name: "numeric_route_property_one_cpt", + params: [ka, ke, v, tlag, f_oral], + states: [depot, central], + outputs: [1], + routes: { + bolus(1) -> depot, + }, + diffeq: |x, _t, dx, bolus, _rateiv| { + dx[depot] = bolus[1] - ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; + }, + lag: |_t| { + lag! { 1 => tlag } + }, + fa: |_t| { + fa! { 1 => f_oral } + }, + out: |x, _t, y| { + y[1] = x[central] / v; + }, + } +} + +fn numeric_route_property_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, bolus, _rateiv, _cov| { + fetch_params!(p, ka, ke, _v, _tlag, _f_oral); + dx[0] = bolus[0] - ka * x[0]; + dx[1] = ka * x[0] - ke * x[1]; + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, tlag, _f_oral); + lag! { 0 => tlag } + }, + |p, _t, _cov| { + fetch_params!(p, _ka, _ke, _v, _tlag, f_oral); + fa! { 0 => f_oral } + }, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ka, _ke, v, _tlag, _f_oral); + y[0] = x[1] / v; + }, + ) + .with_nstates(2) + .with_ndrugs(1) + .with_nout(1) + .with_metadata( + equation::metadata::new("numeric_route_property_one_cpt") + .parameters(["ka", "ke", "v", "tlag", "f_oral"]) + .states(["depot", "central"]) + .outputs(["1"]) + .route( + equation::Route::bolus("1") + .to_state("depot") + .with_lag() + .with_bioavailability() + .expect_explicit_input(), + ), + ) + .expect("handwritten numeric route-property metadata should validate") +} + +fn mixed_output_labels_macro_ode() -> equation::ODE { + ode! { + name: "mixed_output_labels_one_cpt", + params: [ke, v], + states: [central], + outputs: [cp, 0, 1], + routes: { + infusion(iv) -> central, + }, + diffeq: |x, _t, dx, _bolus, rateiv| { + dx[central] = rateiv[iv] - ke * x[central]; + }, + out: |x, _t, y| { + y[cp] = x[central] / v; + y[0] = 2.0 * x[central] / v; + y[1] = 3.0 * x[central] / v; + }, + } +} + +fn mixed_output_labels_handwritten_ode() -> equation::ODE { + equation::ODE::new( + |x, p, _t, dx, _bolus, rateiv, _cov| { + fetch_params!(p, ke, _v); + dx[0] = rateiv[0] - ke * x[0]; + }, + |_p, _t, _cov| lag! {}, + |_p, _t, _cov| fa! {}, + |_p, _t, _cov, _x| {}, + |x, p, _t, _cov, y| { + fetch_params!(p, _ke, v); + y[0] = x[0] / v; + y[1] = 2.0 * x[0] / v; + y[2] = 3.0 * x[0] / v; + }, + ) + .with_nstates(1) + .with_ndrugs(1) + .with_nout(3) + .with_metadata( + equation::metadata::new("mixed_output_labels_one_cpt") + .parameters(["ke", "v"]) + .states(["central"]) + .outputs(["cp", "0", "1"]) + .route( + equation::Route::infusion("iv") + .to_state("central") + .expect_explicit_input(), + ), + ) + .expect("handwritten mixed-output metadata should validate") +} + fn covariate_macro_ode() -> equation::ODE { ode! { name: "covariate_one_cpt", @@ -267,7 +451,7 @@ fn assert_prediction_match(left: &[f64], right: &[f64]) { fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { let macro_ode = injected_macro_ode(); let handwritten_ode = injected_handwritten_ode(); - let subject = subject_for_route(0); + let subject = subject_for_route("iv", "cp"); let support_point = [0.2, 10.0]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); @@ -293,7 +477,7 @@ fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { fn macro_explicit_lowering_matches_handwritten_metadata_and_predictions() { let macro_ode = explicit_macro_ode(); let handwritten_ode = explicit_handwritten_ode(); - let subject = subject_for_route(0); + let subject = subject_for_route("iv", "cp"); let support_point = [0.2, 10.0]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); @@ -315,11 +499,37 @@ fn macro_explicit_lowering_matches_handwritten_metadata_and_predictions() { assert_prediction_match(¯o_predictions, &handwritten_predictions); } +#[test] +fn macro_numeric_labels_lower_to_dense_slots() { + let macro_ode = numeric_label_macro_ode(); + let handwritten_ode = numeric_label_handwritten_ode(); + let subject = subject_for_route("1", "1"); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("1"), Some(0)); + assert_eq!(macro_ode.output_index("1"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(0)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro numeric-label model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten numeric-label model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + #[test] fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { let macro_ode = shared_channel_macro_ode(); let handwritten_ode = shared_channel_handwritten_ode(); - let subject = subject_for_shared_channel(0); + let subject = subject_for_shared_channel(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); @@ -343,11 +553,119 @@ fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() assert_prediction_match(¯o_predictions, &handwritten_predictions); } +#[test] +fn macro_mixed_output_labels_lower_to_dense_slots() { + let macro_ode = mixed_output_labels_macro_ode(); + let handwritten_ode = mixed_output_labels_handwritten_ode(); + let subject = Subject::builder("mixed-output-labels") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "0") + .missing_observation(2.0, "1") + .build(); + let support_point = [0.2, 10.0]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.output_index("cp"), Some(0)); + assert_eq!(macro_ode.output_index("0"), Some(1)); + assert_eq!(macro_ode.output_index("1"), Some(2)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro mixed-output model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten mixed-output model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_numeric_route_properties_lower_to_dense_slots() { + let macro_ode = numeric_route_property_macro_ode(); + let handwritten_ode = numeric_route_property_handwritten_ode(); + let subject = subject_for_numeric_bolus_route("1", "1"); + let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; + + assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); + assert_eq!(macro_ode.route_index("1"), Some(0)); + assert_eq!(macro_ode.output_index("1"), Some(0)); + assert_eq!(macro_ode.state_index("depot"), Some(0)); + assert_eq!(macro_ode.state_index("central"), Some(1)); + + let macro_predictions = macro_ode + .estimate_predictions(&subject, &support_point) + .expect("macro numeric route-property model should simulate") + .flat_predictions() + .to_vec(); + let handwritten_predictions = handwritten_ode + .estimate_predictions(&subject, &support_point) + .expect("handwritten numeric route-property model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(¯o_predictions, &handwritten_predictions); +} + +#[test] +fn macro_named_labels_resolve_from_pmetrics_ingestion() { + let file = write_pmetrics_fixture( + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,iv,.,.,.,.,.,.,.\npt1,0,0.5,.,.,.,.,.,.,cp,0,.,.,.,.\npt1,0,1.0,.,.,.,.,.,.,cp,0,.,.,.,.\npt1,0,2.0,.,.,.,.,.,.,cp,0,.,.,.,.\n", + ); + + let data = + read_pmetrics(file.path().display().to_string()).expect("read named-label Pmetrics data"); + let subject = &data.subjects()[0]; + let support_point = [0.2, 10.0]; + + let pmetrics_predictions = explicit_macro_ode() + .estimate_predictions(subject, &support_point) + .expect("macro named-label model should simulate") + .flat_predictions() + .to_vec(); + let manual_predictions = explicit_macro_ode() + .estimate_predictions(&subject_for_route("iv", "cp"), &support_point) + .expect("macro internal-index model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(&pmetrics_predictions, &manual_predictions); +} + +#[test] +fn macro_numeric_labels_resolve_from_pmetrics_ingestion() { + let file = write_pmetrics_fixture( + "ID,EVID,TIME,DUR,DOSE,ADDL,II,INPUT,OUT,OUTEQ,CENS,C0,C1,C2,C3\npt1,1,0,1,100,.,.,1,.,.,.,.,.,.,.\npt1,0,0.5,.,.,.,.,.,.,1,0,.,.,.,.\npt1,0,1.0,.,.,.,.,.,.,1,0,.,.,.,.\npt1,0,2.0,.,.,.,.,.,.,1,0,.,.,.,.\n", + ); + + let data = + read_pmetrics(file.path().display().to_string()).expect("read numeric-label Pmetrics data"); + let subject = &data.subjects()[0]; + let support_point = [0.2, 10.0]; + + let pmetrics_predictions = numeric_label_macro_ode() + .estimate_predictions(subject, &support_point) + .expect("macro numeric-label model should simulate") + .flat_predictions() + .to_vec(); + let manual_predictions = numeric_label_macro_ode() + .estimate_predictions(&subject_for_route("1", "1"), &support_point) + .expect("macro internal-index numeric-label model should simulate") + .flat_predictions() + .to_vec(); + + assert_prediction_match(&pmetrics_predictions, &manual_predictions); +} + #[test] fn macro_covariate_lowering_matches_handwritten_metadata_and_predictions() { let macro_ode = covariate_macro_ode(); let handwritten_ode = covariate_handwritten_ode(); - let subject = subject_for_covariates(0); + let subject = subject_for_covariates("oral", "cp"); let support_point = [1.0, 0.2, 10.0]; let macro_metadata = macro_ode .metadata() diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 876d2b23..13d21a2b 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -2,47 +2,47 @@ use approx::assert_relative_eq; use pharmsol::prelude::*; use pharmsol::Predictions; -fn infusion_subject(input: usize) -> Subject { +fn infusion_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("sde-macro-iv") .infusion(0.0, 120.0, input, 1.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn oral_subject(input: usize) -> Subject { +fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { Subject::builder("sde-macro-oral") .bolus(0.0, 100.0, input) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) + .missing_observation(0.5, outeq.to_string()) + .missing_observation(1.0, outeq.to_string()) + .missing_observation(2.0, outeq) .build() } -fn shared_channel_subject(input: usize) -> Subject { +fn shared_channel_subject() -> Subject { Subject::builder("sde-macro-shared") - .bolus(0.0, 100.0, input) - .infusion(6.0, 60.0, input, 2.0) - .missing_observation(0.5, 0) - .missing_observation(1.0, 0) - .missing_observation(2.0, 0) - .missing_observation(6.5, 0) - .missing_observation(7.0, 0) - .missing_observation(8.0, 0) + .bolus(0.0, 100.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") .build() } -fn covariate_subject(oral: usize, iv: usize, cp: usize) -> Subject { +fn covariate_subject(oral: impl ToString, iv: impl ToString, cp: impl ToString) -> Subject { Subject::builder("sde-macro-covariates") .bolus(1.0, 100.0, oral) .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) + .missing_observation(0.25, cp.to_string()) + .missing_observation(0.75, cp.to_string()) + .missing_observation(1.5, cp.to_string()) + .missing_observation(3.0, cp.to_string()) + .missing_observation(6.5, cp.to_string()) + .missing_observation(7.0, cp.to_string()) .missing_observation(8.0, cp) .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) @@ -491,7 +491,7 @@ fn handwritten_covariate_sde() -> equation::SDE { fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_infusion_sde(); let handwritten_model = handwritten_infusion_sde(); - let subject = infusion_subject(0); + let subject = infusion_subject("iv", "cp"); let support_point = [0.2, 0.0, 10.0]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -516,7 +516,7 @@ fn sde_macro_lowering_matches_handwritten_metadata_and_predictions() { fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { let macro_model = macro_absorption_sde(); let handwritten_model = handwritten_absorption_sde(); - let subject = oral_subject(0); + let subject = oral_subject("oral", "cp"); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -541,7 +541,7 @@ fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { let macro_model = macro_shared_channel_sde(); let handwritten_model = handwritten_shared_channel_sde(); - let subject = shared_channel_subject(0); + let subject = shared_channel_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -571,14 +571,9 @@ fn sde_macro_covariates_lower_to_handwritten_behavior() { assert_eq!(macro_model.metadata(), handwritten_model.metadata()); - let oral = macro_model.route_index("oral").expect("oral route exists"); - let iv = macro_model.route_index("iv").expect("iv route exists"); - let cp = macro_model.output_index("cp").expect("cp output exists"); - let subject = covariate_subject(oral, iv, cp); + let subject = covariate_subject("oral", "iv", "cp"); let support_point = [1.0, 0.16, 0.0, 32.0, 0.5, 0.8, 3.0, 14.0]; - assert_eq!(oral, iv); - let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) .expect("macro covariate SDE should simulate"); From 8a94a3180eeabbceb55316af7b16b2b9b871abfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 13:57:19 +0100 Subject: [PATCH 02/12] chore: update test --- tests/full_feature_macro_parity.rs | 52 +++++++++++++++--------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs index 71a1afa7..e3175f84 100644 --- a/tests/full_feature_macro_parity.rs +++ b/tests/full_feature_macro_parity.rs @@ -197,19 +197,19 @@ fn handwritten_ode_model() -> equation::ODE { .expect("handwritten ODE metadata should validate") } -fn build_ode_subject(oral: usize, load: usize, iv: usize, cp: usize) -> Subject { +fn build_ode_subject() -> Subject { Subject::builder("macro-vs-handwritten-ode-full-features") - .bolus(0.0, 80.0, load) - .bolus(1.0, 120.0, oral) - .infusion(6.0, 150.0, iv, 2.5) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 80.0, "load") + .bolus(1.0, 120.0, "oral") + .infusion(6.0, 150.0, "iv", 2.5) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -368,19 +368,19 @@ fn handwritten_analytical_model() -> equation::Analytical { .expect("handwritten analytical metadata should validate") } -fn build_analytical_subject(oral: usize, load: usize, iv: usize, cp: usize) -> Subject { +fn build_analytical_subject() -> Subject { Subject::builder("macro-vs-handwritten-analytical-full-features") - .bolus(0.0, 60.0, load) - .bolus(1.0, 100.0, oral) - .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 60.0, "load") + .bolus(1.0, 100.0, "oral") + .infusion(6.0, 140.0, "iv", 2.0) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -407,7 +407,7 @@ fn ode_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::Pharmsol assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); - let subject = build_ode_subject(oral, load, iv, cp); + let subject = build_ode_subject(); let params = [1.1, 0.18, 0.07, 0.04, 35.0, 0.6, 0.85, 4.0, 18.0, 9.0]; let macro_predictions = macro_ode.estimate_predictions(&subject, ¶ms)?; @@ -453,7 +453,7 @@ fn analytical_full_feature_macro_matches_handwritten() -> Result<(), pharmsol::P assert_eq!(handwritten_analytical.route_index("iv"), Some(iv)); assert_eq!(handwritten_analytical.output_index("cp"), Some(cp)); - let subject = build_analytical_subject(oral, load, iv, cp); + let subject = build_analytical_subject(); let params = [1.0, 0.16, 32.0, 0.5, 0.8, 3.0, 14.0, 0.16]; let macro_predictions = macro_analytical.estimate_predictions(&subject, ¶ms)?; From fe3916a80a9a938e5a61ba3a84ea0106c722ac76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 14:19:42 +0100 Subject: [PATCH 03/12] chore: channeless --- examples/compare_solvers.rs | 4 +- examples/macro_vs_handwritten_one_cpt.rs | 2 +- examples/macro_vs_handwritten_two_cpt.rs | 13 +- pharmsol-dsl/src/execution.rs | 2 +- pharmsol-macros/src/lib.rs | 47 +++--- src/data/event.rs | 177 ++++++++++++----------- src/data/parser/pmetrics.rs | 21 ++- src/data/row.rs | 24 +-- src/data/structs.rs | 33 +++-- src/dsl/jit.rs | 2 +- src/dsl/model_info.rs | 2 +- src/dsl/native.rs | 10 +- src/error/mod.rs | 6 +- src/simulator/equation/analytical/mod.rs | 8 +- src/simulator/equation/metadata.rs | 62 ++++---- src/simulator/equation/mod.rs | 9 +- src/simulator/equation/ode/closure.rs | 22 ++- src/simulator/equation/ode/mod.rs | 8 +- src/simulator/equation/sde/mod.rs | 8 +- src/simulator/mod.rs | 4 +- tests/analytical_macro_lowering.rs | 20 +-- tests/authoring_parity_corpus.rs | 94 ++++++------ tests/ode_macro_lowering.rs | 26 ++-- tests/sde_macro_lowering.rs | 20 +-- 24 files changed, 312 insertions(+), 312 deletions(-) diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index ebec4caa..ad705931 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -48,7 +48,7 @@ fn main() { let trbdf2 = two_cpt(OdeSolver::Sdirk(SdirkTableau::TrBdf2)); let esdirk34 = two_cpt(OdeSolver::Sdirk(SdirkTableau::Esdirk34)); - // Both declarations resolve to the same shared input channel, so subject + // Both declarations resolve to the same shared input, so subject // authoring still uses one numeric index for the loading bolus and the // maintenance infusion. let load = bdf.route_index("load").expect("load route exists"); @@ -57,7 +57,7 @@ fn main() { assert_eq!( load, iv, - "mixed IV declarations should share one numeric channel" + "mixed IV declarations should share one numeric input" ); let subject = Subject::builder("id1") diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs index ddff59f8..be9edb2a 100644 --- a/examples/macro_vs_handwritten_one_cpt.rs +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -26,7 +26,7 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( - // Handwritten closures stay on dense internal channels. + // Handwritten closures stay on dense internal slots. // Public labels like `iv` and `cp` live in attached metadata, not in // the low-level `rateiv[]` / `y[]` buffers. |x, p, _t, dx, _bolus, rateiv, _cov| { diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs index 915267d6..114024bd 100644 --- a/examples/macro_vs_handwritten_two_cpt.rs +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -1,5 +1,5 @@ //! Compares a declaration-first macro ODE with the equivalent handwritten ODE -//! on a two-compartment IV problem that shares one numeric input channel across +//! on a two-compartment IV problem that shares one numeric input across //! a loading bolus and a maintenance infusion. //! //! This keeps the macro story as the default surface while showing the @@ -9,7 +9,7 @@ use pharmsol::prelude::*; fn macro_model() -> equation::ODE { ode! { - name: "two_cpt_shared_channel_parity", + name: "two_cpt_shared_input_parity", params: [ke, kcp, kpc, v], states: [central, peripheral], outputs: [cp], @@ -29,7 +29,7 @@ fn macro_model() -> equation::ODE { fn handwritten_model() -> equation::ODE { equation::ODE::new( - // Handwritten closures stay on dense internal channels. + // Handwritten closures stay on dense internal slots. // Public route labels like `load` and `iv` are metadata names; the // low-level `bolus[]`, `rateiv[]`, and `y[]` buffers remain indexed by // dense internal slots. @@ -50,7 +50,7 @@ fn handwritten_model() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("two_cpt_shared_channel_parity") + equation::metadata::new("two_cpt_shared_input_parity") .parameters(["ke", "kcp", "kpc", "v"]) .states(["central", "peripheral"]) .outputs(["cp"]) @@ -83,10 +83,7 @@ fn main() -> Result<(), pharmsol::PharmsolError> { let iv = macro_ode.route_index("iv").expect("iv route exists"); let cp = macro_ode.output_index("cp").expect("cp output exists"); - assert_eq!( - load, iv, - "load and iv should share one numeric input channel" - ); + assert_eq!(load, iv, "load and iv should share one numeric input"); assert_eq!(handwritten_ode.route_index("load"), Some(load)); assert_eq!(handwritten_ode.route_index("iv"), Some(iv)); assert_eq!(handwritten_ode.output_index("cp"), Some(cp)); diff --git a/pharmsol-dsl/src/execution.rs b/pharmsol-dsl/src/execution.rs index 886d570a..8bac1d69 100644 --- a/pharmsol-dsl/src/execution.rs +++ b/pharmsol-dsl/src/execution.rs @@ -1516,7 +1516,7 @@ mod tests { } #[test] - fn authoring_routes_share_channel_indices_by_kind_local_ordinal() { + fn authoring_routes_share_input_indices_by_kind_local_ordinal() { let src = r#"name = shared_authoring kind = ode diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 96b9536e..0d143184 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -1105,7 +1105,7 @@ fn route_input_names(routes: &[OdeRouteDecl]) -> Vec { routes.iter().map(|route| route.input.name()).collect() } -fn ode_route_channel_bindings(routes: &[OdeRouteDecl]) -> Vec<(SymbolicIndex, usize)> { +fn ode_route_input_bindings(routes: &[OdeRouteDecl]) -> Vec<(SymbolicIndex, usize)> { let mut next_bolus_index = 0usize; let mut next_infusion_index = 0usize; @@ -2024,14 +2024,14 @@ fn expand_injected_ode_route_terms( let terms = routes .iter() .zip(route_bindings.iter()) - .map(|(route, (_, channel_index))| { + .map(|(route, (_, input_index))| { let destination = route_destination_index(route, states); match route.kind { OdeRouteKind::Bolus => quote! { - #dx[#destination] += #bolus[#channel_index]; + #dx[#destination] += #bolus[#input_index]; }, OdeRouteKind::Infusion => quote! { - #dx[#destination] += #rateiv[#channel_index]; + #dx[#destination] += #rateiv[#input_index]; }, } }); @@ -2048,19 +2048,18 @@ fn expand_injected_sde_rate_terms( dx: &Ident, rateiv: &Ident, ) -> TokenStream2 { - let terms = - routes - .iter() - .zip(route_bindings.iter()) - .filter_map(|(route, (_, channel_index))| match route.kind { - OdeRouteKind::Bolus => None, - OdeRouteKind::Infusion => { - let destination = route_destination_index(route, states); - Some(quote! { - #dx[#destination] += #rateiv[#channel_index]; - }) - } - }); + let terms = routes + .iter() + .zip(route_bindings.iter()) + .filter_map(|(route, (_, input_index))| match route.kind { + OdeRouteKind::Bolus => None, + OdeRouteKind::Infusion => { + let destination = route_destination_index(route, states); + Some(quote! { + #dx[#destination] += #rateiv[#input_index]; + }) + } + }); quote! { #(#terms)* @@ -2074,10 +2073,10 @@ fn expand_injected_sde_bolus_mappings( ) -> TokenStream2 { let mut destinations = vec![quote! { None }; dense_index_len(route_bindings)]; - for (route, (_, channel_index)) in routes.iter().zip(route_bindings.iter()) { + for (route, (_, input_index)) in routes.iter().zip(route_bindings.iter()) { if let OdeRouteKind::Bolus = route.kind { let destination = route_destination_index(route, states); - destinations[*channel_index] = quote! { Some(#destination) }; + destinations[*input_index] = quote! { Some(#destination) }; } } @@ -2829,7 +2828,7 @@ fn expand_sde_out( pub fn ode(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as OdeInput); - let route_bindings = ode_route_channel_bindings(&input.routes); + let route_bindings = ode_route_input_bindings(&input.routes); let lag_routes = match input.lag.as_ref() { Some(closure) => match extract_route_property_routes( @@ -2986,7 +2985,7 @@ pub fn ode(input: TokenStream) -> TokenStream { #[proc_macro] pub fn analytical(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as AnalyticalInput); - let route_bindings = ode_route_channel_bindings(&input.routes); + let route_bindings = ode_route_input_bindings(&input.routes); let kernel_spec = match resolve_analytical_structure(&input.structure) { Ok(spec) => spec, @@ -3150,7 +3149,7 @@ pub fn analytical(input: TokenStream) -> TokenStream { #[proc_macro] pub fn sde(input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as SdeInput); - let route_bindings = ode_route_channel_bindings(&input.routes); + let route_bindings = ode_route_input_bindings(&input.routes); let lag_routes = match input.lag.as_ref() { Some(closure) => match extract_route_property_routes( @@ -3364,13 +3363,13 @@ mod tests { } #[test] - fn ode_route_bindings_share_channels_by_kind_local_ordinal() { + fn ode_route_bindings_share_inputs_by_kind_local_ordinal() { let input = syn::parse_str::( "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: { bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot }, diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", ) .expect("declaration-first ode input should parse"); - let bindings = ode_route_channel_bindings(&input.routes); + let bindings = ode_route_input_bindings(&input.routes); assert_eq!(dense_index_len(&bindings), 2); assert_eq!(bindings[0].0.name(), "oral"); diff --git a/src/data/event.rs b/src/data/event.rs index 46995ef5..bff4c700 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -94,76 +94,85 @@ pub enum Event { Observation(Observation), } -#[derive(Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] -pub struct ChannelId(String); +macro_rules! impl_label_type { + ($name:ident) => { + #[derive( + Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, + )] + pub struct $name(String); + + impl $name { + pub fn new(label: impl ToString) -> Self { + Self(label.to_string()) + } -impl ChannelId { - pub fn new(label: impl ToString) -> Self { - Self(label.to_string()) - } + pub fn as_str(&self) -> &str { + &self.0 + } - pub fn as_str(&self) -> &str { - &self.0 - } + pub fn index(&self) -> Option { + self.0.parse::().ok() + } + } - pub fn index(&self) -> Option { - self.0.parse::().ok() - } -} + impl From for $name { + fn from(value: String) -> Self { + Self(value) + } + } -impl From for ChannelId { - fn from(value: String) -> Self { - Self(value) - } -} + impl From<&str> for $name { + fn from(value: &str) -> Self { + Self(value.to_string()) + } + } -impl From<&str> for ChannelId { - fn from(value: &str) -> Self { - Self(value.to_string()) - } -} + impl From for $name { + fn from(value: usize) -> Self { + Self(value.to_string()) + } + } -impl From for ChannelId { - fn from(value: usize) -> Self { - Self(value.to_string()) - } -} + impl AsRef for $name { + fn as_ref(&self) -> &str { + self.as_str() + } + } -impl AsRef for ChannelId { - fn as_ref(&self) -> &str { - self.as_str() - } -} + impl fmt::Display for $name { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } + } -impl fmt::Display for ChannelId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.as_str()) - } -} + impl PartialEq for $name { + fn eq(&self, other: &usize) -> bool { + self.index() == Some(*other) + } + } -impl PartialEq for ChannelId { - fn eq(&self, other: &usize) -> bool { - self.index() == Some(*other) - } -} + impl PartialEq<$name> for usize { + fn eq(&self, other: &$name) -> bool { + other == self + } + } -impl PartialEq for usize { - fn eq(&self, other: &ChannelId) -> bool { - other == self - } -} + impl PartialEq for &$name { + fn eq(&self, other: &usize) -> bool { + (**self).eq(other) + } + } -impl PartialEq for &ChannelId { - fn eq(&self, other: &usize) -> bool { - (**self).eq(other) - } + impl PartialEq<&$name> for usize { + fn eq(&self, other: &&$name) -> bool { + other.eq(self) + } + } + }; } -impl PartialEq<&ChannelId> for usize { - fn eq(&self, other: &&ChannelId) -> bool { - other.eq(self) - } -} +impl_label_type!(InputLabel); +impl_label_type!(OutputLabel); impl Event { /// Get the time of the event @@ -224,7 +233,7 @@ impl Event { pub struct Bolus { time: f64, amount: f64, - input: ChannelId, + input: InputLabel, occasion: usize, } impl Bolus { @@ -234,12 +243,12 @@ impl Bolus { /// /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered - /// * `input` - The compartment number receiving the dose + /// * `input` - The route label receiving the dose pub fn new(time: f64, amount: f64, input: impl ToString, occasion: usize) -> Self { Bolus { time, amount, - input: ChannelId::new(input), + input: InputLabel::new(input), occasion, } } @@ -249,8 +258,8 @@ impl Bolus { self.amount } - /// Get the compartment number that receives the bolus - pub fn input(&self) -> &ChannelId { + /// Get the route label that receives the bolus + pub fn input(&self) -> &InputLabel { &self.input } @@ -268,9 +277,9 @@ impl Bolus { self.amount = amount; } - /// Set the compartment number that receives the bolus + /// Set the route label that receives the bolus pub fn set_input(&mut self, input: impl ToString) { - self.input = ChannelId::new(input); + self.input = InputLabel::new(input); } /// Set the time of the bolus administration @@ -283,8 +292,8 @@ impl Bolus { &mut self.amount } - /// Get a mutable reference to the compartment number (1-indexed) that receives the bolus - pub fn mut_input(&mut self) -> &mut ChannelId { + /// Get a mutable reference to the route label that receives the bolus + pub fn mut_input(&mut self) -> &mut InputLabel { &mut self.input } @@ -311,7 +320,7 @@ impl Bolus { pub struct Infusion { time: f64, amount: f64, - input: ChannelId, + input: InputLabel, duration: f64, occasion: usize, } @@ -322,7 +331,7 @@ impl Infusion { /// /// * `time` - Start time of the infusion /// * `amount` - Total amount of drug to be administered - /// * `input` - The compartment number receiving the dose + /// * `input` - The route label receiving the dose /// * `duration` - Duration of the infusion in time units pub fn new( time: f64, @@ -334,7 +343,7 @@ impl Infusion { Infusion { time, amount, - input: ChannelId::new(input), + input: InputLabel::new(input), duration, occasion, } @@ -345,8 +354,8 @@ impl Infusion { self.amount } - /// Get the compartment number that receives the infusion - pub fn input(&self) -> &ChannelId { + /// Get the route label that receives the infusion + pub fn input(&self) -> &InputLabel { &self.input } @@ -371,9 +380,9 @@ impl Infusion { self.amount = amount; } - /// Set the compartment number that receives the infusion + /// Set the route label that receives the infusion pub fn set_input(&mut self, input: impl ToString) { - self.input = ChannelId::new(input); + self.input = InputLabel::new(input); } /// Set the time of the infusion administration @@ -391,8 +400,8 @@ impl Infusion { &mut self.amount } - /// Get a mutable reference to the compartment number (1-indexed) that receives the infusion - pub fn mut_input(&mut self) -> &mut ChannelId { + /// Get a mutable reference to the route label that receives the infusion + pub fn mut_input(&mut self) -> &mut InputLabel { &mut self.input } @@ -434,7 +443,7 @@ pub enum Censor { pub struct Observation { time: f64, value: Option, - outeq: ChannelId, + outeq: OutputLabel, errorpoly: Option, occasion: usize, censoring: Censor, @@ -446,7 +455,7 @@ impl Observation { /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number corresponding to this observation + /// * `outeq` - Output label corresponding to this observation /// * `errorpoly` - Optional error polynomial coefficients (c0, c1, c2, c3) /// * `occasion` - Occasion index /// * `censoring` - Censoring type for this observation @@ -461,7 +470,7 @@ impl Observation { Observation { time, value, - outeq: ChannelId::new(outeq), + outeq: OutputLabel::new(outeq), errorpoly, occasion, censoring, @@ -478,8 +487,8 @@ impl Observation { self.value } - /// Get the output equation number corresponding to this observation - pub fn outeq(&self) -> &ChannelId { + /// Get the output label corresponding to this observation + pub fn outeq(&self) -> &OutputLabel { &self.outeq } @@ -504,9 +513,9 @@ impl Observation { self.value = value; } - /// Set the output equation number corresponding to this observation + /// Set the output label corresponding to this observation pub fn set_outeq(&mut self, outeq: impl ToString) { - self.outeq = ChannelId::new(outeq); + self.outeq = OutputLabel::new(outeq); } /// Set the [ErrorPoly] for this observation @@ -524,8 +533,8 @@ impl Observation { &mut self.value } - /// Get a mutable reference to the output equation number - pub fn mut_outeq(&mut self) -> &mut ChannelId { + /// Get a mutable reference to the output label + pub fn mut_outeq(&mut self) -> &mut OutputLabel { &mut self.outeq } diff --git a/src/data/parser/pmetrics.rs b/src/data/parser/pmetrics.rs index 4554e435..89943f6e 100644 --- a/src/data/parser/pmetrics.rs +++ b/src/data/parser/pmetrics.rs @@ -95,14 +95,14 @@ struct Row { #[serde(deserialize_with = "deserialize_option_f64")] ii: Option, /// Input compartment - #[serde(deserialize_with = "deserialize_option_channel_id")] - input: Option, + #[serde(deserialize_with = "deserialize_option_route_label")] + input: Option, /// Observed value #[serde(deserialize_with = "deserialize_option_f64")] out: Option, /// Corresponding output equation for the observation - #[serde(deserialize_with = "deserialize_option_channel_id")] - outeq: Option, + #[serde(deserialize_with = "deserialize_option_output_label")] + outeq: Option, /// Censoring output #[serde(default, deserialize_with = "deserialize_option_censor")] cens: Option, @@ -196,11 +196,18 @@ where } } -fn deserialize_option_channel_id<'de, D>(deserializer: D) -> Result, D::Error> +fn deserialize_option_route_label<'de, D>(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - deserialize_option::(deserializer).map(|value| value.map(ChannelId::from)) + deserialize_option::(deserializer).map(|value| value.map(InputLabel::from)) +} + +fn deserialize_option_output_label<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + deserialize_option::(deserializer).map(|value| value.map(OutputLabel::from)) } fn deserialize_option_isize<'de, D>(deserializer: D) -> Result, D::Error> @@ -498,7 +505,7 @@ mod tests { } #[test] - fn read_pmetrics_preserves_named_channel_labels() { + fn read_pmetrics_preserves_named_route_and_output_labels() { let file = NamedTempFile::new().unwrap(); std::fs::write( file.path(), diff --git a/src/data/row.rs b/src/data/row.rs index f6e44e98..b9a807c1 100644 --- a/src/data/row.rs +++ b/src/data/row.rs @@ -32,8 +32,8 @@ use thiserror::Error; /// /// # Fields /// -/// All fields use Pmetrics conventions: -/// - `input` and `outeq` are **1-indexed** (kept as-is, user must size arrays accordingly) +/// All fields use the public labeling conventions: +/// - `input` and `outeq` preserve the route and output labels from the source data /// - `evid`: 0=observation, 1=dose, 4=reset/new occasion /// - `addl`: positive=forward in time, negative=backward in time /// @@ -78,12 +78,12 @@ pub struct DataRow { pub addl: Option, /// Interdose interval for ADDL pub ii: Option, - /// Input compartment - pub input: Option, + /// Input route label + pub input: Option, /// Observed value (for EVID=0) pub out: Option, - /// Output equation number - pub outeq: Option, + /// Output label + pub outeq: Option, /// Censoring indicator pub cens: Option, /// Error polynomial coefficients @@ -373,12 +373,12 @@ impl DataRowBuilder { self } - /// Set the input compartment (1-indexed) + /// Set the input route label /// /// Required for EVID=1 (dosing events). - /// Kept as 1-indexed; user must size state arrays accordingly. + /// Preserved as the public route label until model resolution. pub fn input(mut self, input: impl ToString) -> Self { - self.row.input = Some(ChannelId::new(input)); + self.row.input = Some(InputLabel::new(input)); self } @@ -390,12 +390,12 @@ impl DataRowBuilder { self } - /// Set the output equation (1-indexed) + /// Set the output label /// /// Required for EVID=0 (observation events). - /// Will be converted to 0-indexed internally. + /// Preserved as the public output label until model resolution. pub fn outeq(mut self, outeq: impl ToString) -> Self { - self.row.outeq = Some(ChannelId::new(outeq)); + self.row.outeq = Some(OutputLabel::new(outeq)); self } diff --git a/src/data/structs.rs b/src/data/structs.rs index c977d89a..d7d123b1 100644 --- a/src/data/structs.rs +++ b/src/data/structs.rs @@ -180,17 +180,18 @@ impl Data { let old_events = occasion.process_events(None, true); // Create a set of existing (time, outeq) pairs for fast lookup - let existing_obs: std::collections::HashSet<(u64, ChannelId)> = old_events - .iter() - .filter_map(|event| match event { - Event::Observation(obs) => { - // Convert to microseconds for consistent comparison - let time_key = (obs.time() * 1e6).round() as u64; - Some((time_key, obs.outeq().clone())) - } - _ => None, - }) - .collect(); + let existing_obs: std::collections::HashSet<(u64, OutputLabel)> = + old_events + .iter() + .filter_map(|event| match event { + Event::Observation(obs) => { + // Convert to microseconds for consistent comparison + let time_key = (obs.time() * 1e6).round() as u64; + Some((time_key, obs.outeq().clone())) + } + _ => None, + }) + .collect(); // Generate new observation times let mut new_events = Vec::new(); @@ -273,10 +274,10 @@ impl Data { self.subjects.is_empty() } - /// Get a vector of all unique output equations (outeq) across all subjects - pub fn get_output_equations(&self) -> Vec { + /// Get a vector of all unique output labels (outeq) across all subjects + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let mut outeq_values: Vec = self + let mut outeq_values: Vec = self .subjects .iter() .flat_map(|subject| subject.get_output_equations()) @@ -396,9 +397,9 @@ impl Subject { self.occasions.iter_mut() } - pub fn get_output_equations(&self) -> Vec { + pub fn get_output_equations(&self) -> Vec { // Collect all unique outeq values in order of occurrence - let outeq_values: Vec = self + let outeq_values: Vec = self .occasions .iter() .flat_map(|occasion| { diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index b0f1fe4a..5504ab08 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -1331,7 +1331,7 @@ mod tests { } #[test] - fn authoring_runtime_shares_channel_between_bolus_and_infusion_routes() { + fn authoring_runtime_shares_input_between_bolus_and_infusion_routes() { let source = r#" name = shared_authoring kind = ode diff --git a/src/dsl/model_info.rs b/src/dsl/model_info.rs index 0094059f..d9a2fdbd 100644 --- a/src/dsl/model_info.rs +++ b/src/dsl/model_info.rs @@ -243,7 +243,7 @@ model explicit_route_usage { } #[test] - fn authoring_shared_channel_routes_keep_declaration_specific_injection() { + fn authoring_shared_input_routes_keep_declaration_specific_injection() { let info = load_model_info( r#" name = shared_authoring diff --git a/src/dsl/native.rs b/src/dsl/native.rs index d2600e67..c1ce8eac 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -20,7 +20,7 @@ pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, }; use crate::{ - data::{ChannelId, Covariates, Infusion}, + data::{Covariates, Infusion, InputLabel, OutputLabel}, simulator::{ equation::{ ode::{closure_helpers::PMProblem, ExplicitRkTableau, OdeSolver, SdirkTableau}, @@ -392,14 +392,14 @@ impl SharedNativeModel { } Err(PharmsolError::OtherError(format!( - "model `{}` does not declare a {:?} route for input channel {}", + "model `{}` does not declare a {:?} route for input {}", self.info.name, kind, input ))) } fn resolve_input_label( &self, - label: &ChannelId, + label: &InputLabel, kind: RouteKind, ) -> Result { if let Some(input) = self.route_index(label.as_str()) { @@ -416,7 +416,7 @@ impl SharedNativeModel { Ok(input) } - fn resolve_output_label(&self, label: &ChannelId) -> Result { + fn resolve_output_label(&self, label: &OutputLabel) -> Result { if let Some(outeq) = self.output_index(label.as_str()) { return Ok(outeq); } @@ -682,7 +682,7 @@ impl SharedNativeModel { .bolus_destination(input) .ok_or_else(|| { PharmsolError::OtherError(format!( - "model `{}` does not declare a bolus route for input channel {}", + "model `{}` does not declare a bolus route for input index {}", self.info.name, input )) })?; diff --git a/src/error/mod.rs b/src/error/mod.rs index 5145626e..c8f70b58 100644 --- a/src/error/mod.rs +++ b/src/error/mod.rs @@ -37,11 +37,11 @@ pub enum PharmsolError { ZeroLikelihood, #[error("Missing observation in prediction")] MissingObservation, - #[error("Input label `{label}` could not be resolved to a route channel")] + #[error("Input label `{label}` could not be resolved to a route input")] UnknownInputLabel { label: String }, - #[error("Output label `{label}` could not be resolved to an output channel")] + #[error("Output label `{label}` could not be resolved to an output")] UnknownOutputLabel { label: String }, - #[error("Input channel {input} is out of range (ndrugs = {ndrugs})")] + #[error("Input index {input} is out of range (ndrugs = {ndrugs})")] InputOutOfRange { input: usize, ndrugs: usize }, #[error("Output equation {outeq} is out of range (nout = {nout})")] OuteqOutOfRange { outeq: usize, nout: usize }, diff --git a/src/simulator/equation/analytical/mod.rs b/src/simulator/equation/analytical/mod.rs index b0d78481..1dd4bbb5 100644 --- a/src/simulator/equation/analytical/mod.rs +++ b/src/simulator/equation/analytical/mod.rs @@ -32,9 +32,7 @@ pub enum AnalyticalMetadataError { Validation(#[from] ModelMetadataError), #[error("analytical model declares {declared} state metadata entries but model has {expected} states")] StateCountMismatch { expected: usize, declared: usize }, - #[error( - "analytical model declares {declared} route metadata entries but model has {expected} input channels" - )] + #[error("analytical model declares {declared} route metadata entries but model has {expected} inputs")] RouteCountMismatch { expected: usize, declared: usize }, #[error("analytical model declares {declared} output metadata entries but model has {expected} outputs")] OutputCountMismatch { expected: usize, declared: usize }, @@ -119,7 +117,7 @@ impl Analytical { self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; self.invalidate_metadata(); @@ -186,7 +184,7 @@ fn validate_metadata_dimensions( }); } - let declared_routes = metadata.route_channel_count(); + let declared_routes = metadata.route_input_count(); if declared_routes != neqs.ndrugs { return Err(AnalyticalMetadataError::RouteCountMismatch { expected: neqs.ndrugs, diff --git a/src/simulator/equation/metadata.rs b/src/simulator/equation/metadata.rs index ecf51a52..fecab7e2 100644 --- a/src/simulator/equation/metadata.rs +++ b/src/simulator/equation/metadata.rs @@ -80,7 +80,7 @@ pub struct ValidatedModelMetadata { covariates: Vec, states: Vec, routes: Vec, - route_channel_count: usize, + route_input_count: usize, outputs: Vec, particles: Option, analytical: Option, @@ -111,8 +111,8 @@ impl ValidatedModelMetadata { &self.routes } - pub fn route_channel_count(&self) -> usize { - self.route_channel_count + pub fn route_input_count(&self) -> usize { + self.route_input_count } pub fn outputs(&self) -> &[Output] { @@ -144,7 +144,7 @@ impl ValidatedModelMetadata { } pub fn route_index(&self, name: &str) -> Option { - self.route(name).map(ValidatedRoute::channel_index) + self.route(name).map(ValidatedRoute::input_index) } pub fn route_declaration_index(&self, name: &str) -> Option { @@ -185,7 +185,7 @@ pub struct ValidatedRoute { name: String, kind: RouteKind, declaration_index: usize, - channel_index: usize, + input_index: usize, destination: String, destination_index: usize, has_lag: bool, @@ -206,8 +206,8 @@ impl ValidatedRoute { self.declaration_index } - pub fn channel_index(&self) -> usize { - self.channel_index + pub fn input_index(&self) -> usize { + self.input_index } pub fn destination(&self) -> &str { @@ -416,7 +416,7 @@ impl ModelMetadata { let particles = resolve_particles(kind, self.particles, fallback_particles)?; validate_kind_specific_fields(kind, self.analytical, particles)?; - let (routes, route_channel_count) = validate_routes(self.routes, &self.states)?; + let (routes, route_input_count) = validate_routes(self.routes, &self.states)?; Ok(ValidatedModelMetadata { name: self.name, @@ -425,7 +425,7 @@ impl ModelMetadata { covariates: self.covariates, states: self.states, routes, - route_channel_count, + route_input_count, outputs: self.outputs, particles, analytical: self.analytical, @@ -730,20 +730,20 @@ fn validate_routes( routes: Vec, states: &[State], ) -> Result<(Vec, usize), ModelMetadataError> { - let mut bolus_channels = 0; - let mut infusion_channels = 0; + let mut bolus_inputs = 0; + let mut infusion_inputs = 0; let mut validated_routes = Vec::with_capacity(routes.len()); for (declaration_index, route) in routes.into_iter().enumerate() { - let channel_index = match route.kind { + let input_index = match route.kind { RouteKind::Bolus => { - let index = bolus_channels; - bolus_channels += 1; + let index = bolus_inputs; + bolus_inputs += 1; index } RouteKind::Infusion => { - let index = infusion_channels; - infusion_channels += 1; + let index = infusion_inputs; + infusion_inputs += 1; index } }; @@ -751,18 +751,18 @@ fn validate_routes( validated_routes.push(validate_route( route, declaration_index, - channel_index, + input_index, states, )?); } - Ok((validated_routes, bolus_channels.max(infusion_channels))) + Ok((validated_routes, bolus_inputs.max(infusion_inputs))) } fn validate_route( route: Route, declaration_index: usize, - channel_index: usize, + input_index: usize, states: &[State], ) -> Result { if route.kind == RouteKind::Infusion && route.has_lag { @@ -796,7 +796,7 @@ fn validate_route( name: route.name, kind: route.kind, declaration_index, - channel_index, + input_index, destination, destination_index, has_lag: route.has_lag, @@ -902,7 +902,7 @@ mod tests { assert_eq!(metadata.state_index("central"), Some(0)); assert_eq!(metadata.route_index("iv"), Some(0)); assert_eq!(metadata.route_declaration_index("iv"), Some(0)); - assert_eq!(metadata.route_channel_count(), 1); + assert_eq!(metadata.route_input_count(), 1); assert_eq!(metadata.output_index("cp"), Some(0)); assert_eq!( metadata.route("iv").expect("route exists").destination(), @@ -915,10 +915,7 @@ mod tests { .declaration_index(), 0 ); - assert_eq!( - metadata.route("iv").expect("route exists").channel_index(), - 0 - ); + assert_eq!(metadata.route("iv").expect("route exists").input_index(), 0); assert_eq!( metadata .route("iv") @@ -988,8 +985,8 @@ mod tests { } #[test] - fn shared_channel_routes_preserve_declaration_and_channel_identity() { - let metadata = new("shared_channel") + fn shared_input_routes_preserve_declaration_and_input_identity() { + let metadata = new("shared_input") .kind(ModelKind::Ode) .parameters(["ke"]) .states(["gut", "central"]) @@ -999,19 +996,16 @@ mod tests { Route::infusion("iv").to_state("central"), ]) .validate() - .expect("shared-channel metadata should validate"); + .expect("shared-input metadata should validate"); assert_eq!(metadata.routes().len(), 2); - assert_eq!(metadata.route_channel_count(), 1); + assert_eq!(metadata.route_input_count(), 1); assert_eq!(metadata.route_index("oral"), Some(0)); assert_eq!(metadata.route_index("iv"), Some(0)); assert_eq!(metadata.route_declaration_index("oral"), Some(0)); assert_eq!(metadata.route_declaration_index("iv"), Some(1)); - assert_eq!( - metadata.route("oral").expect("oral route").channel_index(), - 0 - ); - assert_eq!(metadata.route("iv").expect("iv route").channel_index(), 0); + assert_eq!(metadata.route("oral").expect("oral route").input_index(), 0); + assert_eq!(metadata.route("iv").expect("iv route").input_index(), 0); assert_eq!( metadata .route("oral") diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index f3532382..c5a97958 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -12,7 +12,8 @@ pub use sde::*; use crate::{ error_model::AssayErrorModels, simulator::{Fa, Lag}, - ChannelId, Covariates, Event, Infusion, Observation, Occasion, PharmsolError, Subject, + Covariates, Event, Infusion, InputLabel, Observation, Occasion, OutputLabel, PharmsolError, + Subject, }; use super::likelihood::Prediction; @@ -145,7 +146,7 @@ pub(crate) trait EquationPriv: EquationTypes { fn resolve_input_label( &self, - label: &ChannelId, + label: &InputLabel, expected_kind: RouteKind, ) -> Result { if let Some(metadata) = self.metadata() { @@ -165,7 +166,7 @@ pub(crate) trait EquationPriv: EquationTypes { ))); } - return Ok(route.channel_index()); + return Ok(route.input_index()); } label @@ -175,7 +176,7 @@ pub(crate) trait EquationPriv: EquationTypes { }) } - fn resolve_output_label(&self, label: &ChannelId) -> Result { + fn resolve_output_label(&self, label: &OutputLabel) -> Result { if let Some(metadata) = self.metadata() { return metadata.output_index(label.as_str()).ok_or_else(|| { PharmsolError::UnknownOutputLabel { diff --git a/src/simulator/equation/ode/closure.rs b/src/simulator/equation/ode/closure.rs index cb9c0726..47f2a81e 100644 --- a/src/simulator/equation/ode/closure.rs +++ b/src/simulator/equation/ode/closure.rs @@ -11,13 +11,13 @@ type C = ::C; type T = ::T; #[derive(Debug, Clone)] -struct InfusionChannel { +struct InfusionTrack { input: usize, event_times: Vec, cumulative_rates: Vec, } -impl InfusionChannel { +impl InfusionTrack { fn new(input: usize, mut events: Vec<(f64, f64)>) -> Self { events.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal)); @@ -63,15 +63,13 @@ impl InfusionChannel { #[derive(Debug, Clone, Default)] struct InfusionSchedule { - channels: Vec, + tracks: Vec, } impl InfusionSchedule { fn new(ndrugs: usize, infusions: &[&Infusion]) -> Result { if ndrugs == 0 || infusions.is_empty() { - return Ok(Self { - channels: Vec::new(), - }); + return Ok(Self { tracks: Vec::new() }); } let mut per_input: Vec> = vec![Vec::new(); ndrugs]; @@ -94,27 +92,27 @@ impl InfusionSchedule { per_input[input].push((infusion.time() + infusion.duration(), -rate)); } - let channels = per_input + let tracks = per_input .into_iter() .enumerate() .filter_map(|(input, events)| { if events.is_empty() { None } else { - Some(InfusionChannel::new(input, events)) + Some(InfusionTrack::new(input, events)) } }) .collect(); - Ok(Self { channels }) + Ok(Self { tracks }) } fn fill_rate_vector(&self, time: f64, rateiv: &mut V) { rateiv.fill(0.0); - for channel in &self.channels { - let rate = channel.rate_at(time); + for track in &self.tracks { + let rate = track.rate_at(time); if rate != 0.0 { - rateiv[channel.input] = rate; + rateiv[track.input] = rate; } } } diff --git a/src/simulator/equation/ode/mod.rs b/src/simulator/equation/ode/mod.rs index 853b3108..c65f16a9 100644 --- a/src/simulator/equation/ode/mod.rs +++ b/src/simulator/equation/ode/mod.rs @@ -87,9 +87,7 @@ pub enum OdeMetadataError { Validation(#[from] ModelMetadataError), #[error("ODE declares {declared} state metadata entries but model has {expected} states")] StateCountMismatch { expected: usize, declared: usize }, - #[error( - "ODE declares {declared} route metadata entries but model has {expected} input channels" - )] + #[error("ODE declares {declared} route metadata entries but model has {expected} inputs")] RouteCountMismatch { expected: usize, declared: usize }, #[error("ODE declares {declared} output metadata entries but model has {expected} outputs")] OutputCountMismatch { expected: usize, declared: usize }, @@ -134,7 +132,7 @@ impl ODE { self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; self.invalidate_metadata(); @@ -211,7 +209,7 @@ fn validate_metadata_dimensions( }); } - let declared_routes = metadata.route_channel_count(); + let declared_routes = metadata.route_input_count(); if declared_routes != neqs.ndrugs { return Err(OdeMetadataError::RouteCountMismatch { expected: neqs.ndrugs, diff --git a/src/simulator/equation/sde/mod.rs b/src/simulator/equation/sde/mod.rs index c24b615d..43a1d48a 100644 --- a/src/simulator/equation/sde/mod.rs +++ b/src/simulator/equation/sde/mod.rs @@ -34,9 +34,7 @@ pub enum SdeMetadataError { Validation(#[from] ModelMetadataError), #[error("SDE declares {declared} state metadata entries but model has {expected} states")] StateCountMismatch { expected: usize, declared: usize }, - #[error( - "SDE declares {declared} route metadata entries but model has {expected} input channels" - )] + #[error("SDE declares {declared} route metadata entries but model has {expected} inputs")] RouteCountMismatch { expected: usize, declared: usize }, #[error("SDE declares {declared} output metadata entries but model has {expected} outputs")] OutputCountMismatch { expected: usize, declared: usize }, @@ -236,7 +234,7 @@ impl SDE { self } - /// Set the number of drug input channels (size of bolus[] and rateiv[]). + /// Set the number of drug inputs (size of bolus[] and rateiv[]). pub fn with_ndrugs(mut self, ndrugs: usize) -> Self { self.neqs.ndrugs = ndrugs; self.invalidate_metadata(); @@ -309,7 +307,7 @@ fn validate_metadata_dimensions( }); } - let declared_routes = metadata.route_channel_count(); + let declared_routes = metadata.route_input_count(); if declared_routes != neqs.ndrugs { return Err(SdeMetadataError::RouteCountMismatch { expected: neqs.ndrugs, diff --git a/src/simulator/mod.rs b/src/simulator/mod.rs index 5cea84fe..058ca125 100644 --- a/src/simulator/mod.rs +++ b/src/simulator/mod.rs @@ -200,7 +200,7 @@ pub type Fa = fn(&V, T, &Covariates) -> HashMap; /// /// # Fields /// - `nstates`: Number of state variables (ODE compartments) -/// - `ndrugs`: Number of drug input channels (size of bolus[] and rateiv[]) +/// - `ndrugs`: Number of drug inputs (size of bolus[] and rateiv[]) /// - `nout`: Number of output equations /// /// # Defaults @@ -218,7 +218,7 @@ pub type Fa = fn(&V, T, &Covariates) -> HashMap; pub struct Neqs { /// Number of state variables pub nstates: usize, - /// Number of drug input channels (bolus/rateiv size) + /// Number of drug inputs (bolus/rateiv size) pub ndrugs: usize, /// Number of output equations pub nout: usize, diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs index c55719f5..796cb55e 100644 --- a/tests/analytical_macro_lowering.rs +++ b/tests/analytical_macro_lowering.rs @@ -19,7 +19,7 @@ fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { .build() } -fn shared_channel_subject() -> Subject { +fn shared_input_subject() -> Subject { Subject::builder("analytical-macro-shared") .bolus(0.0, 100.0, "oral") .infusion(6.0, 60.0, "iv", 2.0) @@ -160,7 +160,7 @@ fn handwritten_one_compartment_with_absorption() -> equation::Analytical { .expect("handwritten absorption metadata should validate") } -fn macro_shared_channel_analytical() -> equation::Analytical { +fn macro_shared_input_analytical() -> equation::Analytical { analytical! { name: "one_cmt_abs_shared", params: [ka, ke, v, tlag, f_oral], @@ -183,7 +183,7 @@ fn macro_shared_channel_analytical() -> equation::Analytical { } } -fn handwritten_shared_channel_analytical() -> equation::Analytical { +fn handwritten_shared_input_analytical() -> equation::Analytical { equation::Analytical::new( equation::one_compartment_with_absorption, |_p, _t, _cov| {}, @@ -219,7 +219,7 @@ fn handwritten_shared_channel_analytical() -> equation::Analytical { ]) .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), ) - .expect("handwritten shared-channel analytical metadata should validate") + .expect("handwritten shared-input analytical metadata should validate") } fn macro_covariate_analytical() -> equation::Analytical { @@ -438,10 +438,10 @@ fn analytical_macro_supports_extra_parameters_and_named_route_bindings() { } #[test] -fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { - let macro_model = macro_shared_channel_analytical(); - let handwritten_model = handwritten_shared_channel_analytical(); - let subject = shared_channel_subject(); +fn analytical_macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_input_analytical(); + let handwritten_model = handwritten_shared_input_analytical(); + let subject = shared_input_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -453,12 +453,12 @@ fn analytical_macro_shared_channel_lowering_matches_handwritten_metadata_and_pre let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) - .expect("macro shared-channel analytical model should simulate") + .expect("macro shared-input analytical model should simulate") .flat_predictions() .to_vec(); let handwritten_predictions = handwritten_model .estimate_predictions(&subject, &support_point) - .expect("handwritten shared-channel analytical model should simulate") + .expect("handwritten shared-input analytical model should simulate") .flat_predictions() .to_vec(); diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index 67f91c7a..37a5891a 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -70,8 +70,8 @@ out(cp) = central / v ~ continuous() "#; #[cfg(feature = "dsl-jit")] -const ODE_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" -name = shared_channel_one_cpt +const ODE_RUNTIME_SHARED_INPUT_DSL: &str = r#" +name = shared_input_one_cpt kind = ode params = ka, ke, v, tlag, f_oral @@ -122,7 +122,7 @@ out(cp) = central / v ~ continuous() "#; #[cfg(feature = "dsl-jit")] -const ANALYTICAL_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +const ANALYTICAL_RUNTIME_SHARED_INPUT_DSL: &str = r#" name = one_cmt_abs_shared kind = analytical @@ -177,7 +177,7 @@ out(cp) = central / v ~ continuous() "#; #[cfg(feature = "dsl-jit")] -const SDE_RUNTIME_SHARED_CHANNEL_DSL: &str = r#" +const SDE_RUNTIME_SHARED_INPUT_DSL: &str = r#" name = one_cmt_shared_sde kind = sde @@ -205,7 +205,7 @@ struct MetadataParityView { parameters: Vec, covariates: Vec, states: Vec, - route_channel_count: usize, + route_input_count: usize, routes: Vec, outputs: Vec, analytical_kernel: Option, @@ -230,7 +230,7 @@ struct RouteParity { name: String, kind: Option, declaration_index: usize, - channel_index: usize, + input_index: usize, destination_name: String, destination_index: usize, has_lag: bool, @@ -242,7 +242,7 @@ struct RouteParity { struct RouteInputPolicyParity { name: String, declaration_index: usize, - channel_index: usize, + input_index: usize, input_policy: RouteInputPolicy, } @@ -286,8 +286,8 @@ fn compile_runtime_jit_model(src: &str, model_name: &str) -> dsl::CompiledRuntim } #[cfg(feature = "dsl-jit")] -fn shared_channel_prediction_subject() -> Subject { - Subject::builder("authoring-parity-shared-channel") +fn shared_input_prediction_subject() -> Subject { + Subject::builder("authoring-parity-shared-input") .bolus(0.0, 100.0, "oral") .infusion(6.0, 60.0, "iv", 2.0) .missing_observation(0.5, "cp") @@ -347,7 +347,7 @@ fn dsl_metadata_view(src: &str) -> MetadataParityView { name: route.name.clone(), kind: route.kind.map(RouteKindParity::from_dsl), declaration_index: route.declaration_index, - channel_index: route.index, + input_index: route.index, destination_name: route.destination.state_name.clone(), destination_index: route.destination.state_offset, has_lag: route.has_lag, @@ -361,7 +361,7 @@ fn dsl_metadata_view(src: &str) -> MetadataParityView { parameters, covariates, states, - route_channel_count: model.abi.route_buffer.len, + route_input_count: model.abi.route_buffer.len, routes, outputs, analytical_kernel: model.metadata.analytical, @@ -379,7 +379,7 @@ fn dsl_route_input_policy_view(src: &str) -> Vec { .map(|route| RouteInputPolicyParity { name: route.name, declaration_index: route.declaration_index, - channel_index: route.index, + input_index: route.index, input_policy: if route.inject_input_to_destination { RouteInputPolicy::InjectToDestination } else { @@ -421,7 +421,7 @@ fn validated_metadata_view(metadata: &ValidatedModelMetadata) -> MetadataParityV index, }) .collect(), - route_channel_count: metadata.route_channel_count(), + route_input_count: metadata.route_input_count(), routes: metadata .routes() .iter() @@ -429,7 +429,7 @@ fn validated_metadata_view(metadata: &ValidatedModelMetadata) -> MetadataParityV name: route.name().to_string(), kind: Some(RouteKindParity::from_handwritten(route.kind())), declaration_index: route.declaration_index(), - channel_index: route.channel_index(), + input_index: route.input_index(), destination_name: route.destination().to_string(), destination_index: route.destination_index(), has_lag: route.has_lag(), @@ -460,7 +460,7 @@ fn handwritten_route_input_policy_view( .map(|route| RouteInputPolicyParity { name: route.name().to_string(), declaration_index: route.declaration_index(), - channel_index: route.channel_index(), + input_index: route.input_index(), input_policy: route .input_policy() .expect("route input policy should be explicit in this handwritten fixture"), @@ -565,9 +565,9 @@ fn handwritten_ode_model() -> equation::ODE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_macro_ode() -> equation::ODE { +fn runtime_shared_input_macro_ode() -> equation::ODE { ode! { - name: "shared_channel_one_cpt", + name: "shared_input_one_cpt", params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [cp], @@ -592,7 +592,7 @@ fn runtime_shared_channel_macro_ode() -> equation::ODE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_handwritten_ode() -> equation::ODE { +fn runtime_shared_input_handwritten_ode() -> equation::ODE { equation::ODE::new( |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); @@ -617,7 +617,7 @@ fn runtime_shared_channel_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_channel_one_cpt") + equation::metadata::new("shared_input_one_cpt") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -632,11 +632,11 @@ fn runtime_shared_channel_handwritten_ode() -> equation::ODE { .expect_explicit_input(), ]), ) - .expect("handwritten shared-channel ODE metadata should validate") + .expect("handwritten shared-input ODE metadata should validate") } #[cfg(feature = "dsl-jit")] -fn runtime_mismatched_shared_channel_ode() -> equation::ODE { +fn runtime_mismatched_shared_input_ode() -> equation::ODE { equation::ODE::new( |x, p, _t, dx, _bolus, _rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); @@ -661,7 +661,7 @@ fn runtime_mismatched_shared_channel_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_channel_one_cpt_mismatched") + equation::metadata::new("shared_input_one_cpt_mismatched") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -676,11 +676,11 @@ fn runtime_mismatched_shared_channel_ode() -> equation::ODE { .expect_explicit_input(), ]), ) - .expect("mismatched shared-channel ODE metadata should validate") + .expect("mismatched shared-input ODE metadata should validate") } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_macro_analytical() -> equation::Analytical { +fn runtime_shared_input_macro_analytical() -> equation::Analytical { analytical! { name: "one_cmt_abs_shared", params: [ka, ke, v, tlag, f_oral], @@ -704,7 +704,7 @@ fn runtime_shared_channel_macro_analytical() -> equation::Analytical { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_handwritten_analytical() -> equation::Analytical { +fn runtime_shared_input_handwritten_analytical() -> equation::Analytical { equation::Analytical::new( equation::one_compartment_with_absorption, |_p, _t, _cov| {}, @@ -740,11 +740,11 @@ fn runtime_shared_channel_handwritten_analytical() -> equation::Analytical { ]) .analytical_kernel(equation::AnalyticalKernel::OneCompartmentWithAbsorption), ) - .expect("handwritten shared-channel analytical metadata should validate") + .expect("handwritten shared-input analytical metadata should validate") } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_macro_sde() -> equation::SDE { +fn runtime_shared_input_macro_sde() -> equation::SDE { sde! { name: "one_cmt_shared_sde", params: [ka, ke, sigma_ke, v, tlag, f_oral], @@ -780,7 +780,7 @@ fn runtime_shared_channel_macro_sde() -> equation::SDE { } #[cfg(feature = "dsl-jit")] -fn runtime_shared_channel_handwritten_sde() -> equation::SDE { +fn runtime_shared_input_handwritten_sde() -> equation::SDE { equation::SDE::new( |x, p, _t, dx, rateiv, _cov| { fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); @@ -830,7 +830,7 @@ fn runtime_shared_channel_handwritten_sde() -> equation::SDE { ]) .particles(8), ) - .expect("handwritten shared-channel SDE metadata should validate") + .expect("handwritten shared-input SDE metadata should validate") } #[cfg(feature = "dsl-jit")] @@ -1196,11 +1196,11 @@ fn invalid_dsl_infusion_route_properties_fail_explicitly() { #[cfg(feature = "dsl-jit")] #[test] -fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { +fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { let runtime_model = - compile_runtime_jit_model(ODE_RUNTIME_SHARED_CHANNEL_DSL, "shared_channel_one_cpt"); - let macro_model = runtime_shared_channel_macro_ode(); - let handwritten_model = runtime_shared_channel_handwritten_ode(); + compile_runtime_jit_model(ODE_RUNTIME_SHARED_INPUT_DSL, "shared_input_one_cpt"); + let macro_model = runtime_shared_input_macro_ode(); + let handwritten_model = runtime_shared_input_handwritten_ode(); let oral = runtime_model .route_index("oral") @@ -1211,7 +1211,7 @@ fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(); + let subject = shared_input_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); @@ -1246,11 +1246,11 @@ fn ode_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha #[cfg(feature = "dsl-jit")] #[test] -fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { +fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { let runtime_model = - compile_runtime_jit_model(ANALYTICAL_RUNTIME_SHARED_CHANNEL_DSL, "one_cmt_abs_shared"); - let macro_model = runtime_shared_channel_macro_analytical(); - let handwritten_model = runtime_shared_channel_handwritten_analytical(); + compile_runtime_jit_model(ANALYTICAL_RUNTIME_SHARED_INPUT_DSL, "one_cmt_abs_shared"); + let macro_model = runtime_shared_input_macro_analytical(); + let handwritten_model = runtime_shared_input_handwritten_analytical(); let oral = runtime_model .route_index("oral") @@ -1261,7 +1261,7 @@ fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_chan let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(); + let subject = shared_input_prediction_subject(); let support_point = [1.1, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); @@ -1298,11 +1298,11 @@ fn analytical_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_chan #[cfg(feature = "dsl-jit")] #[test] -fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_shape() { +fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_input_shape() { let runtime_model = - compile_runtime_jit_model(SDE_RUNTIME_SHARED_CHANNEL_DSL, "one_cmt_shared_sde"); - let macro_model = runtime_shared_channel_macro_sde(); - let handwritten_model = runtime_shared_channel_handwritten_sde(); + compile_runtime_jit_model(SDE_RUNTIME_SHARED_INPUT_DSL, "one_cmt_shared_sde"); + let macro_model = runtime_shared_input_macro_sde(); + let handwritten_model = runtime_shared_input_handwritten_sde(); let oral = runtime_model .route_index("oral") @@ -1313,7 +1313,7 @@ fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(); + let subject = shared_input_prediction_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); @@ -1350,8 +1350,8 @@ fn sde_runtime_jit_macro_and_handwritten_predictions_agree_on_shared_channel_sha #[test] fn route_input_policy_runtime_mismatches_are_detected_explicitly() { let runtime_model = - compile_runtime_jit_model(ODE_RUNTIME_SHARED_CHANNEL_DSL, "shared_channel_one_cpt"); - let mismatched_model = runtime_mismatched_shared_channel_ode(); + compile_runtime_jit_model(ODE_RUNTIME_SHARED_INPUT_DSL, "shared_input_one_cpt"); + let mismatched_model = runtime_mismatched_shared_input_ode(); let oral = runtime_model .route_index("oral") @@ -1362,7 +1362,7 @@ fn route_input_policy_runtime_mismatches_are_detected_explicitly() { let cp = runtime_model .output_index("cp") .expect("runtime cp output should exist"); - let subject = shared_channel_prediction_subject(); + let subject = shared_input_prediction_subject(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(oral, 0); diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index a556f428..480f7e80 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -18,8 +18,8 @@ fn subject_for_route(input: impl ToString, outeq: impl ToString) -> Subject { .build() } -fn subject_for_shared_channel() -> Subject { - Subject::builder("macro-shared-channel") +fn subject_for_shared_input() -> Subject { + Subject::builder("macro-shared-input") .bolus(0.0, 100.0, "oral") .infusion(6.0, 60.0, "iv", 2.0) .missing_observation(0.5, "cp") @@ -197,9 +197,9 @@ fn numeric_label_handwritten_ode() -> equation::ODE { .expect("handwritten numeric-label metadata should validate") } -fn shared_channel_macro_ode() -> equation::ODE { +fn shared_input_macro_ode() -> equation::ODE { ode! { - name: "shared_channel_one_cpt", + name: "shared_input_one_cpt", params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [cp], @@ -223,7 +223,7 @@ fn shared_channel_macro_ode() -> equation::ODE { } } -fn shared_channel_handwritten_ode() -> equation::ODE { +fn shared_input_handwritten_ode() -> equation::ODE { equation::ODE::new( |x, p, _t, dx, bolus, rateiv, _cov| { fetch_params!(p, ka, ke, _v, _tlag, _f_oral); @@ -248,7 +248,7 @@ fn shared_channel_handwritten_ode() -> equation::ODE { .with_ndrugs(1) .with_nout(1) .with_metadata( - equation::metadata::new("shared_channel_one_cpt") + equation::metadata::new("shared_input_one_cpt") .parameters(["ka", "ke", "v", "tlag", "f_oral"]) .states(["depot", "central"]) .outputs(["cp"]) @@ -263,7 +263,7 @@ fn shared_channel_handwritten_ode() -> equation::ODE { .expect_explicit_input(), ]), ) - .expect("handwritten shared-channel metadata should validate") + .expect("handwritten shared-input metadata should validate") } fn numeric_route_property_macro_ode() -> equation::ODE { @@ -526,10 +526,10 @@ fn macro_numeric_labels_lower_to_dense_slots() { } #[test] -fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { - let macro_ode = shared_channel_macro_ode(); - let handwritten_ode = shared_channel_handwritten_ode(); - let subject = subject_for_shared_channel(); +fn macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_ode = shared_input_macro_ode(); + let handwritten_ode = shared_input_handwritten_ode(); + let subject = subject_for_shared_input(); let support_point = [1.0, 0.2, 10.0, 0.25, 0.8]; assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); @@ -541,12 +541,12 @@ fn macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() let macro_predictions = macro_ode .estimate_predictions(&subject, &support_point) - .expect("macro shared-channel model should simulate") + .expect("macro shared-input model should simulate") .flat_predictions() .to_vec(); let handwritten_predictions = handwritten_ode .estimate_predictions(&subject, &support_point) - .expect("handwritten shared-channel model should simulate") + .expect("handwritten shared-input model should simulate") .flat_predictions() .to_vec(); diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 13d21a2b..05b5cb27 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -20,7 +20,7 @@ fn oral_subject(input: impl ToString, outeq: impl ToString) -> Subject { .build() } -fn shared_channel_subject() -> Subject { +fn shared_input_subject() -> Subject { Subject::builder("sde-macro-shared") .bolus(0.0, 100.0, "oral") .infusion(6.0, 60.0, "iv", 2.0) @@ -211,7 +211,7 @@ fn handwritten_absorption_sde() -> equation::SDE { .expect("handwritten absorption SDE metadata should validate") } -fn macro_shared_channel_sde() -> equation::SDE { +fn macro_shared_input_sde() -> equation::SDE { sde! { name: "one_cmt_shared_sde", params: [ka, ke, sigma_ke, v, tlag, f_oral], @@ -246,7 +246,7 @@ fn macro_shared_channel_sde() -> equation::SDE { } } -fn handwritten_shared_channel_sde() -> equation::SDE { +fn handwritten_shared_input_sde() -> equation::SDE { equation::SDE::new( |x, p, _t, dx, rateiv, _cov| { fetch_params!(p, ka, ke, _sigma_ke, _v, _tlag, _f_oral); @@ -296,7 +296,7 @@ fn handwritten_shared_channel_sde() -> equation::SDE { ]) .particles(8), ) - .expect("handwritten shared-channel SDE metadata should validate") + .expect("handwritten shared-input SDE metadata should validate") } fn macro_covariate_sde() -> equation::SDE { @@ -538,10 +538,10 @@ fn sde_macro_supports_lag_fa_init_and_named_sigma_bindings() { } #[test] -fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_predictions() { - let macro_model = macro_shared_channel_sde(); - let handwritten_model = handwritten_shared_channel_sde(); - let subject = shared_channel_subject(); +fn sde_macro_shared_input_lowering_matches_handwritten_metadata_and_predictions() { + let macro_model = macro_shared_input_sde(); + let handwritten_model = handwritten_shared_input_sde(); + let subject = shared_input_subject(); let support_point = [1.1, 0.2, 0.0, 10.0, 0.25, 0.8]; assert_eq!(macro_model.metadata(), handwritten_model.metadata()); @@ -553,10 +553,10 @@ fn sde_macro_shared_channel_lowering_matches_handwritten_metadata_and_prediction let macro_predictions = macro_model .estimate_predictions(&subject, &support_point) - .expect("macro shared-channel SDE should simulate"); + .expect("macro shared-input SDE should simulate"); let handwritten_predictions = handwritten_model .estimate_predictions(&subject, &support_point) - .expect("handwritten shared-channel SDE should simulate"); + .expect("handwritten shared-input SDE should simulate"); assert_prediction_match( &prediction_means(¯o_predictions), From dac11081f25ef10c386479f9431d8fffa86155a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 15:22:09 +0100 Subject: [PATCH 04/12] fix: The new string-based mapper for inputs/outeq had an issue when ordering indices, lso the implementation was not complete over the DSL frontend --- pharmsol-dsl/src/authoring.rs | 44 +++- pharmsol-dsl/src/parser.rs | 10 +- src/dsl/aot.rs | 16 +- src/dsl/jit.rs | 122 +++++++--- src/dsl/native.rs | 23 +- src/dsl/runtime.rs | 285 +++++++++++++++++++++-- tests/authoring_parity_corpus.rs | 238 +++++++++++++++++++ tests/full_feature_dsl_backend_parity.rs | 201 ++++++++++++++++ tests/support/bimodal_ke.rs | 23 +- tests/support/runtime_corpus.rs | 102 ++++---- 10 files changed, 924 insertions(+), 140 deletions(-) create mode 100644 tests/full_feature_dsl_backend_parity.rs diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index 129f07c8..7b0b4dd6 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -20,6 +20,7 @@ struct AuthoringParser<'a> { states: Vec, declared_derived: BTreeSet, declared_outputs: BTreeSet, + explicit_output_order: Vec, explicit_outputs: BTreeMap, assigned_outputs: BTreeMap, declared_outputs_span: Option, @@ -77,6 +78,7 @@ impl<'a> AuthoringParser<'a> { states: Vec::new(), declared_derived: BTreeSet::new(), declared_outputs: BTreeSet::new(), + explicit_output_order: Vec::new(), explicit_outputs: BTreeMap::new(), assigned_outputs: BTreeMap::new(), declared_outputs_span: None, @@ -175,6 +177,20 @@ impl<'a> AuthoringParser<'a> { )); } + if !self.explicit_output_order.is_empty() { + let output_order = self + .explicit_output_order + .iter() + .enumerate() + .map(|(index, name)| (name.clone(), index)) + .collect::>(); + self.output_statements.sort_by_key(|statement| { + output_statement_name(statement) + .and_then(|name| output_order.get(name).copied()) + .unwrap_or(usize::MAX) + }); + } + let mut derivative_statements = std::mem::take(&mut self.derivative_statements); inject_infusion_rates(&surface_routes, &routes, &mut derivative_statements); @@ -372,6 +388,7 @@ impl<'a> AuthoringParser<'a> { if lhs_trimmed == "outputs" { self.declared_outputs_span = Some(span); for ident in parse_output_label_list(rhs, rhs_abs)? { + self.explicit_output_order.push(ident.text.clone()); self.declared_outputs.insert(ident.text.clone()); self.explicit_outputs.insert(ident.text, ident.span); } @@ -467,7 +484,7 @@ impl<'a> AuthoringParser<'a> { } }; - let input = parse_ident_segment(call.argument, call.argument_start)?; + let input = parse_label_segment(call.argument, call.argument_start, "route label")?; let route_name = input.text.clone(); let destination = parse_place_at(rhs, line_start + arrow + 2)?; if self.routes.contains_key(&route_name) { @@ -498,7 +515,8 @@ impl<'a> AuthoringParser<'a> { ) -> Result<(), ParseError> { match call.callee.text.as_str() { "lag" | "fa" => { - let route_name = parse_ident_segment(call.argument, call.argument_start)?; + let route_name = + parse_label_segment(call.argument, call.argument_start, "route label")?; let value = parse_expr_at(rhs, rhs_abs)?; let property_name = match call.callee.text.as_str() { "lag" => "lag", @@ -928,17 +946,25 @@ fn parse_ident_segment(src: &str, abs_start: usize) -> Result } fn parse_output_label_segment(src: &str, abs_start: usize) -> Result { + parse_label_segment(src, abs_start, "output label") +} + +fn parse_label_segment( + src: &str, + abs_start: usize, + expected: &str, +) -> Result { let trimmed = src.trim(); let leading = src.len() - src.trim_start().len(); if trimmed.is_empty() { return Err(ParseError::new( - "expected output label", + format!("expected {expected}"), Span::new(abs_start, abs_start + src.len()), )); } if !is_valid_output_label(trimmed) { return Err(ParseError::new( - format!("expected output label, found `{trimmed}`"), + format!("expected {expected}, found `{trimmed}`"), Span::new(abs_start + leading, abs_start + leading + trimmed.len()), )); } @@ -1417,6 +1443,16 @@ fn join_covariate_spans(items: &[CovariateDecl]) -> Span { .unwrap_or_else(|| Span::empty(0)) } +fn output_statement_name(statement: &Stmt) -> Option<&str> { + match &statement.kind { + StmtKind::Assign(assign) => match &assign.target.kind { + AssignTargetKind::Name(name) => Some(name.text.as_str()), + _ => None, + }, + _ => None, + } +} + fn join_state_spans(items: &[StateDecl]) -> Span { items .iter() diff --git a/pharmsol-dsl/src/parser.rs b/pharmsol-dsl/src/parser.rs index c265b4df..fe844c37 100644 --- a/pharmsol-dsl/src/parser.rs +++ b/pharmsol-dsl/src/parser.rs @@ -563,7 +563,7 @@ impl Parser { } fn parse_route_decl(&mut self) -> Result { - let input = self.parse_ident()?; + let input = self.parse_label_name("route label")?; let arrow = self.expect_simple(|kind| matches!(kind, TokenKind::Arrow), "`->`")?; self.ensure_not_layout_boundary( arrow.span, @@ -902,9 +902,13 @@ impl Parser { } fn parse_output_target_name(&mut self) -> Result { + self.parse_label_name("output label") + } + + fn parse_label_name(&mut self, expected: &str) -> Result { let token = self .bump() - .ok_or_else(|| ParseError::new("expected output label", Span::empty(self.src_len)))?; + .ok_or_else(|| ParseError::new(format!("expected {expected}"), Span::empty(self.src_len)))?; match token.kind { TokenKind::Ident(name) => Ok(Ident::new(name, token.span)), TokenKind::Number(value) @@ -917,7 +921,7 @@ impl Parser { } other => Err(ParseError::new( format!( - "expected output label identifier or non-negative integer, found {}", + "expected {expected} identifier or non-negative integer, found {}", other.describe() ), token.span, diff --git a/src/dsl/aot.rs b/src/dsl/aot.rs index 3749f183..2a46409a 100644 --- a/src/dsl/aot.rs +++ b/src/dsl/aot.rs @@ -543,14 +543,14 @@ mod tests { let subject = crate::Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build(); let support = vec![1.2, 5.0, 40.0, 0.5, 0.8]; diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index 5504ab08..a440c51d 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -1360,21 +1360,33 @@ out(cp) = central / v ~ continuous() let cp = jit.output_index("cp").expect("cp output"); assert_eq!(oral, 0); assert_eq!(iv, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("ode") + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .observation(0.5, 0.0, "cp") + .observation(1.0, 0.0, "cp") + .observation(2.0, 0.0, "cp") + .observation(6.0, 0.0, "cp") + .observation(7.0, 0.0, "cp") + .observation(9.0, 0.0, "cp") + .build(); - let subject = Subject::builder("ode") - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .observation(0.5, 0.0, cp) - .observation(1.0, 0.0, cp) - .observation(2.0, 0.0, cp) - .observation(6.0, 0.0, cp) - .observation(7.0, 0.0, cp) - .observation(9.0, 0.0, cp) + let reference_subject = Subject::builder("ode") + .bolus(0.0, 120.0, 0) + .infusion(6.0, 60.0, 0, 2.0) + .observation(0.5, 0.0, 0) + .observation(1.0, 0.0, 0) + .observation(2.0, 0.0, 0) + .observation(6.0, 0.0, 0) + .observation(7.0, 0.0, 0) + .observation(9.0, 0.0, 0) .build(); let support = vec![1.2, 0.15, 40.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit predictions"); let reference = ODE::new( @@ -1397,7 +1409,7 @@ out(cp) = central / v ~ continuous() .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference ode predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1491,22 +1503,35 @@ out(cp) = central / v ~ continuous() let cp = jit.output_index("cp").expect("cp output"); assert_eq!(oral, 0); assert_eq!(iv, 1); + assert_eq!(cp, 0); - let subject = Subject::builder("ode") + let jit_subject = Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") + .build(); + + let reference_subject = Subject::builder("ode") + .covariate("wt", 0.0, 70.0) + .bolus(0.0, 120.0, 0) + .infusion(6.0, 60.0, 1, 2.0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(6.0, 0) + .missing_observation(7.0, 0) + .missing_observation(9.0, 0) .build(); let support = vec![1.2, 5.0, 40.0, 0.5, 0.8]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit predictions"); let reference = ODE::new( @@ -1551,7 +1576,7 @@ out(cp) = central / v ~ continuous() .with_solver(OdeSolver::ExplicitRk(ExplicitRkTableau::Tsit45)); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference ode predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1574,18 +1599,28 @@ out(cp) = central / v ~ continuous() let oral = jit.route_index("oral").expect("oral route"); let cp = jit.output_index("cp").expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("analytical") + .bolus(0.0, 100.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); - let subject = Subject::builder("analytical") - .bolus(0.0, 100.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + let reference_subject = Subject::builder("analytical") + .bolus(0.0, 100.0, 0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(4.0, 0) .build(); let support = vec![1.0, 0.15, 25.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit analytical predictions"); let reference = equation::Analytical::new( @@ -1603,7 +1638,7 @@ out(cp) = central / v ~ continuous() .with_nout(1); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference analytical predictions"); for (jit_pred, reference_pred) in jit_predictions @@ -1628,19 +1663,30 @@ out(cp) = central / v ~ continuous() let oral = jit.route_index("oral").expect("oral route"); let cp = jit.output_index("cp").expect("cp output"); + assert_eq!(oral, 0); + assert_eq!(cp, 0); + + let jit_subject = Subject::builder("sde") + .covariate("wt", 0.0, 70.0) + .bolus(0.0, 80.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .build(); - let subject = Subject::builder("sde") + let reference_subject = Subject::builder("sde") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 80.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 80.0, 0) + .missing_observation(0.5, 0) + .missing_observation(1.0, 0) + .missing_observation(2.0, 0) + .missing_observation(4.0, 0) .build(); let support = vec![1.1, 0.2, 0.12, 0.08, 15.0, 0.0]; let jit_predictions = jit - .estimate_predictions(&subject, &support) + .estimate_predictions(&jit_subject, &support) .expect("jit sde predictions"); let reference = SDE::new( @@ -1677,7 +1723,7 @@ out(cp) = central / v ~ continuous() .with_nout(1); let reference_predictions = reference - .estimate_predictions(&subject, &support) + .estimate_predictions(&reference_subject, &support) .expect("reference sde predictions"); for (jit_pred, reference_pred) in jit_predictions diff --git a/src/dsl/native.rs b/src/dsl/native.rs index c1ce8eac..6df2f05d 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -402,13 +402,8 @@ impl SharedNativeModel { label: &InputLabel, kind: RouteKind, ) -> Result { - if let Some(input) = self.route_index(label.as_str()) { - self.validate_input_for_kind(input, kind)?; - return Ok(input); - } - - let input = label - .index() + let input = self + .route_index(label.as_str()) .ok_or_else(|| PharmsolError::UnknownInputLabel { label: label.to_string(), })?; @@ -417,17 +412,11 @@ impl SharedNativeModel { } fn resolve_output_label(&self, label: &OutputLabel) -> Result { - if let Some(outeq) = self.output_index(label.as_str()) { - return Ok(outeq); - } - - let outeq = label - .index() - .ok_or_else(|| PharmsolError::UnknownOutputLabel { + self.output_index(label.as_str()).ok_or_else(|| { + PharmsolError::UnknownOutputLabel { label: label.to_string(), - })?; - self.validate_output(outeq)?; - Ok(outeq) + } + }) } fn resolve_events(&self, occasion: &Occasion) -> Result, PharmsolError> { diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index 1d49d82a..1d8a1327 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -376,12 +376,96 @@ fn runtime_model_from_parts( mod tests { use super::*; use crate::dsl::compile_sde_model_to_jit; + use crate::PharmsolError; use crate::test_fixtures::STRUCTURED_BLOCK_CORPUS; use crate::SubjectBuilderExt; use approx::assert_relative_eq; use pharmsol_dsl::{DiagnosticPhase, DSL_BACKEND_GENERIC, DSL_PARSE_GENERIC}; use tempfile::tempdir; + const MULTI_DIGIT_OUTPUT_ORDER_RUNTIME_DSL: &str = r#" +name = multi_digit_output_runtime +kind = ode + +params = ke, v +states = central +outputs = 2, 10, 11 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(10) = central / v ~ continuous() +out(2) = central / v ~ continuous() +out(11) = central / v ~ continuous() +"#; + + const NUMERIC_ROUTE_LABELS_RUNTIME_DSL: &str = r#" +name = numeric_route_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(10) -> central +bolus(11) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + + const UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" +name = undeclared_numeric_output_runtime +kind = ode + +params = ke, v +states = central +outputs = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(a0) = central / v ~ continuous() +out(a1) = central / v ~ continuous() +out(a2) = central / v ~ continuous() +out(a3) = central / v ~ continuous() +out(a4) = central / v ~ continuous() +out(a5) = central / v ~ continuous() +out(a6) = central / v ~ continuous() +out(a7) = central / v ~ continuous() +out(a8) = central / v ~ continuous() +out(a9) = central / v ~ continuous() +out(a10) = central / v ~ continuous() +"#; + + const UNDECLARED_NUMERIC_INPUT_LABEL_RUNTIME_DSL: &str = r#" +name = undeclared_numeric_input_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(r0) -> central +bolus(r1) -> central +bolus(r2) -> central +bolus(r3) -> central +bolus(r4) -> central +bolus(r5) -> central +bolus(r6) -> central +bolus(r7) -> central +bolus(r8) -> central +bolus(r9) -> central +bolus(r10) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + fn corpus_source() -> &'static str { STRUCTURED_BLOCK_CORPUS } @@ -397,17 +481,17 @@ mod tests { pharmsol_dsl::lower_typed_model(model).expect("lower corpus model") } - fn ode_subject(output: usize, oral: usize, iv: usize) -> Subject { + fn ode_subject() -> Subject { Subject::builder("ode") .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, output) - .missing_observation(1.0, output) - .missing_observation(2.0, output) - .missing_observation(6.0, output) - .missing_observation(7.0, output) - .missing_observation(9.0, output) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build() } @@ -421,6 +505,80 @@ mod tests { .collect() } + fn compile_runtime_backend_matrix( + source: &str, + model_name: &str, + work_dir: &std::path::Path, + ) -> (CompiledRuntimeModel, CompiledRuntimeModel, CompiledRuntimeModel) { + let jit = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::Jit, + |_, _| {}, + ) + .expect("compile jit runtime model"); + let aot = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::NativeAot( + NativeAotCompileOptions::new(work_dir.join(format!("{model_name}-aot-build"))) + .with_output(work_dir.join(format!("{model_name}.pkm"))), + ), + |_, _| {}, + ) + .expect("compile aot runtime model"); + let wasm = compile_module_source_to_runtime( + source, + Some(model_name), + RuntimeCompilationTarget::Wasm, + |_, _| {}, + ) + .expect("compile wasm runtime model"); + + (jit, aot, wasm) + } + + fn numeric_route_subject() -> Subject { + Subject::builder("numeric-route-runtime") + .bolus(0.0, 120.0, "10") + .bolus(1.0, 80.0, "11") + .missing_observation(0.5, "cp") + .missing_observation(1.5, "cp") + .build() + } + + fn assert_unknown_output_label( + model: &CompiledRuntimeModel, + subject: &Subject, + support: &[f64], + expected_label: &str, + ) { + let error = model + .estimate_predictions(subject, support) + .expect_err("undeclared numeric output label should fail"); + + assert!(matches!( + error, + RuntimeError::Runtime(PharmsolError::UnknownOutputLabel { label }) if label == expected_label + )); + } + + fn assert_unknown_input_label( + model: &CompiledRuntimeModel, + subject: &Subject, + support: &[f64], + expected_label: &str, + ) { + let error = model + .estimate_predictions(subject, support) + .expect_err("undeclared numeric input label should fail"); + + assert!(matches!( + error, + RuntimeError::Runtime(PharmsolError::UnknownInputLabel { label }) if label == expected_label + )); + } + #[test] fn runtime_backend_matrix_matches_ode_predictions() { let work_dir = tempdir().expect("tempdir"); @@ -460,10 +618,73 @@ mod tests { vec!["ka", "cl", "v", "tlag", "f_oral"] ); - let oral = jit.route_index("oral").expect("oral route"); - let iv = jit.route_index("iv").expect("iv route"); - let cp = jit.output_index("cp").expect("cp output"); - let subject = ode_subject(cp, oral, iv); + assert!(jit.route_index("oral").is_some()); + assert!(jit.route_index("iv").is_some()); + assert_eq!(jit.output_index("cp"), Some(0)); + let subject = ode_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + + #[test] + fn runtime_backend_matrix_preserves_multi_digit_output_label_order() { + let work_dir = tempdir().expect("tempdir"); + let (jit, aot, wasm) = compile_runtime_backend_matrix( + MULTI_DIGIT_OUTPUT_ORDER_RUNTIME_DSL, + "multi_digit_output_runtime", + work_dir.path(), + ); + + assert_eq!(jit.output_index("2"), Some(0)); + assert_eq!(jit.output_index("10"), Some(1)); + assert_eq!(jit.output_index("11"), Some(2)); + assert_eq!(aot.output_index("2"), Some(0)); + assert_eq!(aot.output_index("10"), Some(1)); + assert_eq!(aot.output_index("11"), Some(2)); + assert_eq!(wasm.output_index("2"), Some(0)); + assert_eq!(wasm.output_index("10"), Some(1)); + assert_eq!(wasm.output_index("11"), Some(2)); + } + + #[test] + fn runtime_backend_matrix_supports_multi_digit_numeric_route_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + NUMERIC_ROUTE_LABELS_RUNTIME_DSL, + "numeric_route_runtime", + work_dir.path(), + ); + + assert_eq!(jit.route_index("10"), Some(0)); + assert_eq!(jit.route_index("11"), Some(1)); + assert_eq!(aot.route_index("10"), Some(0)); + assert_eq!(aot.route_index("11"), Some(1)); + assert_eq!(wasm.route_index("10"), Some(0)); + assert_eq!(wasm.route_index("11"), Some(1)); + + let subject = numeric_route_subject(); let jit_values = subject_values( &jit.estimate_predictions(&subject, &support) @@ -489,6 +710,44 @@ mod tests { } } + #[test] + fn runtime_backend_matrix_rejects_undeclared_numeric_output_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL, + "undeclared_numeric_output_runtime", + work_dir.path(), + ); + let subject = Subject::builder("runtime-undeclared-numeric-output") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "10") + .build(); + + assert_unknown_output_label(&jit, &subject, &support, "10"); + assert_unknown_output_label(&aot, &subject, &support, "10"); + assert_unknown_output_label(&wasm, &subject, &support, "10"); + } + + #[test] + fn runtime_backend_matrix_rejects_undeclared_numeric_input_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + UNDECLARED_NUMERIC_INPUT_LABEL_RUNTIME_DSL, + "undeclared_numeric_input_runtime", + work_dir.path(), + ); + let subject = Subject::builder("runtime-undeclared-numeric-input") + .bolus(0.0, 100.0, "10") + .missing_observation(0.5, "cp") + .build(); + + assert_unknown_input_label(&jit, &subject, &support, "10"); + assert_unknown_input_label(&aot, &subject, &support, "10"); + assert_unknown_input_label(&wasm, &subject, &support, "10"); + } + #[test] fn runtime_compile_preserves_parse_diagnostic_structure() { let source = "model broken { kind ode outputs { cp = 1 + } }"; diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index 37a5891a..c7164d71 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -53,6 +53,59 @@ dx(central) = ka * depot - (cl / v) * central out(cp) = central / (v * (wt / 70.0)) ~ continuous() "#; +const ODE_MULTI_DIGIT_OUTPUT_ORDER_DSL: &str = r#" +name = multi_digit_output_order +kind = ode + +params = ke, v +states = central +outputs = 2, 10, 11 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(10) = central / v ~ continuous() +out(2) = central / v ~ continuous() +out(11) = central / v ~ continuous() +"#; + +const ODE_NUMERIC_ROUTE_LABELS_AUTHORING_DSL: &str = r#" +name = authoring_numeric_routes +kind = ode + +states = first, second +outputs = cp + +bolus(10) -> first +bolus(11) -> second + +dx(first) = 0 +dx(second) = 0 + +out(cp) = first + second ~ continuous() +"#; + +const ODE_NUMERIC_ROUTE_LABELS_STRUCTURED_DSL: &str = r#"model structured_numeric_routes { + kind ode + states { + first, + second, + } + routes { + 10 -> first + 11 -> second + } + dynamics { + ddt(first) = 0 + ddt(second) = 0 + } + outputs { + cp = first + second + } +} +"#; + const ODE_INVALID_INFUSION_LAG_DSL: &str = r#" name = invalid_infusion_lag_parity kind = ode @@ -107,6 +160,58 @@ out(0) = 2 * central / v ~ continuous() out(1) = 3 * central / v ~ continuous() "#; +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_UNDECLARED_NUMERIC_OUTPUT_LABEL_DSL: &str = r#" +name = undeclared_numeric_output_runtime +kind = ode + +params = ke, v +states = central +outputs = a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, a10 + +infusion(iv) -> central + +dx(central) = -ke * central + +out(a0) = central / v ~ continuous() +out(a1) = central / v ~ continuous() +out(a2) = central / v ~ continuous() +out(a3) = central / v ~ continuous() +out(a4) = central / v ~ continuous() +out(a5) = central / v ~ continuous() +out(a6) = central / v ~ continuous() +out(a7) = central / v ~ continuous() +out(a8) = central / v ~ continuous() +out(a9) = central / v ~ continuous() +out(a10) = central / v ~ continuous() +"#; + +#[cfg(feature = "dsl-jit")] +const ODE_RUNTIME_UNDECLARED_NUMERIC_INPUT_LABEL_DSL: &str = r#" +name = undeclared_numeric_input_runtime +kind = ode + +params = ke, v +states = central +outputs = cp + +bolus(r0) -> central +bolus(r1) -> central +bolus(r2) -> central +bolus(r3) -> central +bolus(r4) -> central +bolus(r5) -> central +bolus(r6) -> central +bolus(r7) -> central +bolus(r8) -> central +bolus(r9) -> central +bolus(r10) -> central + +dx(central) = -ke * central + +out(cp) = central / v ~ continuous() +"#; + const ANALYTICAL_DSL: &str = r#" name = one_cmt_abs_parity kind = analytical @@ -1056,6 +1161,93 @@ fn ode_dsl_and_handwritten_metadata_agree_on_public_shape() { assert_eq!(handwritten_view, dsl_view); } +#[test] +fn ode_dsl_declared_output_order_controls_dense_indices_for_multi_digit_labels() { + let dsl_view = dsl_metadata_view(ODE_MULTI_DIGIT_OUTPUT_ORDER_DSL); + + assert_eq!( + dsl_view.outputs, + vec![ + NamedIndex { + name: "2".to_string(), + index: 0, + }, + NamedIndex { + name: "10".to_string(), + index: 1, + }, + NamedIndex { + name: "11".to_string(), + index: 2, + }, + ] + ); +} + +#[test] +fn ode_authoring_dsl_supports_multi_digit_numeric_route_labels() { + let dsl_view = dsl_metadata_view(ODE_NUMERIC_ROUTE_LABELS_AUTHORING_DSL); + + assert_eq!(dsl_view.route_input_count, 2); + assert_eq!( + dsl_view.routes, + vec![ + RouteParity { + name: "10".to_string(), + kind: Some(RouteKindParity::Bolus), + declaration_index: 0, + input_index: 0, + destination_name: "first".to_string(), + destination_index: 0, + has_lag: false, + has_bioavailability: false, + }, + RouteParity { + name: "11".to_string(), + kind: Some(RouteKindParity::Bolus), + declaration_index: 1, + input_index: 1, + destination_name: "second".to_string(), + destination_index: 1, + has_lag: false, + has_bioavailability: false, + }, + ] + ); +} + +#[test] +fn ode_structured_dsl_supports_multi_digit_numeric_route_labels() { + let dsl_view = dsl_metadata_view(ODE_NUMERIC_ROUTE_LABELS_STRUCTURED_DSL); + + assert_eq!(dsl_view.route_input_count, 2); + assert_eq!( + dsl_view.routes, + vec![ + RouteParity { + name: "10".to_string(), + kind: None, + declaration_index: 0, + input_index: 0, + destination_name: "first".to_string(), + destination_index: 0, + has_lag: false, + has_bioavailability: false, + }, + RouteParity { + name: "11".to_string(), + kind: None, + declaration_index: 1, + input_index: 1, + destination_name: "second".to_string(), + destination_index: 1, + has_lag: false, + has_bioavailability: false, + }, + ] + ); +} + #[test] fn ode_macro_dsl_and_handwritten_metadata_agree_on_macro_authorable_shape() { let handwritten = handwritten_ode_macro_model(); @@ -1418,3 +1610,49 @@ fn ode_runtime_jit_preserves_mixed_output_labels() { assert_relative_eq!(predictions[1], 2.0 * predictions[0], epsilon = 1e-6); assert_relative_eq!(predictions[2], 3.0 * predictions[0], epsilon = 1e-6); } + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_rejects_undeclared_numeric_output_labels_even_when_dense_index_exists() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_UNDECLARED_NUMERIC_OUTPUT_LABEL_DSL, + "undeclared_numeric_output_runtime", + ); + let subject = Subject::builder("runtime-undeclared-numeric-output") + .infusion(0.0, 100.0, "iv", 1.0) + .missing_observation(0.5, "10") + .build(); + let support_point = [0.2, 10.0]; + + let error = runtime_model + .estimate_predictions(&subject, &support_point) + .expect_err("undeclared numeric output label should fail"); + + assert!(matches!( + error, + dsl::RuntimeError::Runtime(PharmsolError::UnknownOutputLabel { label }) if label == "10" + )); +} + +#[cfg(feature = "dsl-jit")] +#[test] +fn ode_runtime_jit_rejects_undeclared_numeric_input_labels_even_when_dense_index_exists() { + let runtime_model = compile_runtime_jit_model( + ODE_RUNTIME_UNDECLARED_NUMERIC_INPUT_LABEL_DSL, + "undeclared_numeric_input_runtime", + ); + let subject = Subject::builder("runtime-undeclared-numeric-input") + .bolus(0.0, 100.0, "10") + .missing_observation(0.5, "cp") + .build(); + let support_point = [0.2, 10.0]; + + let error = runtime_model + .estimate_predictions(&subject, &support_point) + .expect_err("undeclared numeric input label should fail"); + + assert!(matches!( + error, + dsl::RuntimeError::Runtime(PharmsolError::UnknownInputLabel { label }) if label == "10" + )); +} diff --git a/tests/full_feature_dsl_backend_parity.rs b/tests/full_feature_dsl_backend_parity.rs new file mode 100644 index 00000000..1aba8213 --- /dev/null +++ b/tests/full_feature_dsl_backend_parity.rs @@ -0,0 +1,201 @@ +#[path = "support/runtime_corpus.rs"] +mod runtime_corpus; + +#[cfg(all(feature = "dsl-jit", feature = "dsl-wasm"))] +mod tests { + use super::runtime_corpus::{self as corpus, CorpusCase}; + use pharmsol::dsl::{CompiledRuntimeModel, RuntimeBackend}; + + fn owned_names(names: &[&str]) -> Vec { + names.iter().map(|name| (*name).to_owned()).collect() + } + + fn assert_info_matches( + left_label: &str, + left: &CompiledRuntimeModel, + right_label: &str, + right: &CompiledRuntimeModel, + ) { + assert_eq!( + left.info(), + right.info(), + "{left_label} model info diverged from {right_label}" + ); + } + + fn assert_ode_full_public_shape(model: &CompiledRuntimeModel) { + let info = model.info(); + + assert_eq!(info.name, "ode_full_feature_parity"); + assert_eq!(info.parameters, owned_names(&[ + "ka", + "ke", + "kcp", + "kpc", + "v", + "tlag", + "f_oral", + "base_depot", + "base_central", + "base_peripheral", + ])); + assert_eq!( + info.covariates + .iter() + .map(|covariate| covariate.name.as_str()) + .collect::>(), + vec!["wt", "renal"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["oral", "load", "iv"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.declaration_index) + .collect::>(), + vec![0, 1, 2] + ); + assert_eq!( + info.routes.iter().map(|route| route.index).collect::>(), + vec![0, 1, 0] + ); + assert_eq!( + info.outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp"] + ); + assert_eq!(model.route_index("oral"), Some(0)); + assert_eq!(model.route_index("load"), Some(1)); + assert_eq!(model.route_index("iv"), Some(0)); + assert_eq!(model.output_index("cp"), Some(0)); + } + + fn assert_analytical_full_public_shape(model: &CompiledRuntimeModel) { + let info = model.info(); + + assert_eq!(info.name, "analytical_full_feature_parity"); + assert_eq!(info.parameters, owned_names(&[ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ])); + assert_eq!( + info.covariates + .iter() + .map(|covariate| covariate.name.as_str()) + .collect::>(), + vec!["wt", "renal"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["oral", "load", "iv"] + ); + assert_eq!( + info.routes + .iter() + .map(|route| route.declaration_index) + .collect::>(), + vec![0, 1, 2] + ); + assert_eq!( + info.routes.iter().map(|route| route.index).collect::>(), + vec![0, 1, 0] + ); + assert_eq!( + info.outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["cp"] + ); + assert_eq!(model.route_index("oral"), Some(0)); + assert_eq!(model.route_index("load"), Some(1)); + assert_eq!(model.route_index("iv"), Some(0)); + assert_eq!(model.output_index("cp"), Some(0)); + } + + fn assert_full_backend_parity( + case: CorpusCase, + assert_public_shape: fn(&CompiledRuntimeModel), + ) -> Result<(), Box> { + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let workspace = super::runtime_corpus::ArtifactWorkspace::new()?; + + let jit = corpus::compile_runtime_jit_model(case)?; + assert_eq!(jit.backend(), RuntimeBackend::Jit); + assert_public_shape(&jit); + corpus::assert_runtime_model_matches_reference(case, "runtime-jit", &jit)?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + let aot = corpus::compile_runtime_native_aot_model(case, &workspace)?; + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + { + assert_eq!(aot.backend(), RuntimeBackend::NativeAot); + assert_public_shape(&aot); + corpus::assert_runtime_model_matches_reference(case, "runtime-native-aot", &aot)?; + } + + let wasm = corpus::compile_runtime_wasm_model(case)?; + assert_eq!(wasm.backend(), RuntimeBackend::Wasm); + assert_public_shape(&wasm); + corpus::assert_runtime_model_matches_reference(case, "runtime-wasm", &wasm)?; + + assert_info_matches("runtime-jit", &jit, "runtime-wasm", &wasm); + corpus::assert_runtime_models_match_each_other( + case, + "runtime-jit", + &jit, + "runtime-wasm", + &wasm, + )?; + + #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] + { + assert_info_matches("runtime-jit", &jit, "runtime-native-aot", &aot); + assert_info_matches("runtime-native-aot", &aot, "runtime-wasm", &wasm); + corpus::assert_runtime_models_match_each_other( + case, + "runtime-jit", + &jit, + "runtime-native-aot", + &aot, + )?; + corpus::assert_runtime_models_match_each_other( + case, + "runtime-native-aot", + &aot, + "runtime-wasm", + &wasm, + )?; + } + + Ok(()) + } + + #[test] + fn ode_full_feature_dsl_matches_handwritten_across_backends( + ) -> Result<(), Box> { + assert_full_backend_parity(CorpusCase::OdeFull, assert_ode_full_public_shape) + } + + #[test] + fn analytical_full_feature_dsl_matches_handwritten_across_backends( + ) -> Result<(), Box> { + assert_full_backend_parity(CorpusCase::AnalyticalFull, assert_analytical_full_public_shape) + } +} \ No newline at end of file diff --git a/tests/support/bimodal_ke.rs b/tests/support/bimodal_ke.rs index 4c82be4f..6e7e5f8e 100644 --- a/tests/support/bimodal_ke.rs +++ b/tests/support/bimodal_ke.rs @@ -55,6 +55,14 @@ fn subject_for_indices(route_index: usize, output_index: usize) -> Subject { builder.build() } +fn subject_for_labels(route_label: &str, output_label: &str) -> Subject { + let mut builder = Subject::builder(MODEL_NAME).infusion(0.0, 500.0, route_label, 0.5); + for time in OBSERVATION_TIMES { + builder = builder.missing_observation(time, output_label); + } + builder.build() +} + pub fn subject() -> Subject { subject_for_indices(0, 0) } @@ -65,12 +73,15 @@ pub fn subject() -> Subject { feature = "dsl-wasm" ))] pub fn subject_for_runtime_model(model: &pharmsol::dsl::CompiledRuntimeModel) -> Subject { - let route_index = model - .route_index("iv") - .or_else(|| model.route_index("input_0")) - .expect("bimodal_ke route is available"); - let output_index = model.output_index("cp").expect("cp output is available"); - subject_for_indices(route_index, output_index) + let route_label = if model.route_index("iv").is_some() { + "iv" + } else if model.route_index("input_0").is_some() { + "input_0" + } else { + panic!("bimodal_ke route is available"); + }; + model.output_index("cp").expect("cp output is available"); + subject_for_labels(route_label, "cp") } pub fn reference_values() -> Result, Box> { diff --git a/tests/support/runtime_corpus.rs b/tests/support/runtime_corpus.rs index 3ed75511..1ca8ae78 100644 --- a/tests/support/runtime_corpus.rs +++ b/tests/support/runtime_corpus.rs @@ -208,52 +208,52 @@ impl CorpusCase { } fn runtime_subject(self, model: &CompiledRuntimeModel) -> Result> { - let cp = model + model .output_index("cp") .ok_or_else(|| io::Error::other(format!("{}: missing cp output", self.label())))?; let subject = match self { Self::Ode => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; - let iv = model.route_index("iv").ok_or_else(|| { + model.route_index("iv").ok_or_else(|| { io::Error::other(format!("{}: missing iv route", self.label())) })?; Subject::builder(self.label()) .covariate("wt", 0.0, 70.0) - .bolus(0.0, 120.0, oral) - .infusion(6.0, 60.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(6.0, cp) - .missing_observation(7.0, cp) - .missing_observation(9.0, cp) + .bolus(0.0, 120.0, "oral") + .infusion(6.0, 60.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(7.0, "cp") + .missing_observation(9.0, "cp") .build() } Self::OdeFull => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; - let load = model.route_index("load").ok_or_else(|| { + model.route_index("load").ok_or_else(|| { io::Error::other(format!("{}: missing load route", self.label())) })?; - let iv = model.route_index("iv").ok_or_else(|| { + model.route_index("iv").ok_or_else(|| { io::Error::other(format!("{}: missing iv route", self.label())) })?; Subject::builder(self.label()) - .bolus(0.0, 80.0, load) - .bolus(1.0, 120.0, oral) - .infusion(6.0, 150.0, iv, 2.5) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 80.0, "load") + .bolus(1.0, 120.0, "oral") + .infusion(6.0, 150.0, "iv", 2.5) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -261,39 +261,39 @@ impl CorpusCase { .build() } Self::Analytical => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; Subject::builder(self.label()) - .bolus(0.0, 100.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 100.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build() } Self::AnalyticalFull => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; - let load = model.route_index("load").ok_or_else(|| { + model.route_index("load").ok_or_else(|| { io::Error::other(format!("{}: missing load route", self.label())) })?; - let iv = model.route_index("iv").ok_or_else(|| { + model.route_index("iv").ok_or_else(|| { io::Error::other(format!("{}: missing iv route", self.label())) })?; Subject::builder(self.label()) - .bolus(0.0, 60.0, load) - .bolus(1.0, 100.0, oral) - .infusion(6.0, 140.0, iv, 2.0) - .missing_observation(0.25, cp) - .missing_observation(0.75, cp) - .missing_observation(1.5, cp) - .missing_observation(3.0, cp) - .missing_observation(6.5, cp) - .missing_observation(7.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) + .bolus(0.0, 60.0, "load") + .bolus(1.0, 100.0, "oral") + .infusion(6.0, 140.0, "iv", 2.0) + .missing_observation(0.25, "cp") + .missing_observation(0.75, "cp") + .missing_observation(1.5, "cp") + .missing_observation(3.0, "cp") + .missing_observation(6.5, "cp") + .missing_observation(7.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") .covariate("wt", 0.0, 68.0) .covariate("wt", 8.0, 74.0) .covariate("renal", 0.0, 95.0) @@ -301,16 +301,16 @@ impl CorpusCase { .build() } Self::Sde => { - let oral = model.route_index("oral").ok_or_else(|| { + model.route_index("oral").ok_or_else(|| { io::Error::other(format!("{}: missing oral route", self.label())) })?; Subject::builder(self.label()) .covariate("wt", 0.0, 70.0) - .bolus(0.0, 80.0, oral) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .bolus(0.0, 80.0, "oral") + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build() } }; From ecf05753b1db510e9cffbe873fdb480b3561a9a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 15:22:23 +0100 Subject: [PATCH 05/12] chore: fmt --- pharmsol-dsl/src/authoring.rs | 6 +-- pharmsol-dsl/src/parser.rs | 6 +-- src/dsl/native.rs | 17 +++--- src/dsl/runtime.rs | 8 ++- tests/full_feature_dsl_backend_parity.rs | 67 +++++++++++++++--------- 5 files changed, 59 insertions(+), 45 deletions(-) diff --git a/pharmsol-dsl/src/authoring.rs b/pharmsol-dsl/src/authoring.rs index 7b0b4dd6..0496c0fc 100644 --- a/pharmsol-dsl/src/authoring.rs +++ b/pharmsol-dsl/src/authoring.rs @@ -949,11 +949,7 @@ fn parse_output_label_segment(src: &str, abs_start: usize) -> Result Result { +fn parse_label_segment(src: &str, abs_start: usize, expected: &str) -> Result { let trimmed = src.trim(); let leading = src.len() - src.trim_start().len(); if trimmed.is_empty() { diff --git a/pharmsol-dsl/src/parser.rs b/pharmsol-dsl/src/parser.rs index fe844c37..98c6b0a4 100644 --- a/pharmsol-dsl/src/parser.rs +++ b/pharmsol-dsl/src/parser.rs @@ -906,9 +906,9 @@ impl Parser { } fn parse_label_name(&mut self, expected: &str) -> Result { - let token = self - .bump() - .ok_or_else(|| ParseError::new(format!("expected {expected}"), Span::empty(self.src_len)))?; + let token = self.bump().ok_or_else(|| { + ParseError::new(format!("expected {expected}"), Span::empty(self.src_len)) + })?; match token.kind { TokenKind::Ident(name) => Ok(Ident::new(name, token.span)), TokenKind::Number(value) diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 6df2f05d..97c41013 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -402,21 +402,20 @@ impl SharedNativeModel { label: &InputLabel, kind: RouteKind, ) -> Result { - let input = self - .route_index(label.as_str()) - .ok_or_else(|| PharmsolError::UnknownInputLabel { - label: label.to_string(), - })?; + let input = + self.route_index(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownInputLabel { + label: label.to_string(), + })?; self.validate_input_for_kind(input, kind)?; Ok(input) } fn resolve_output_label(&self, label: &OutputLabel) -> Result { - self.output_index(label.as_str()).ok_or_else(|| { - PharmsolError::UnknownOutputLabel { + self.output_index(label.as_str()) + .ok_or_else(|| PharmsolError::UnknownOutputLabel { label: label.to_string(), - } - }) + }) } fn resolve_events(&self, occasion: &Occasion) -> Result, PharmsolError> { diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index 1d8a1327..ba6dd5cd 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -376,8 +376,8 @@ fn runtime_model_from_parts( mod tests { use super::*; use crate::dsl::compile_sde_model_to_jit; - use crate::PharmsolError; use crate::test_fixtures::STRUCTURED_BLOCK_CORPUS; + use crate::PharmsolError; use crate::SubjectBuilderExt; use approx::assert_relative_eq; use pharmsol_dsl::{DiagnosticPhase, DSL_BACKEND_GENERIC, DSL_PARSE_GENERIC}; @@ -509,7 +509,11 @@ out(cp) = central / v ~ continuous() source: &str, model_name: &str, work_dir: &std::path::Path, - ) -> (CompiledRuntimeModel, CompiledRuntimeModel, CompiledRuntimeModel) { + ) -> ( + CompiledRuntimeModel, + CompiledRuntimeModel, + CompiledRuntimeModel, + ) { let jit = compile_module_source_to_runtime( source, Some(model_name), diff --git a/tests/full_feature_dsl_backend_parity.rs b/tests/full_feature_dsl_backend_parity.rs index 1aba8213..929e7243 100644 --- a/tests/full_feature_dsl_backend_parity.rs +++ b/tests/full_feature_dsl_backend_parity.rs @@ -27,18 +27,21 @@ mod tests { let info = model.info(); assert_eq!(info.name, "ode_full_feature_parity"); - assert_eq!(info.parameters, owned_names(&[ - "ka", - "ke", - "kcp", - "kpc", - "v", - "tlag", - "f_oral", - "base_depot", - "base_central", - "base_peripheral", - ])); + assert_eq!( + info.parameters, + owned_names(&[ + "ka", + "ke", + "kcp", + "kpc", + "v", + "tlag", + "f_oral", + "base_depot", + "base_central", + "base_peripheral", + ]) + ); assert_eq!( info.covariates .iter() @@ -61,7 +64,10 @@ mod tests { vec![0, 1, 2] ); assert_eq!( - info.routes.iter().map(|route| route.index).collect::>(), + info.routes + .iter() + .map(|route| route.index) + .collect::>(), vec![0, 1, 0] ); assert_eq!( @@ -81,16 +87,19 @@ mod tests { let info = model.info(); assert_eq!(info.name, "analytical_full_feature_parity"); - assert_eq!(info.parameters, owned_names(&[ - "ka", - "ke", - "v", - "tlag", - "f_oral", - "base_gut", - "base_central", - "tvke", - ])); + assert_eq!( + info.parameters, + owned_names(&[ + "ka", + "ke", + "v", + "tlag", + "f_oral", + "base_gut", + "base_central", + "tvke", + ]) + ); assert_eq!( info.covariates .iter() @@ -113,7 +122,10 @@ mod tests { vec![0, 1, 2] ); assert_eq!( - info.routes.iter().map(|route| route.index).collect::>(), + info.routes + .iter() + .map(|route| route.index) + .collect::>(), vec![0, 1, 0] ); assert_eq!( @@ -196,6 +208,9 @@ mod tests { #[test] fn analytical_full_feature_dsl_matches_handwritten_across_backends( ) -> Result<(), Box> { - assert_full_backend_parity(CorpusCase::AnalyticalFull, assert_analytical_full_public_shape) + assert_full_backend_parity( + CorpusCase::AnalyticalFull, + assert_analytical_full_public_shape, + ) } -} \ No newline at end of file +} From b65d795657cc537b6c4de2929d2e608bd7b8575c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 15:26:14 +0100 Subject: [PATCH 06/12] chore: update README.md --- README.md | 25 +++++++++++-------------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index d24aea87..cf9865d0 100644 --- a/README.md +++ b/README.md @@ -36,15 +36,12 @@ let analytical = analytical! { }, }; -let iv = analytical.route_index("iv").unwrap(); -let cp = analytical.output_index("cp").unwrap(); - let subject = Subject::builder("patient_001") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = analytical @@ -121,12 +118,12 @@ use pharmsol::prelude::*; use pharmsol::nca::NCAOptions; let subject = Subject::builder("patient_001") - .bolus(0.0, 100.0, 0) // 100 mg oral dose - .observation(0.5, 5.0, 0) - .observation(1.0, 10.0, 0) - .observation(2.0, 8.0, 0) - .observation(4.0, 4.0, 0) - .observation(8.0, 2.0, 0) + .bolus(0.0, 100.0, "oral") // 100 mg oral dose + .observation(0.5, 5.0, "cp") + .observation(1.0, 10.0, "cp") + .observation(2.0, 8.0, "cp") + .observation(4.0, 4.0, "cp") + .observation(8.0, 2.0, "cp") .build(); let result = subject.nca(&NCAOptions::default()).expect("NCA failed"); From 627085bf7d5fedd9f5a4e56b141295564ae20014 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 15:52:07 +0100 Subject: [PATCH 07/12] chore: {}->[] :D --- README.md | 8 +- examples/analytical_readme.rs | 4 +- examples/analytical_vs_ode.rs | 32 +-- examples/compare_solvers.rs | 4 +- examples/covariates.rs | 4 +- examples/macro_vs_handwritten_one_cpt.rs | 4 +- examples/macro_vs_handwritten_two_cpt.rs | 4 +- examples/ode_readme.rs | 4 +- examples/one_compartment.rs | 8 +- examples/sde_readme.rs | 4 +- examples/two_compartment.rs | 4 +- pharmsol-macros/src/lib.rs | 305 ++++++++--------------- tests/analytical_macro_lowering.rs | 16 +- tests/authoring_parity_corpus.rs | 34 +-- tests/full_feature_macro_parity.rs | 20 +- tests/ode_macro_lowering.rs | 131 +++------- tests/sde_macro_lowering.rs | 16 +- 17 files changed, 213 insertions(+), 389 deletions(-) diff --git a/README.md b/README.md index cf9865d0..73932de2 100644 --- a/README.md +++ b/README.md @@ -27,9 +27,9 @@ let analytical = analytical! { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -61,9 +61,9 @@ let ode = ode! { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/analytical_readme.rs b/examples/analytical_readme.rs index 8e5b97f7..676f07b9 100644 --- a/examples/analytical_readme.rs +++ b/examples/analytical_readme.rs @@ -6,9 +6,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; diff --git a/examples/analytical_vs_ode.rs b/examples/analytical_vs_ode.rs index 290d6632..3fd58fd1 100644 --- a/examples/analytical_vs_ode.rs +++ b/examples/analytical_vs_ode.rs @@ -72,9 +72,9 @@ fn one_cmt_iv(params: &[f64]) { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -86,9 +86,9 @@ fn one_cmt_iv(params: &[f64]) { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, @@ -114,9 +114,9 @@ fn one_cmt_oral(params: &[f64]) { params: [ka, ke, v], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], structure: one_compartment_with_absorption, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -128,9 +128,9 @@ fn one_cmt_oral(params: &[f64]) { params: [ka, ke, v], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -157,9 +157,9 @@ fn two_cmt_iv(params: &[f64]) { params: [ke, k12, k21, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: two_compartments, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -171,9 +171,9 @@ fn two_cmt_iv(params: &[f64]) { params: [ke, k12, k21, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central] - k12 * x[central] + k21 * x[peripheral]; dx[peripheral] = k12 * x[central] - k21 * x[peripheral]; @@ -200,9 +200,9 @@ fn two_cmt_oral(params: &[f64]) { params: [ka, ke, k12, k21, v], states: [gut, central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], structure: two_compartments_with_absorption, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -214,9 +214,9 @@ fn two_cmt_oral(params: &[f64]) { params: [ka, ke, k12, k21, v], states: [gut, central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central] - k12 * x[central] + k21 * x[peripheral]; diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index ad705931..a8067485 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -24,10 +24,10 @@ fn two_cpt(solver: OdeSolver) -> equation::ODE { params: [ke, kcp, kpc, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(load) -> central, infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; diff --git a/examples/covariates.rs b/examples/covariates.rs index 9aabf491..180a0173 100644 --- a/examples/covariates.rs +++ b/examples/covariates.rs @@ -7,9 +7,9 @@ fn main() { covariates: [creatinine, age], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _t, dx| { let scaled_ke = ke * (creatinine / 75.0).powf(0.75) * (age / 25.0).powf(0.5); diff --git a/examples/macro_vs_handwritten_one_cpt.rs b/examples/macro_vs_handwritten_one_cpt.rs index be9edb2a..c7b088a5 100644 --- a/examples/macro_vs_handwritten_one_cpt.rs +++ b/examples/macro_vs_handwritten_one_cpt.rs @@ -12,9 +12,9 @@ fn macro_model() -> equation::ODE { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/macro_vs_handwritten_two_cpt.rs b/examples/macro_vs_handwritten_two_cpt.rs index 114024bd..377e1e88 100644 --- a/examples/macro_vs_handwritten_two_cpt.rs +++ b/examples/macro_vs_handwritten_two_cpt.rs @@ -13,10 +13,10 @@ fn macro_model() -> equation::ODE { params: [ke, kcp, kpc, v], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(load) -> central, infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central] - kcp * x[central] + kpc * x[peripheral]; dx[peripheral] = kcp * x[central] - kpc * x[peripheral]; diff --git a/examples/ode_readme.rs b/examples/ode_readme.rs index a0174801..7b436d0b 100644 --- a/examples/ode_readme.rs +++ b/examples/ode_readme.rs @@ -6,9 +6,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/one_compartment.rs b/examples/one_compartment.rs index aafdf2b2..021e06f2 100644 --- a/examples/one_compartment.rs +++ b/examples/one_compartment.rs @@ -6,9 +6,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -20,9 +20,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/sde_readme.rs b/examples/sde_readme.rs index 6106b17a..cc47cdda 100644 --- a/examples/sde_readme.rs +++ b/examples/sde_readme.rs @@ -7,9 +7,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { states: [central], outputs: [cp], particles: 16, - routes: { + routes: [ infusion(iv) -> central, - }, + ], drift: |x, _p, _t, dx, _cov| { dx[central] = -ke * x[central]; }, diff --git a/examples/two_compartment.rs b/examples/two_compartment.rs index 64d554af..fdba715e 100644 --- a/examples/two_compartment.rs +++ b/examples/two_compartment.rs @@ -27,9 +27,9 @@ fn main() -> Result<(), pharmsol::PharmsolError> { covariates: [wt], states: [central, peripheral], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _t, dx| { // CL: Clearance (L/hr), V: Central volume (L) // Vp: Peripheral volume (L), Q: Inter-compartmental clearance (L/hr) diff --git a/pharmsol-macros/src/lib.rs b/pharmsol-macros/src/lib.rs index 0d143184..7e483951 100644 --- a/pharmsol-macros/src/lib.rs +++ b/pharmsol-macros/src/lib.rs @@ -27,7 +27,6 @@ struct OdeInput { states: Vec, outputs: Vec, routes: Vec, - diffeq_mode: OdeDiffeqMode, diffeq: ExprClosure, lag: Option, fa: Option, @@ -66,12 +65,6 @@ struct SdeInput { out: ExprClosure, } -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -enum OdeDiffeqMode { - InjectedRouteInputs, - ExplicitRouteVectors, -} - struct OdeRouteDecl { kind: OdeRouteKind, input: SymbolicIndex, @@ -275,7 +268,7 @@ impl Parse for OdeInput { let routes = routes.ok_or_else(|| missing_required_ode_field("routes"))?; let diffeq = diffeq.ok_or_else(|| missing_required_ode_field("diffeq"))?; let out = out.ok_or_else(|| missing_required_ode_field("out"))?; - let diffeq_mode = classify_diffeq_mode(&diffeq, &routes)?; + validate_ode_diffeq_uses_automatic_injection(&diffeq, &routes)?; validate_unique_idents("parameter", ¶ms, "ode!")?; validate_unique_idents("covariate", &covariates, "ode!")?; @@ -300,7 +293,6 @@ impl Parse for OdeInput { init: init.as_ref(), out: &out, }, - diffeq_mode, }, )?; @@ -311,7 +303,6 @@ impl Parse for OdeInput { states, outputs, routes, - diffeq_mode, diffeq, lag, fa, @@ -694,8 +685,18 @@ fn parse_symbolic_index_list(input: ParseStream) -> syn::Result syn::Result> { + if input.peek(token::Brace) { + return Err(input.error("declaration-first macro `routes` must use `[...]`, not `{...}`")); + } + + if !input.peek(token::Bracket) { + return Err( + input.error("expected a bracketed route list like `routes: [infusion(iv) -> central]`") + ); + } + let content; - syn::braced!(content in input); + syn::bracketed!(content in input); Ok( Punctuated::::parse_terminated(&content)? .into_iter() @@ -1063,13 +1064,12 @@ fn generate_covariate_bindings( } } -fn classify_diffeq_mode( +fn validate_ode_diffeq_uses_automatic_injection( diffeq: &ExprClosure, routes: &[OdeRouteDecl], -) -> syn::Result { +) -> syn::Result<()> { match closure_param_names(diffeq).len() { - 3 => Ok(OdeDiffeqMode::InjectedRouteInputs), - 7 => Ok(OdeDiffeqMode::ExplicitRouteVectors), + 3 => Ok(()), 5 => { let usage = ClosureBodyUsage::analyze(diffeq.body.as_ref()); let route_inputs = route_input_idents(routes); @@ -1082,14 +1082,17 @@ fn classify_diffeq_mode( .is_some_and(|ident| usage.indexes(ident) && !usage.assigns_index(ident)); if mentions_route_inputs || indexes_fifth_param || reads_fourth_param_as_input { - Ok(OdeDiffeqMode::ExplicitRouteVectors) + Err(syn::Error::new_spanned( + diffeq, + "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx| and remove manual `bolus[...]` / `rateiv[...]` terms", + )) } else { - Ok(OdeDiffeqMode::InjectedRouteInputs) + Ok(()) } } _ => Err(syn::Error::new_spanned( diffeq, - "declaration-first `ode!` requires `diffeq` to have either 3 parameters: |x, t, dx|, 5 parameters: |x, p, t, dx, cov| or |x, t, dx, bolus, rateiv|, or 7 parameters: |x, p, t, dx, bolus, rateiv, cov|", + "declaration-first `ode!` only supports automatic route injection in `diffeq`; use either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", )), } } @@ -1214,7 +1217,6 @@ struct AnalyticalBindingClosures<'a> { struct OdeBindingClosures<'a> { diffeq: &'a ExprClosure, common: CommonBindingClosures<'a>, - diffeq_mode: OdeDiffeqMode, } #[derive(Clone, Copy)] @@ -1238,7 +1240,6 @@ fn validate_named_binding_compatibility( let OdeBindingClosures { diffeq, common: CommonBindingClosures { lag, fa, init, out }, - diffeq_mode, } = closures; let route_inputs = route_input_idents(routes); @@ -1289,31 +1290,6 @@ fn validate_named_binding_compatibility( validate_closure_param_conflicts("diffeq", diffeq, covariates, "covariate")?; validate_closure_param_conflicts("diffeq", diffeq, states, "state")?; - if diffeq_mode == OdeDiffeqMode::ExplicitRouteVectors { - validate_binding_conflicts( - "parameter", - params, - "route", - &route_inputs, - "`diffeq` named binding generation", - )?; - validate_binding_conflicts( - "state", - states, - "route", - &route_inputs, - "`diffeq` named binding generation", - )?; - validate_binding_conflicts( - "covariate", - covariates, - "route", - &route_inputs, - "`diffeq` named binding generation", - )?; - validate_closure_param_conflicts("diffeq", diffeq, &route_inputs, "route")?; - } - if let Some(lag) = lag { validate_binding_conflicts( "covariate", @@ -1881,7 +1857,6 @@ fn expand_ode_init( fn expand_route_metadata( routes: &[OdeRouteDecl], - diffeq_mode: OdeDiffeqMode, lag_routes: &HashSet, fa_routes: &HashSet, ) -> Vec { @@ -1899,10 +1874,6 @@ fn expand_route_metadata( quote! { ::pharmsol::equation::Route::infusion(stringify!(#input)) } } }; - let input_policy = match diffeq_mode { - OdeDiffeqMode::InjectedRouteInputs => quote! { .inject_input_to_destination() }, - OdeDiffeqMode::ExplicitRouteVectors => quote! { .expect_explicit_input() }, - }; let lag_flag = if lag_routes.contains(&route_name) { quote! { .with_lag() } } else { @@ -1919,7 +1890,7 @@ fn expand_route_metadata( .to_state(stringify!(#destination)) #lag_flag #fa_flag - #input_policy + .inject_input_to_destination() } }) .collect() @@ -2151,148 +2122,64 @@ fn expand_diffeq( states: &[Ident], routes: &[OdeRouteDecl], route_bindings: &[(SymbolicIndex, usize)], - diffeq_mode: OdeDiffeqMode, ) -> syn::Result { let state_consts = generate_index_consts(states); + let x = generated_ident("__pharmsol_x"); + let p = generated_ident("__pharmsol_p"); + let t = generated_ident("__pharmsol_t"); + let dx = generated_ident("__pharmsol_dx"); + let bolus = generated_ident("__pharmsol_bolus"); + let rateiv = generated_ident("__pharmsol_rateiv"); + let cov = generated_ident("__pharmsol_cov"); + let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; + let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; + let input_aliases = generate_supported_input_aliases( + diffeq, + &[&full_inputs, &reduced_inputs], + "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", + )?; + let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); + let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); + let body = &diffeq.body; + let dx_binding = if diffeq.inputs.len() == full_inputs.len() { + closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()) + } else { + closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone()) + }; + let route_terms = expand_injected_ode_route_terms( + routes, + states, + route_bindings, + &dx_binding, + &bolus, + &rateiv, + ); - match diffeq_mode { - OdeDiffeqMode::ExplicitRouteVectors => { - let route_consts = generate_mapped_index_consts(route_bindings); - let x = generated_ident("__pharmsol_x"); - let p = generated_ident("__pharmsol_p"); - let t = generated_ident("__pharmsol_t"); - let dx = generated_ident("__pharmsol_dx"); - let bolus = generated_ident("__pharmsol_bolus"); - let rateiv = generated_ident("__pharmsol_rateiv"); - let cov = generated_ident("__pharmsol_cov"); - let full_inputs = [ - x.clone(), - p.clone(), - t.clone(), - dx.clone(), - bolus.clone(), - rateiv.clone(), - cov.clone(), - ]; - let reduced_inputs = [ - x.clone(), - t.clone(), - dx.clone(), - bolus.clone(), - rateiv.clone(), - ]; - let input_aliases = generate_supported_input_aliases( - diffeq, - &[&full_inputs, &reduced_inputs], - "declaration-first `ode!` explicit-route `diffeq` requires either 7 parameters: |x, p, t, dx, bolus, rateiv, cov| or 5 parameters: |x, t, dx, bolus, rateiv|", - )?; - let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); - let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); - let bolus_binding = if diffeq.inputs.len() == full_inputs.len() { - closure_param_ident(diffeq, 4).unwrap_or_else(|| bolus.clone()) - } else { - closure_param_ident(diffeq, 3).unwrap_or_else(|| bolus.clone()) - }; - let rateiv_binding = if diffeq.inputs.len() == full_inputs.len() { - closure_param_ident(diffeq, 5).unwrap_or_else(|| rateiv.clone()) - } else { - closure_param_ident(diffeq, 4).unwrap_or_else(|| rateiv.clone()) - }; - let route_label_map = symbolic_numeric_binding_map(route_bindings); - let body = NumericLabelRewriter::rewrite( - diffeq.body.as_ref(), - vec![ - IndexRewriteTarget::new(bolus_binding, route_label_map.clone()), - IndexRewriteTarget::new(rateiv_binding, route_label_map), - ], - None, - ); - - Ok(quote! {{ - let __pharmsol_diffeq: fn( - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - f64, - &mut ::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::data::Covariates, - ) = |#x: &::pharmsol::simulator::V, - #p: &::pharmsol::simulator::V, - #t: f64, - #dx: &mut ::pharmsol::simulator::V, - #bolus: &::pharmsol::simulator::V, - #rateiv: &::pharmsol::simulator::V, - #cov: &::pharmsol::data::Covariates| { - #input_aliases - #state_consts - #route_consts - #parameter_bindings - #covariate_bindings - #body - }; - __pharmsol_diffeq - }}) - } - OdeDiffeqMode::InjectedRouteInputs => { - let x = generated_ident("__pharmsol_x"); - let p = generated_ident("__pharmsol_p"); - let t = generated_ident("__pharmsol_t"); - let dx = generated_ident("__pharmsol_dx"); - let bolus = generated_ident("__pharmsol_bolus"); - let rateiv = generated_ident("__pharmsol_rateiv"); - let cov = generated_ident("__pharmsol_cov"); - let full_inputs = [x.clone(), p.clone(), t.clone(), dx.clone(), cov.clone()]; - let reduced_inputs = [x.clone(), t.clone(), dx.clone()]; - let input_aliases = generate_supported_input_aliases( - diffeq, - &[&full_inputs, &reduced_inputs], - "declaration-first `ode!` injected-route `diffeq` requires either 5 parameters: |x, p, t, dx, cov| or 3 parameters: |x, t, dx|", - )?; - let parameter_bindings = generate_parameter_bindings(params, diffeq, &p); - let covariate_bindings = generate_covariate_bindings(covariates, diffeq, &cov, &t); - let body = &diffeq.body; - let dx_binding = if diffeq.inputs.len() == full_inputs.len() { - closure_param_ident(diffeq, 3).unwrap_or_else(|| dx.clone()) - } else { - closure_param_ident(diffeq, 2).unwrap_or_else(|| dx.clone()) - }; - let route_terms = expand_injected_ode_route_terms( - routes, - states, - route_bindings, - &dx_binding, - &bolus, - &rateiv, - ); - - Ok(quote! {{ - let __pharmsol_diffeq: fn( - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - f64, - &mut ::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::simulator::V, - &::pharmsol::data::Covariates, - ) = |#x: &::pharmsol::simulator::V, - #p: &::pharmsol::simulator::V, - #t: f64, - #dx: &mut ::pharmsol::simulator::V, - #bolus: &::pharmsol::simulator::V, - #rateiv: &::pharmsol::simulator::V, - #cov: &::pharmsol::data::Covariates| { - #input_aliases - #state_consts - #parameter_bindings - #covariate_bindings - #body - #route_terms - }; - __pharmsol_diffeq - }}) - } - } + Ok(quote! {{ + let __pharmsol_diffeq: fn( + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + f64, + &mut ::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::simulator::V, + &::pharmsol::data::Covariates, + ) = |#x: &::pharmsol::simulator::V, + #p: &::pharmsol::simulator::V, + #t: f64, + #dx: &mut ::pharmsol::simulator::V, + #bolus: &::pharmsol::simulator::V, + #rateiv: &::pharmsol::simulator::V, + #cov: &::pharmsol::data::Covariates| { + #input_aliases + #state_consts + #parameter_bindings + #covariate_bindings + #body + #route_terms + }; + __pharmsol_diffeq + }}) } fn resolve_analytical_structure(structure: &Ident) -> syn::Result { @@ -2883,7 +2770,6 @@ pub fn ode(input: TokenStream) -> TokenStream { &input.states, &input.routes, &route_bindings, - input.diffeq_mode, ) { Ok(diffeq) => diffeq, Err(error) => return error.to_compile_error().into(), @@ -2909,7 +2795,7 @@ pub fn ode(input: TokenStream) -> TokenStream { let covariates = &input.covariates; let states = &input.states; let outputs = &input.outputs; - let routes = expand_route_metadata(&input.routes, input.diffeq_mode, &lag_routes, &fa_routes); + let routes = expand_route_metadata(&input.routes, &lag_routes, &fa_routes); let covariate_metadata = if covariates.is_empty() { quote! {} } else { @@ -3339,7 +3225,7 @@ mod tests { #[test] fn validates_route_destinations() { let error = syn::parse_str::( - "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> peripheral }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> peripheral], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown route destination must fail"); @@ -3352,7 +3238,7 @@ mod tests { #[test] fn rejects_named_binding_collisions() { let error = syn::parse_str::( - "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [central, v], states: [central], outputs: [cp], routes: [infusion(iv) -> central], diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", ) .err() .expect("parameter/state binding collisions must fail"); @@ -3365,7 +3251,7 @@ mod tests { #[test] fn ode_route_bindings_share_inputs_by_kind_local_ordinal() { let input = syn::parse_str::( - "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: { bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot }, diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ka, ke, v], states: [depot, central], outputs: [cp], routes: [bolus(oral) -> depot, infusion(iv) -> central, bolus(sc) -> depot], diffeq: |x, p, t, dx, b, rateiv, cov| {}, out: |x, p, t, cov, y| {}", ) .expect("declaration-first ode input should parse"); @@ -3416,7 +3302,7 @@ mod tests { #[test] fn analytical_accepts_extra_parameters_beyond_kernel_arity() { let input = syn::parse_str::( - "name: \"demo\", params: [ka, ke, v, tlag, tvke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, sec: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ka, ke, v, tlag, tvke], covariates: [wt, renal], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, sec: |_t| { ke = tvke; }, out: |x, p, t, cov, y| {}", ) .expect("extra declared parameters should be allowed"); @@ -3429,7 +3315,7 @@ mod tests { #[test] fn analytical_rejects_unknown_structure() { let error = syn::parse_str::( - "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, structure: mystery, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: mystery, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown analytical structure must fail"); @@ -3442,7 +3328,7 @@ mod tests { #[test] fn analytical_rejects_insufficient_kernel_parameters() { let error = syn::parse_str::( - "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, out: |x, p, t, cov, y| {}", ) .err() .expect("insufficient kernel parameters must fail"); @@ -3455,7 +3341,7 @@ mod tests { #[test] fn analytical_rejects_unknown_route_property_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: { bolus(oral) -> gut }, structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ka, ke, v], states: [gut, central], outputs: [cp], routes: [bolus(oral) -> gut], structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { iv => 1.0 } }, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown lag route must fail"); @@ -3468,7 +3354,7 @@ mod tests { #[test] fn analytical_rejects_infusion_lag_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, v, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], structure: one_compartment, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", ) .err() .expect("infusion lag must fail"); @@ -3481,7 +3367,7 @@ mod tests { #[test] fn sde_requires_particles() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, theta], states: [central], outputs: [cp], routes: [infusion(iv) -> central], drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, out: |x, p, t, cov, y| {}", ) .err() .expect("missing particles must fail"); @@ -3494,7 +3380,7 @@ mod tests { #[test] fn sde_rejects_unknown_route_property_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, sigma_ke], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { oral => 1.0 } }, out: |x, p, t, cov, y| {}", ) .err() .expect("unknown lag route must fail"); @@ -3507,7 +3393,7 @@ mod tests { #[test] fn sde_rejects_infusion_lag_binding() { let error = syn::parse_str::( - "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", + "name: \"demo\", params: [ke, sigma_ke, tlag], states: [central], outputs: [cp], routes: [infusion(iv) -> central], particles: 16, drift: |x, p, t, dx, cov| {}, diffusion: |p, sigma| {}, lag: |_p, _t, _cov| { lag! { iv => tlag } }, out: |x, p, t, cov, y| {}", ) .err() .expect("infusion lag must fail"); @@ -3516,4 +3402,17 @@ mod tests { .to_string() .contains("declaration-first `sde!` does not allow `lag` on infusion route `iv`")); } + + #[test] + fn rejects_braced_route_lists() { + let error = syn::parse_str::( + "name: \"demo\", params: [ke], states: [central], outputs: [cp], routes: { infusion(iv) -> central }, diffeq: |x, p, t, dx, cov| {}, out: |x, p, t, cov, y| {}", + ) + .err() + .expect("braced route lists must fail"); + + assert!(error + .to_string() + .contains("declaration-first macro `routes` must use `[...]`, not `{...}`")); + } } diff --git a/tests/analytical_macro_lowering.rs b/tests/analytical_macro_lowering.rs index 796cb55e..f527978f 100644 --- a/tests/analytical_macro_lowering.rs +++ b/tests/analytical_macro_lowering.rs @@ -56,9 +56,9 @@ fn macro_one_compartment() -> equation::Analytical { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], structure: one_compartment, out: |x, _t, y| { y[cp] = x[central] / v; @@ -99,9 +99,9 @@ fn macro_one_compartment_with_absorption() -> equation::Analytical { params: [ka, ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], structure: one_compartment_with_absorption, lag: |_t| { lag! { oral => tlag } @@ -166,10 +166,10 @@ fn macro_shared_input_analytical() -> equation::Analytical { params: [ka, ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, lag: |_t| { lag! { oral => tlag } @@ -229,10 +229,10 @@ fn macro_covariate_analytical() -> equation::Analytical { covariates: [wt, renal], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, sec: |_t| { let wt_scale = (wt / 70.0).powf(0.75); diff --git a/tests/authoring_parity_corpus.rs b/tests/authoring_parity_corpus.rs index c7164d71..be80f10e 100644 --- a/tests/authoring_parity_corpus.rs +++ b/tests/authoring_parity_corpus.rs @@ -580,9 +580,9 @@ fn macro_ode_model() -> equation::ODE { covariates: [wt], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, - }, + ], diffeq: |x, _p, _t, dx, _cov| { dx[depot] = -ka * x[depot]; dx[central] = ka * x[depot] - (cl / v) * x[central]; @@ -676,13 +676,13 @@ fn runtime_shared_input_macro_ode() -> equation::ODE { params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, infusion(iv) -> central, - }, - diffeq: |x, _p, _t, dx, bolus, rateiv, _cov| { - dx[depot] = bolus[oral] - ka * x[depot]; - dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; + ], + diffeq: |x, _p, _t, dx, _cov| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; }, lag: |_p, _t, _cov| { lag! { oral => tlag } @@ -731,10 +731,10 @@ fn runtime_shared_input_handwritten_ode() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ]), ) .expect("handwritten shared-input ODE metadata should validate") @@ -791,10 +791,10 @@ fn runtime_shared_input_macro_analytical() -> equation::Analytical { params: [ka, ke, v, tlag, f_oral], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, lag: |_p, _t, _cov| { lag! { oral => tlag } @@ -856,10 +856,10 @@ fn runtime_shared_input_macro_sde() -> equation::SDE { states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], drift: |x, _p, _t, dx, _cov| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -976,9 +976,9 @@ fn macro_analytical_model() -> equation::Analytical { params: [ka, ke, v], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, - }, + ], structure: one_compartment_with_absorption, out: |x, _p, _t, _cov, y| { y[cp] = x[central] / v; @@ -1019,9 +1019,9 @@ fn macro_sde_model() -> equation::SDE { states: [depot, central], outputs: [cp], particles: 256, - routes: { + routes: [ bolus(oral) -> depot, - }, + ], drift: |x, _p, _t, dx, _cov| { dx[depot] = -ka * x[depot]; dx[central] = ka * x[depot] - ke * x[central]; diff --git a/tests/full_feature_macro_parity.rs b/tests/full_feature_macro_parity.rs index e3175f84..5017902e 100644 --- a/tests/full_feature_macro_parity.rs +++ b/tests/full_feature_macro_parity.rs @@ -14,19 +14,19 @@ fn macro_ode_model() -> equation::ODE { covariates: [wt, renal], states: [depot, central, peripheral], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, bolus(load) -> central, infusion(iv) -> central, - }, - diffeq: |x, _t, dx, bolus, rateiv| { + ], + diffeq: |x, _t, dx| { let wt_scale = (wt / 70.0).powf(0.75); let renal_scale = (renal / 90.0).powf(0.25); let adjusted_ke = ke * wt_scale * renal_scale; let adjusted_kcp = kcp * (wt / 70.0).powf(0.25); - dx[depot] = bolus[oral] - ka * x[depot]; - dx[central] = bolus[load] + ka * x[depot] + rateiv[iv] + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - (adjusted_ke + adjusted_kcp) * x[central] + kpc * x[peripheral]; dx[peripheral] = adjusted_kcp * x[central] - kpc * x[peripheral]; @@ -185,13 +185,13 @@ fn handwritten_ode_model() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::bolus("load") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ]), ) .expect("handwritten ODE metadata should validate") @@ -224,11 +224,11 @@ fn macro_analytical_model() -> equation::Analytical { covariates: [wt, renal], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, bolus(load) -> central, infusion(iv) -> central, - }, + ], structure: one_compartment_with_absorption, sec: |_t| { let wt_scale = (wt / 70.0).powf(0.75); diff --git a/tests/ode_macro_lowering.rs b/tests/ode_macro_lowering.rs index 480f7e80..99e0eeab 100644 --- a/tests/ode_macro_lowering.rs +++ b/tests/ode_macro_lowering.rs @@ -56,9 +56,9 @@ fn injected_macro_ode() -> equation::ODE { params: [ke, v], states: [central], outputs: [cp], - routes: { + routes: [ infusion(iv) -> central, - }, + ], diffeq: |x, _t, dx| { dx[central] = -ke * x[central]; }, @@ -99,66 +99,17 @@ fn injected_handwritten_ode() -> equation::ODE { .expect("handwritten injected metadata should validate") } -fn explicit_macro_ode() -> equation::ODE { - ode! { - name: "explicit_one_cpt", - params: [ke, v], - states: [central], - outputs: [cp], - routes: { - infusion(iv) -> central, - }, - diffeq: |x, _t, dx, _bolus, rateiv| { - dx[central] = rateiv[iv] - ke * x[central]; - }, - out: |x, _t, y| { - y[cp] = x[central] / v; - }, - } -} - -fn explicit_handwritten_ode() -> equation::ODE { - equation::ODE::new( - |x, p, _t, dx, _bolus, rateiv, _cov| { - fetch_params!(p, ke, _v); - dx[0] = rateiv[0] - ke * x[0]; - }, - |_p, _t, _cov| lag! {}, - |_p, _t, _cov| fa! {}, - |_p, _t, _cov, _x| {}, - |x, p, _t, _cov, y| { - fetch_params!(p, _ke, v); - y[0] = x[0] / v; - }, - ) - .with_nstates(1) - .with_ndrugs(1) - .with_nout(1) - .with_metadata( - equation::metadata::new("explicit_one_cpt") - .parameters(["ke", "v"]) - .states(["central"]) - .outputs(["cp"]) - .route( - equation::Route::infusion("iv") - .to_state("central") - .expect_explicit_input(), - ), - ) - .expect("handwritten explicit metadata should validate") -} - fn numeric_label_macro_ode() -> equation::ODE { ode! { name: "numeric_label_one_cpt", params: [ke, v], states: [central], outputs: [1], - routes: { + routes: [ infusion(1) -> central, - }, - diffeq: |x, _t, dx, _bolus, rateiv| { - dx[central] = rateiv[1] - ke * x[central]; + ], + diffeq: |x, _t, dx| { + dx[central] = -ke * x[central]; }, out: |x, _t, y| { y[1] = x[central] / v; @@ -191,7 +142,7 @@ fn numeric_label_handwritten_ode() -> equation::ODE { .route( equation::Route::infusion("1") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ), ) .expect("handwritten numeric-label metadata should validate") @@ -203,13 +154,13 @@ fn shared_input_macro_ode() -> equation::ODE { params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> depot, infusion(iv) -> central, - }, - diffeq: |x, _t, dx, bolus, rateiv| { - dx[depot] = bolus[oral] - ka * x[depot]; - dx[central] = ka * x[depot] + rateiv[iv] - ke * x[central]; + ], + diffeq: |x, _t, dx| { + dx[depot] = -ka * x[depot]; + dx[central] = ka * x[depot] - ke * x[central]; }, lag: |_t| { lag! { oral => tlag } @@ -257,10 +208,10 @@ fn shared_input_handwritten_ode() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ]), ) .expect("handwritten shared-input metadata should validate") @@ -272,11 +223,11 @@ fn numeric_route_property_macro_ode() -> equation::ODE { params: [ka, ke, v, tlag, f_oral], states: [depot, central], outputs: [1], - routes: { + routes: [ bolus(1) -> depot, - }, - diffeq: |x, _t, dx, bolus, _rateiv| { - dx[depot] = bolus[1] - ka * x[depot]; + ], + diffeq: |x, _t, dx| { + dx[depot] = -ka * x[depot]; dx[central] = ka * x[depot] - ke * x[central]; }, lag: |_t| { @@ -325,7 +276,7 @@ fn numeric_route_property_handwritten_ode() -> equation::ODE { .to_state("depot") .with_lag() .with_bioavailability() - .expect_explicit_input(), + .inject_input_to_destination(), ), ) .expect("handwritten numeric route-property metadata should validate") @@ -337,11 +288,11 @@ fn mixed_output_labels_macro_ode() -> equation::ODE { params: [ke, v], states: [central], outputs: [cp, 0, 1], - routes: { + routes: [ infusion(iv) -> central, - }, - diffeq: |x, _t, dx, _bolus, rateiv| { - dx[central] = rateiv[iv] - ke * x[central]; + ], + diffeq: |x, _t, dx| { + dx[central] = -ke * x[central]; }, out: |x, _t, y| { y[cp] = x[central] / v; @@ -378,7 +329,7 @@ fn mixed_output_labels_handwritten_ode() -> equation::ODE { .route( equation::Route::infusion("iv") .to_state("central") - .expect_explicit_input(), + .inject_input_to_destination(), ), ) .expect("handwritten mixed-output metadata should validate") @@ -391,9 +342,9 @@ fn covariate_macro_ode() -> equation::ODE { covariates: [wt], states: [gut, central], outputs: [cp], - routes: { + routes: [ bolus(oral) -> gut, - }, + ], diffeq: |x, _t, dx| { let scaled_ke = ke * (wt / 70.0); dx[gut] = -ka * x[gut]; @@ -473,32 +424,6 @@ fn macro_injected_lowering_matches_handwritten_metadata_and_predictions() { assert_prediction_match(¯o_predictions, &handwritten_predictions); } -#[test] -fn macro_explicit_lowering_matches_handwritten_metadata_and_predictions() { - let macro_ode = explicit_macro_ode(); - let handwritten_ode = explicit_handwritten_ode(); - let subject = subject_for_route("iv", "cp"); - let support_point = [0.2, 10.0]; - - assert_eq!(macro_ode.metadata(), handwritten_ode.metadata()); - assert_eq!(macro_ode.route_index("iv"), Some(0)); - assert_eq!(macro_ode.output_index("cp"), Some(0)); - assert_eq!(macro_ode.state_index("central"), Some(0)); - - let macro_predictions = macro_ode - .estimate_predictions(&subject, &support_point) - .expect("macro explicit model should simulate") - .flat_predictions() - .to_vec(); - let handwritten_predictions = handwritten_ode - .estimate_predictions(&subject, &support_point) - .expect("handwritten explicit model should simulate") - .flat_predictions() - .to_vec(); - - assert_prediction_match(¯o_predictions, &handwritten_predictions); -} - #[test] fn macro_numeric_labels_lower_to_dense_slots() { let macro_ode = numeric_label_macro_ode(); @@ -622,12 +547,12 @@ fn macro_named_labels_resolve_from_pmetrics_ingestion() { let subject = &data.subjects()[0]; let support_point = [0.2, 10.0]; - let pmetrics_predictions = explicit_macro_ode() + let pmetrics_predictions = injected_macro_ode() .estimate_predictions(subject, &support_point) .expect("macro named-label model should simulate") .flat_predictions() .to_vec(); - let manual_predictions = explicit_macro_ode() + let manual_predictions = injected_macro_ode() .estimate_predictions(&subject_for_route("iv", "cp"), &support_point) .expect("macro internal-index model should simulate") .flat_predictions() diff --git a/tests/sde_macro_lowering.rs b/tests/sde_macro_lowering.rs index 05b5cb27..474c7bab 100644 --- a/tests/sde_macro_lowering.rs +++ b/tests/sde_macro_lowering.rs @@ -73,9 +73,9 @@ fn macro_infusion_sde() -> equation::SDE { states: [central], outputs: [cp], particles: 16, - routes: { + routes: [ infusion(iv) -> central, - }, + ], drift: |x, _t, dx| { dx[central] = -ke * x[central]; }, @@ -133,9 +133,9 @@ fn macro_absorption_sde() -> equation::SDE { states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, - }, + ], drift: |x, _t, dx| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -218,10 +218,10 @@ fn macro_shared_input_sde() -> equation::SDE { states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], drift: |x, _t, dx| { dx[gut] = -ka * x[gut]; dx[central] = ka * x[gut] - ke * x[central]; @@ -307,10 +307,10 @@ fn macro_covariate_sde() -> equation::SDE { states: [gut, central], outputs: [cp], particles: 8, - routes: { + routes: [ bolus(oral) -> gut, infusion(iv) -> central, - }, + ], drift: |x, _t, dx| { let wt_scale = (wt / 70.0).powf(0.75); let renal_scale = (renal / 90.0).powf(0.25); From 8cac67c5cf8a806e03769e4ad8fc54e287baca90 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 16:16:31 +0100 Subject: [PATCH 08/12] chore: update examples to use new route API --- examples/analytical_readme.rs | 13 +++++-------- examples/compare_solvers.rs | 34 +++++++++++++--------------------- examples/dsl_runtime_jit.rs | 24 ++++++++---------------- examples/ode_readme.rs | 13 +++++-------- examples/sde_readme.rs | 13 +++++-------- 5 files changed, 36 insertions(+), 61 deletions(-) diff --git a/examples/analytical_readme.rs b/examples/analytical_readme.rs index 676f07b9..8451b478 100644 --- a/examples/analytical_readme.rs +++ b/examples/analytical_readme.rs @@ -15,15 +15,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { }, }; - let iv = analytical.route_index("iv").expect("iv route exists"); - let cp = analytical.output_index("cp").expect("cp output exists"); - let subject = Subject::builder("analytical_readme") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = analytical.estimate_predictions(&subject, &[1.022, 194.0])?; diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index a8067485..58813081 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -51,29 +51,21 @@ fn main() { // Both declarations resolve to the same shared input, so subject // authoring still uses one numeric index for the loading bolus and the // maintenance infusion. - let load = bdf.route_index("load").expect("load route exists"); - let iv = bdf.route_index("iv").expect("iv route exists"); - let cp = bdf.output_index("cp").expect("cp output exists"); - - assert_eq!( - load, iv, - "mixed IV declarations should share one numeric input" - ); let subject = Subject::builder("id1") - .bolus(0.0, 100.0, iv) - .infusion(12.0, 200.0, iv, 2.0) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) - .missing_observation(8.0, cp) - .missing_observation(12.0, cp) - .missing_observation(12.5, cp) - .missing_observation(13.0, cp) - .missing_observation(14.0, cp) - .missing_observation(16.0, cp) - .missing_observation(24.0, cp) + .bolus(0.0, 100.0, "iv") + .infusion(12.0, 200.0, "iv", 2.0) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(8.0, "cp") + .missing_observation(12.0, "cp") + .missing_observation(12.5, "cp") + .missing_observation(13.0, "cp") + .missing_observation(14.0, "cp") + .missing_observation(16.0, "cp") + .missing_observation(24.0, "cp") .build(); let spp = vec![0.1, 0.05, 0.03, 50.0]; // ke, kcp, kpc, V diff --git a/examples/dsl_runtime_jit.rs b/examples/dsl_runtime_jit.rs index 932acaae..3f7d1efe 100644 --- a/examples/dsl_runtime_jit.rs +++ b/examples/dsl_runtime_jit.rs @@ -43,24 +43,16 @@ out(cp) = central / v on_compile_event, )?; - // 2. Resolve the route and output indices declared by the model. - let iv = model - .route_index("iv") - .ok_or_else(|| io::Error::other("missing iv route"))?; - let cp = model - .output_index("cp") - .ok_or_else(|| io::Error::other("missing cp output"))?; - // 3. Define the subject data. let subject = Subject::builder("bimodal_ke") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(3.0, cp) - .missing_observation(4.0, cp) - .missing_observation(6.0, cp) - .missing_observation(8.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(3.0, "cp") + .missing_observation(4.0, "cp") + .missing_observation(6.0, "cp") + .missing_observation(8.0, "cp") .build(); // 4. Estimate predictions for one support point. diff --git a/examples/ode_readme.rs b/examples/ode_readme.rs index 7b436d0b..2989895f 100644 --- a/examples/ode_readme.rs +++ b/examples/ode_readme.rs @@ -17,15 +17,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { }, }; - let iv = ode.route_index("iv").expect("iv route exists"); - let cp = ode.output_index("cp").expect("cp output exists"); - let subject = Subject::builder("id1") - .infusion(0.0, 100.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 100.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = ode.estimate_predictions(&subject, &[1.022, 194.0])?; diff --git a/examples/sde_readme.rs b/examples/sde_readme.rs index cc47cdda..97b5fed4 100644 --- a/examples/sde_readme.rs +++ b/examples/sde_readme.rs @@ -21,15 +21,12 @@ fn main() -> Result<(), pharmsol::PharmsolError> { }, }; - let iv = sde.route_index("iv").expect("iv route exists"); - let cp = sde.output_index("cp").expect("cp output exists"); - let subject = Subject::builder("sde_readme") - .infusion(0.0, 500.0, iv, 0.5) - .missing_observation(0.5, cp) - .missing_observation(1.0, cp) - .missing_observation(2.0, cp) - .missing_observation(4.0, cp) + .infusion(0.0, 500.0, "iv", 0.5) + .missing_observation(0.5, "cp") + .missing_observation(1.0, "cp") + .missing_observation(2.0, "cp") + .missing_observation(4.0, "cp") .build(); let predictions = sde.estimate_predictions(&subject, &[1.022, 0.0, 194.0])?; From 4b6d962dbe63437198e1180b982011ab59281791 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 16:21:31 +0100 Subject: [PATCH 09/12] chore: Julian made a mistake :P --- examples/compare_solvers.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/compare_solvers.rs b/examples/compare_solvers.rs index 58813081..5d8fdbb6 100644 --- a/examples/compare_solvers.rs +++ b/examples/compare_solvers.rs @@ -53,7 +53,7 @@ fn main() { // maintenance infusion. let subject = Subject::builder("id1") - .bolus(0.0, 100.0, "iv") + .bolus(0.0, 100.0, "load") .infusion(12.0, 200.0, "iv", 2.0) .missing_observation(0.5, "cp") .missing_observation(1.0, "cp") From 7b24e09715b7c1cf217f20ebe00bb05072d1028c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 20:42:01 +0100 Subject: [PATCH 10/12] feature: DSL supports numeric input/outeq. Implement Equation Trait for runtime environments. --- pharmsol-dsl/src/semantic.rs | 20 +- .../tests/dsl_authoring_edge_cases.rs | 71 +++++++ src/dsl/native.rs | 177 +++++++++++++++++- src/dsl/runtime.rs | 66 +++++++ 4 files changed, 328 insertions(+), 6 deletions(-) diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index ac9223dd..6a5e3b91 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -1617,11 +1617,13 @@ impl<'a> Analyzer<'a> { span, ))); } - if let Some(existing) = self.globals.all_names.get(name) { + if let Some(existing) = self.globals.all_names.get(name).copied() { + let existing_kind = self.symbols.get(existing).expect("valid symbol id").kind; + if !allows_route_output_name_overlap(existing_kind, kind) { return Err(SemanticAssist::default() .context_label( - self.symbol_span(*existing), - self.symbol_declared_here(*existing), + self.symbol_span(existing), + self.symbol_declared_here(existing), ) .help(format!( "rename this declaration to a unique name such as `{}_2`", @@ -1636,10 +1638,11 @@ impl<'a> Analyzer<'a> { .apply(SemanticError::new( format!( "symbol name `{name}` collides with existing `{}`", - self.symbol_name(*existing) + self.symbol_name(existing) ), span, ))); + } } let id = self.symbols.len(); self.symbols.push(PendingSymbol { @@ -1649,7 +1652,7 @@ impl<'a> Analyzer<'a> { ty, span, }); - self.globals.all_names.insert(name.to_string(), id); + self.globals.all_names.entry(name.to_string()).or_insert(id); Ok(id) } @@ -2132,6 +2135,13 @@ impl<'a> Analyzer<'a> { } } +fn allows_route_output_name_overlap(existing: SymbolKind, new: SymbolKind) -> bool { + matches!( + (existing, new), + (SymbolKind::Route, SymbolKind::Output) | (SymbolKind::Output, SymbolKind::Route) + ) +} + #[derive(Default)] struct Globals { all_names: BTreeMap, diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 404487dc..3f4cb494 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -209,6 +209,77 @@ out(1) = 3 * central / v assert_eq!(rendered, reparsed.to_string()); } +#[test] +fn shared_numeric_route_and_output_labels_lower_and_round_trip() { + let src = r#" +name = shared_numeric_route_output_labels +kind = ode +params = ke, v +states = central +outputs = 1 +infusion(1) -> central +ddt(central) = -ke * central +out(1) = central / v +"#; + + let module = parse_module(src).expect("shared numeric route/output labels should parse"); + let model = module + .models + .first() + .expect("authoring DSL should produce one model"); + let typed = analyze_model(model).expect("shared numeric route/output labels should analyze"); + let lowered = lower_typed_model(&typed).expect("shared numeric route/output labels should lower"); + + assert_eq!( + lowered + .metadata + .routes + .iter() + .map(|route| route.name.as_str()) + .collect::>(), + vec!["1"] + ); + assert_eq!( + lowered + .metadata + .outputs + .iter() + .map(|output| output.name.as_str()) + .collect::>(), + vec!["1"] + ); + + let rendered = module.to_string(); + let reparsed = parse_module(&rendered).expect("rendered shared-label model should reparse"); + + assert_eq!(rendered, reparsed.to_string()); +} + +#[test] +fn route_labels_still_collide_with_scalar_symbol_names() { + let src = r#" +name = route_state_collision +kind = ode +params = ke +states = central, iv +outputs = cp +infusion(iv) -> central +ddt(central) = -ke * central +ddt(iv) = 0 +out(cp) = central +"#; + + let model = parse_model(src).expect("route/state collision model parses"); + let err = analyze_model(&model).expect_err("route label should still collide with state name"); + let rendered = err.render(src); + + assert!( + rendered.contains("symbol name `iv` collides with existing `iv`"), + "{}", + rendered + ); +} + #[test] fn unknown_route_destination_state_suggests_declared_state() { let src = r#" diff --git a/src/dsl/native.rs b/src/dsl/native.rs index 97c41013..d9598172 100644 --- a/src/dsl/native.rs +++ b/src/dsl/native.rs @@ -1,4 +1,5 @@ use std::cell::RefCell; +use std::collections::HashMap; use std::sync::Arc; use diffsol::{ @@ -20,14 +21,17 @@ pub use super::model_info::{ NativeCovariateInfo, NativeModelInfo, NativeOutputInfo, NativeRouteInfo, }; use crate::{ + data::error_model::AssayErrorModels, data::{Covariates, Infusion, InputLabel, OutputLabel}, simulator::{ + cache::{PredictionCache, DEFAULT_CACHE_SIZE}, equation::{ ode::{closure_helpers::PMProblem, ExplicitRkTableau, OdeSolver, SdirkTableau}, sde::simulate_sde_event_with, + EqnKind, Equation, EquationPriv, EquationTypes, }, likelihood::{Prediction, SubjectPredictions}, - M, V, + Fa, Lag, M, T, V, }, Event, Observation, Occasion, PharmsolError, Subject, }; @@ -727,6 +731,7 @@ pub struct NativeOdeModel { solver: OdeSolver, rtol: f64, atol: f64, + cache: Option, } #[derive(Clone, Debug)] @@ -754,6 +759,7 @@ impl NativeOdeModel { solver: OdeSolver::default(), rtol: DEFAULT_ODE_RTOL, atol: DEFAULT_ODE_ATOL, + cache: Some(PredictionCache::new(DEFAULT_CACHE_SIZE)), } } @@ -1031,6 +1037,175 @@ impl NativeOdeModel { } } +fn runtime_no_lag(_: &V, _: T, _: &Covariates) -> HashMap { + HashMap::new() +} + +fn runtime_no_fa(_: &V, _: T, _: &Covariates) -> HashMap { + HashMap::new() +} + +#[inline(always)] +fn runtime_ode_predictions( + model: &NativeOdeModel, + subject: &Subject, + support_point: &[f64], +) -> Result { + if let Some(cache) = &model.cache { + let key = ( + subject.hash(), + crate::simulator::equation::spphash(support_point), + ); + if let Some(cached) = cache.get(&key) { + return Ok(cached); + } + + let result = model.estimate_predictions(subject, support_point)?; + cache.insert(key, result.clone()); + Ok(result) + } else { + model.estimate_predictions(subject, support_point) + } +} + +impl crate::simulator::equation::Cache for NativeOdeModel { + fn with_cache_capacity(mut self, size: u64) -> Self { + self.cache = Some(PredictionCache::new(size)); + self + } + + fn enable_cache(mut self) -> Self { + self.cache = Some(PredictionCache::new(DEFAULT_CACHE_SIZE)); + self + } + + fn clear_cache(&self) { + if let Some(cache) = &self.cache { + cache.invalidate_all(); + } + } + + fn disable_cache(mut self) -> Self { + self.cache = None; + self + } +} + +impl EquationTypes for NativeOdeModel { + type S = V; + type P = SubjectPredictions; +} + +impl EquationPriv for NativeOdeModel { + fn lag(&self) -> &Lag { + &(runtime_no_lag as Lag) + } + + fn fa(&self) -> &Fa { + &(runtime_no_fa as Fa) + } + + fn get_nstates(&self) -> usize { + self.shared.info.state_len + } + + fn get_ndrugs(&self) -> usize { + self.shared.info.route_len + } + + fn get_nouteqs(&self) -> usize { + self.shared.info.output_len + } + + fn metadata(&self) -> Option<&crate::ValidatedModelMetadata> { + None + } + + fn solve( + &self, + _state: &mut Self::S, + _support_point: &[f64], + _covariates: &Covariates, + _infusions: &[Infusion], + _start_time: f64, + _end_time: f64, + ) -> Result<(), PharmsolError> { + unimplemented!("solve is not used for runtime ODE models") + } + + fn process_observation( + &self, + _support_point: &[f64], + _observation: &Observation, + _error_models: Option<&AssayErrorModels>, + _time: f64, + _covariates: &Covariates, + _x: &mut Self::S, + _likelihood: &mut Vec, + _output: &mut Self::P, + ) -> Result<(), PharmsolError> { + unimplemented!("process_observation is not used for runtime ODE models") + } + + fn initial_state( + &self, + _support_point: &[f64], + _covariates: &Covariates, + _occasion_index: usize, + ) -> Self::S { + V::zeros(self.shared.info.state_len, NalgebraContext) + } +} + +impl Equation for NativeOdeModel { + fn estimate_likelihood( + &self, + subject: &Subject, + support_point: &[f64], + error_models: &AssayErrorModels, + ) -> Result { + Ok(self + .estimate_log_likelihood(subject, support_point, error_models)? + .exp()) + } + + fn estimate_log_likelihood( + &self, + subject: &Subject, + support_point: &[f64], + error_models: &AssayErrorModels, + ) -> Result { + let predictions = runtime_ode_predictions(self, subject, support_point)?; + predictions.log_likelihood(error_models) + } + + fn kind() -> EqnKind { + EqnKind::ODE + } + + fn estimate_predictions( + &self, + subject: &Subject, + support_point: &[f64], + ) -> Result { + runtime_ode_predictions(self, subject, support_point) + } + + fn simulate_subject( + &self, + subject: &Subject, + support_point: &[f64], + error_models: Option<&AssayErrorModels>, + ) -> Result<(Self::P, Option), PharmsolError> { + let predictions = runtime_ode_predictions(self, subject, support_point)?; + let likelihood = match error_models { + Some(error_models) => Some(predictions.log_likelihood(error_models)?.exp()), + None => None, + }; + Ok((predictions, likelihood)) + } +} + impl NativeAnalyticalModel { pub(crate) fn new(info: NativeModelInfo, artifact: impl RuntimeArtifact + 'static) -> Self { Self { diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index ba6dd5cd..59c399ab 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -414,6 +414,21 @@ bolus(11) -> central dx(central) = -ke * central out(cp) = central / v ~ continuous() +"#; + + const SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" +name = shared_numeric_route_output_runtime +kind = ode + +params = ke, v +states = central +outputs = 1 + +infusion(1) -> central + +dx(central) = -ke * central + +out(1) = central / v ~ continuous() "#; const UNDECLARED_NUMERIC_OUTPUT_LABEL_RUNTIME_DSL: &str = r#" @@ -551,6 +566,14 @@ out(cp) = central / v ~ continuous() .build() } + fn shared_numeric_route_output_subject() -> Subject { + Subject::builder("shared-numeric-route-output-runtime") + .infusion(0.0, 120.0, "1", 1.0) + .missing_observation(0.5, "1") + .missing_observation(1.5, "1") + .build() + } + fn assert_unknown_output_label( model: &CompiledRuntimeModel, subject: &Subject, @@ -714,6 +737,49 @@ out(cp) = central / v ~ continuous() } } + #[test] + fn runtime_backend_matrix_supports_shared_numeric_route_and_output_labels() { + let work_dir = tempdir().expect("tempdir"); + let support = vec![0.2, 10.0]; + let (jit, aot, wasm) = compile_runtime_backend_matrix( + SHARED_NUMERIC_ROUTE_OUTPUT_LABEL_RUNTIME_DSL, + "shared_numeric_route_output_runtime", + work_dir.path(), + ); + + assert_eq!(jit.route_index("1"), Some(0)); + assert_eq!(jit.output_index("1"), Some(0)); + assert_eq!(aot.route_index("1"), Some(0)); + assert_eq!(aot.output_index("1"), Some(0)); + assert_eq!(wasm.route_index("1"), Some(0)); + assert_eq!(wasm.output_index("1"), Some(0)); + + let subject = shared_numeric_route_output_subject(); + + let jit_values = subject_values( + &jit.estimate_predictions(&subject, &support) + .expect("jit predictions"), + ); + let aot_values = subject_values( + &aot.estimate_predictions(&subject, &support) + .expect("aot predictions"), + ); + let wasm_values = subject_values( + &wasm + .estimate_predictions(&subject, &support) + .expect("wasm predictions"), + ); + + for ((jit_value, aot_value), wasm_value) in jit_values + .iter() + .zip(aot_values.iter()) + .zip(wasm_values.iter()) + { + assert_relative_eq!(jit_value, aot_value, max_relative = 1e-4); + assert_relative_eq!(jit_value, wasm_value, max_relative = 1e-4); + } + } + #[test] fn runtime_backend_matrix_rejects_undeclared_numeric_output_labels() { let work_dir = tempdir().expect("tempdir"); From 6a067e9543e1938d6b5116534d93da59c0f2efd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Fri, 1 May 2026 20:42:18 +0100 Subject: [PATCH 11/12] chore: fmt --- pharmsol-dsl/src/semantic.rs | 44 +++++++++---------- .../tests/dsl_authoring_edge_cases.rs | 3 +- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/pharmsol-dsl/src/semantic.rs b/pharmsol-dsl/src/semantic.rs index 6a5e3b91..f20288a9 100644 --- a/pharmsol-dsl/src/semantic.rs +++ b/pharmsol-dsl/src/semantic.rs @@ -1620,28 +1620,28 @@ impl<'a> Analyzer<'a> { if let Some(existing) = self.globals.all_names.get(name).copied() { let existing_kind = self.symbols.get(existing).expect("valid symbol id").kind; if !allows_route_output_name_overlap(existing_kind, kind) { - return Err(SemanticAssist::default() - .context_label( - self.symbol_span(existing), - self.symbol_declared_here(existing), - ) - .help(format!( - "rename this declaration to a unique name such as `{}_2`", - name - )) - .replacement_suggestion( - span, - format!("{}_2", name), - format!("rename this declaration to `{}_2`", name), - Applicability::MaybeIncorrect, - ) - .apply(SemanticError::new( - format!( - "symbol name `{name}` collides with existing `{}`", - self.symbol_name(existing) - ), - span, - ))); + return Err(SemanticAssist::default() + .context_label( + self.symbol_span(existing), + self.symbol_declared_here(existing), + ) + .help(format!( + "rename this declaration to a unique name such as `{}_2`", + name + )) + .replacement_suggestion( + span, + format!("{}_2", name), + format!("rename this declaration to `{}_2`", name), + Applicability::MaybeIncorrect, + ) + .apply(SemanticError::new( + format!( + "symbol name `{name}` collides with existing `{}`", + self.symbol_name(existing) + ), + span, + ))); } } let id = self.symbols.len(); diff --git a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs index 3f4cb494..4d1651f5 100644 --- a/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs +++ b/pharmsol-dsl/tests/dsl_authoring_edge_cases.rs @@ -228,7 +228,8 @@ out(1) = central / v .first() .expect("authoring DSL should produce one model"); let typed = analyze_model(model).expect("shared numeric route/output labels should analyze"); - let lowered = lower_typed_model(&typed).expect("shared numeric route/output labels should lower"); + let lowered = + lower_typed_model(&typed).expect("shared numeric route/output labels should lower"); assert_eq!( lowered From 444c031b65f2932365ec474d821e8b350f9acdc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juli=C3=A1n=20D=2E=20Ot=C3=A1lvaro?= Date: Wed, 6 May 2026 20:21:42 +0100 Subject: [PATCH 12/12] chore: documentation --- src/data/builder.rs | 117 +++++++++++++++++++-------- src/data/event.rs | 102 +++++++++++++++++++----- src/data/mod.rs | 85 ++++++++++++++++---- src/data/parser/mod.rs | 12 +++ src/data/parser/pmetrics.rs | 59 +++++++++++--- src/data/row.rs | 121 +++++++++++++++++----------- src/dsl/aot.rs | 63 +++++++++++++++ src/dsl/jit.rs | 47 +++++++++++ src/dsl/mod.rs | 93 +++++++++++++++++++++- src/dsl/model_info.rs | 33 ++++++++ src/dsl/runtime.rs | 113 ++++++++++++++++++++++++++ src/dsl/wasm.rs | 28 +++++++ src/dsl/wasm_compile.rs | 55 +++++++++++++ src/lib.rs | 124 +++++++++++++++++++++++++++-- src/simulator/equation/metadata.rs | 103 +++++++++++++++++++++--- src/simulator/equation/mod.rs | 67 ++++++++++++++-- 16 files changed, 1066 insertions(+), 156 deletions(-) diff --git a/src/data/builder.rs b/src/data/builder.rs index a1718dc7..ed0a57a8 100644 --- a/src/data/builder.rs +++ b/src/data/builder.rs @@ -1,6 +1,21 @@ +//! Builder API for constructing [`Subject`] schedules in Rust. +//! +//! Use `Subject::builder(...)` when you want to describe a subject directly in +//! code with a schedule-oriented API. This is the preferred high-level +//! path for hand-written datasets. +//! +//! Builder methods accept public input and output labels. Prefer stable strings +//! such as `"depot"`, `"iv"`, and `"cp"`. Numeric values are accepted, but +//! they remain public labels rather than automatically becoming dense internal +//! indices. + use crate::{data::*, Censor}; -/// Extension trait for creating [Subject] instances using the builder pattern +/// Extension trait that enables `Subject::builder(...)`. +/// +/// Most users do not need to import [`SubjectBuilder`] directly. Import this +/// trait from the crate root or [`crate::prelude`] and then start with +/// `Subject::builder("id")`. pub trait SubjectBuilderExt { /// Create a new SubjectBuilder with the specified ID /// @@ -14,8 +29,8 @@ pub trait SubjectBuilderExt { /// use pharmsol::*; /// /// let subject = Subject::builder("patient_001") - /// .bolus(0.0, 100.0, 0) - /// .observation(1.0, 10.5, 0) + /// .bolus(0.0, 100.0, "depot") + /// .observation(1.0, 10.5, "cp") /// .build(); /// ``` fn builder(id: impl Into) -> SubjectBuilder; @@ -34,11 +49,37 @@ impl SubjectBuilderExt for Subject { } } -/// Builder for creating [Subject] instances with a fluent API +/// Builder for creating [`Subject`] values with a fluent API. +/// +/// Use [`SubjectBuilder`] when you want to author common dose and observation +/// schedules directly in Rust without constructing low-level event values by +/// hand. +/// +/// A builder instance accumulates events inside the current [`Occasion`]. +/// [`SubjectBuilder::repeat`] duplicates the most recently added event at later +/// times, and [`SubjectBuilder::reset`] closes the current occasion and starts a +/// new one with fresh occasion-local state. +/// +/// Input and output arguments are public labels. Prefer stable model-facing +/// names such as `"depot"`, `"iv"`, and `"cp"`. +/// +/// # Example +/// +/// ```rust +/// use pharmsol::*; +/// +/// let subject = Subject::builder("patient_001") +/// .bolus(0.0, 100.0, "depot") +/// .repeat(1, 24.0) +/// .observation(1.0, 12.3, "cp") +/// .missing_observation(25.0, "cp") +/// .reset() +/// .bolus(0.0, 80.0, "depot") +/// .observation(1.0, 10.1, "cp") +/// .build(); /// -/// The [SubjectBuilder] allows for constructing complex subject data with a -/// chainable, readable syntax. Events like doses and observations can be -/// added sequentially, and the builder handles organizing them into occasions. +/// assert_eq!(subject.occasions().len(), 2); +/// ``` #[derive(Debug, Clone)] pub struct SubjectBuilder { id: String, @@ -49,37 +90,39 @@ pub struct SubjectBuilder { } impl SubjectBuilder { - /// Add an event to the current occasion + /// Add a fully constructed event to the current occasion. /// - /// # Arguments - /// - /// * `event` - The event to add + /// Use this when you want to mix builder convenience methods with direct + /// [`Event`] values. pub fn event(mut self, event: Event) -> Self { self.last_added_event = Some(event.clone()); self.current_occasion.add_event(event); self } - /// Add a bolus dosing event + /// Add an instantaneous dose. /// /// # Arguments /// /// * `time` - Time of the bolus dose /// * `amount` - Amount of drug administered - /// * `input` - The compartment number receiving the dose + /// * `input` - Public input label receiving the dose + /// + /// Prefer stable route names such as `"depot"` or `"iv"` when the model + /// declares named routes. pub fn bolus(self, time: f64, amount: f64, input: impl ToString) -> Self { let bolus = Bolus::new(time, amount, input, self.current_occasion.index()); let event = Event::Bolus(bolus); self.event(event) } - /// Add an infusion event + /// Add a continuous dose over a duration. /// /// # Arguments /// /// * `time` - Start time of the infusion /// * `amount` - Total amount of drug to be administered - /// * `input` - The compartment number receiving the dose + /// * `input` - Public input label receiving the dose /// * `duration` - Duration of the infusion in time units pub fn infusion(self, time: f64, amount: f64, input: impl ToString, duration: f64) -> Self { let infusion = Infusion::new(time, amount, input, duration, self.current_occasion.index()); @@ -87,13 +130,13 @@ impl SubjectBuilder { self.event(event) } - /// Add an observation + /// Add an observed value at a given time. /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number corresponding to this observation + /// * `outeq` - Public output label for this observation pub fn observation(self, time: f64, value: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, @@ -107,13 +150,14 @@ impl SubjectBuilder { self.event(event) } - /// Add a censored observation + /// Add an observed value with explicit censoring information. + /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number (zero-indexed) corresponding to this - /// observation + /// * `outeq` - Public output label for this observation + /// * `censoring` - Censoring status for the observation value pub fn censored_observation( self, time: f64, @@ -133,12 +177,15 @@ impl SubjectBuilder { self.event(event) } - /// Add an observation + /// Add a prediction-only observation slot. /// /// # Arguments /// /// * `time` - Time of the observation - /// * `outeq` - Output equation number (zero-indexed) corresponding to this observation + /// * `outeq` - Public output label for this observation + /// + /// Use this when you want a prediction at a time point but do not have an + /// observed value. pub fn missing_observation(self, time: f64, outeq: impl ToString) -> Self { let observation = Observation::new( time, @@ -152,15 +199,15 @@ impl SubjectBuilder { self.event(event) } - /// Add an observation with a specific error polynomial + /// Add an observed value with an explicit assay error polynomial. /// /// # Arguments /// /// * `time` - Time of the observation /// * `value` - Observed value (e.g., drug concentration) - /// * `outeq` - Output equation number (zero-indexed) corresponding to this observation + /// * `outeq` - Public output label for this observation /// * `errorpoly` - Error polynomial coefficients (c0, c1, c2, c3) - /// * `censored` - Whether the observation is censored + /// * `censored` - Censoring status for the observation value pub fn observation_with_error( self, time: f64, @@ -181,7 +228,10 @@ impl SubjectBuilder { self.event(event) } - /// Repeat the last event `n` times, separated by some interval `delta` + /// Repeat the last event `n` times, separated by `delta`. + /// + /// The repeated events keep the same label, value, censoring state, and + /// error polynomial as the original event. Only the event time changes. /// /// # Arguments /// @@ -193,9 +243,8 @@ impl SubjectBuilder { /// ```rust /// use pharmsol::*; /// - /// /// let subject = Subject::builder("patient_001") - /// .bolus(0.0, 100.0, 0) // First dose at time 0 + /// .bolus(0.0, 100.0, "depot") // First dose at time 0 /// .repeat(3, 24.0) // Repeat the dose at times 24, 48, and 72 /// .build(); /// ``` @@ -255,12 +304,14 @@ impl SubjectBuilder { self } - /// Complete the current occasion and start a new one + /// Complete the current occasion and start a new one. /// /// This finalizes the current occasion, adds it to the subject, /// and creates a new occasion for subsequent events. - /// This is useful if a patient has new observations at some other occasion. - /// Note that all states are reset! + /// Use this when the subject should begin a new occasion with reset state. + /// + /// Covariates collected since the previous reset are attached to the + /// finished occasion. The new occasion starts empty and its state is reset. pub fn reset(mut self) -> Self { let block_index = self.current_occasion.index() + 1; self.current_occasion.sort(); @@ -274,7 +325,7 @@ impl SubjectBuilder { self } - /// Add a covariate value at a specific time + /// Add a covariate value at a specific time. /// /// Multiple calls for the same covariate at different times will create /// linear interpolation between the time points. @@ -300,7 +351,7 @@ impl SubjectBuilder { self } - /// Finalize and build the Subject + /// Finalize and build the [`Subject`]. /// /// This completes the current occasion and returns a new Subject with all /// the accumulated data. diff --git a/src/data/event.rs b/src/data/event.rs index bff4c700..02a4c9a7 100644 --- a/src/data/event.rs +++ b/src/data/event.rs @@ -1,3 +1,15 @@ +//! Event types and public label wrappers for subject schedules. +//! +//! These types are the low-level representation behind the higher-level +//! builder and parsing APIs. Most users can start with +//! [`crate::data::builder::SubjectBuilder`], then inspect or transform +//! [`Event`] values after construction. +//! +//! Dose events carry an [`InputLabel`], and observations carry an +//! [`OutputLabel`]. Prefer stable strings such as `"depot"`, `"iv"`, and +//! `"cp"`. Numeric values are accepted, but they remain labels until a +//! downstream workflow explicitly interprets them as indices. + use crate::data::error_model::ErrorPoly; use crate::prelude::simulator::Prediction; use serde::{Deserialize, Serialize}; @@ -7,12 +19,16 @@ use std::fmt; // Shared Analysis Types // ============================================================================ -/// Administration route for a dosing event +/// Administration route classification used by downstream analyses. +/// +/// [`Route`] is a coarse route category, not the original public input label. +/// In the current data-side heuristic: +/// - [`Event::Infusion`] maps to [`Route::IVInfusion`] +/// - [`Event::Bolus`] with input label `0` maps to [`Route::Extravascular`] +/// - [`Event::Bolus`] with any other label maps to [`Route::IVBolus`] /// -/// Determined by the type of dose events and their target compartment: -/// - [`Event::Infusion`] → [`Route::IVInfusion`] -/// - [`Event::Bolus`] with `input >= 1` (central compartment) → [`Route::IVBolus`] -/// - [`Event::Bolus`] with `input == 0` (depot compartment) → [`Route::Extravascular`] +/// If you need the original model-facing label, read [`Bolus::input`] or +/// [`Infusion::input`] instead. #[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] pub enum Route { /// Intravenous bolus @@ -78,12 +94,15 @@ pub enum BLQRule { }, } -/// Represents a pharmacokinetic/pharmacodynamic event +/// One scheduled item in a subject record. +/// +/// Events are the low-level representation for doses and observations: +/// - [`Bolus`] for instantaneous input +/// - [`Infusion`] for input over a duration +/// - [`Observation`] for measured or missing outputs /// -/// Events represent key occurrences in a PK/PD profile, including: -/// - [Bolus] doses (instantaneous drug input) -/// - [Infusion]s (continuous drug input over a duration) -/// - [Observation]s (measured concentrations or other values) +/// Most users create these through `Subject::builder(...)`, row ingestion, or +/// file parsing rather than constructing them all by hand. #[derive(Serialize, Debug, Clone, Deserialize)] pub enum Event { /// A bolus dose (instantaneous drug input) @@ -95,21 +114,31 @@ pub enum Event { } macro_rules! impl_label_type { - ($name:ident) => { + ($(#[$meta:meta])* $name:ident) => { + $(#[$meta])* #[derive( Debug, Clone, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize, )] pub struct $name(String); impl $name { + /// Create a new public label. + /// + /// Prefer stable names when the model declares named routes or + /// outputs. pub fn new(label: impl ToString) -> Self { Self(label.to_string()) } + /// Borrow the stored label as a string. pub fn as_str(&self) -> &str { &self.0 } + /// Try to interpret the label as a numeric index. + /// + /// This is mainly a compatibility helper for lower-level paths that + /// still operate on dense indices after label resolution. pub fn index(&self) -> Option { self.0.parse::().ok() } @@ -171,8 +200,20 @@ macro_rules! impl_label_type { }; } -impl_label_type!(InputLabel); -impl_label_type!(OutputLabel); +impl_label_type!( + /// Public label for a dosing input or route. + /// + /// [`Bolus`] and [`Infusion`] store the original user-facing route name in + /// this type. + InputLabel +); +impl_label_type!( + /// Public label for an observation output. + /// + /// [`Observation`] stores the original user-facing output name in this + /// type. + OutputLabel +); impl Event { /// Get the time of the event @@ -226,9 +267,10 @@ impl Event { } } -/// Represents an instantaneous input of drug +/// Instantaneous dose input. /// -/// A [Bolus] is a discrete amount of drug added to a specific compartment at a specific time. +/// A [`Bolus`] records one discrete amount at one time, tagged with the public +/// input label that should be matched against the model. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Bolus { time: f64, @@ -263,6 +305,9 @@ impl Bolus { &self.input } + /// Try to interpret the input label as a numeric index. + /// + /// Prefer [`Bolus::input`] when working with the public label itself. pub fn input_index(&self) -> Option { self.input.index() } @@ -313,9 +358,10 @@ impl Bolus { } } -/// Represents a continuous dose of drug over time +/// Continuous dose input over a duration. /// -/// An [Infusion] administers drug at a constant rate over a specified duration. +/// An [`Infusion`] records the total amount, start time, duration, and public +/// input label for one infusion event. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Infusion { time: f64, @@ -359,6 +405,9 @@ impl Infusion { &self.input } + /// Try to interpret the input label as a numeric index. + /// + /// Prefer [`Infusion::input`] when working with the public label itself. pub fn input_index(&self) -> Option { self.input.index() } @@ -438,7 +487,11 @@ pub enum Censor { ALOQ, } -/// Represents an observation of drug concentration or other measured value + /// Observation of a model output. + /// + /// An [`Observation`] can carry a measured value or `None` for a prediction-only + /// time point. Observations also carry the public output label, optional assay + /// error polynomial, occasion index, and censoring state. #[derive(Serialize, Debug, Clone, Deserialize)] pub struct Observation { time: f64, @@ -482,7 +535,9 @@ impl Observation { self.time } - /// Get the value of the observation (e.g., drug concentration) + /// Get the value of the observation. + /// + /// `None` means this is a prediction-only or missing-observation slot. pub fn value(&self) -> Option { self.value } @@ -492,6 +547,9 @@ impl Observation { &self.outeq } + /// Try to interpret the output label as a numeric index. + /// + /// Prefer [`Observation::outeq`] when working with the public label itself. pub fn outeq_index(&self) -> Option { self.outeq.index() } @@ -553,7 +611,11 @@ impl Observation { &mut self.occasion } - /// Create a [Prediction] from this observation + /// Create a [`Prediction`] from this observation. + /// + /// This is a low-level helper for code paths that already operate on a + /// resolved or numeric output index. Named output labels must be resolved by + /// the caller before this conversion happens. pub fn to_prediction(&self, pred: f64, state: Vec) -> Prediction { Prediction { time: self.time(), diff --git a/src/data/mod.rs b/src/data/mod.rs index 996c791d..28a80b32 100644 --- a/src/data/mod.rs +++ b/src/data/mod.rs @@ -1,32 +1,83 @@ -//! Data structures and utilities for pharmacometric modeling +//! Data structures for building pharmacometric input data. //! -//! This module provides types for representing pharmacokinetic/pharmacodynamic data, -//! including subjects, dosing events, observations, and covariates. It also includes -//! utilities for reading and manipulating this data. +//! Use this module when you need to describe what happened to each subject: +//! doses, infusions, observations, covariates, and occasion boundaries. //! -//! # Key Components +//! This module is the input side of `pharmsol`. It is where you assemble +//! subjects and datasets before simulation, estimation, or NCA. It is not where +//! you define model equations or choose a backend. For those workflows, move to +//! [`crate::simulator`], [`crate::nca`], or the feature-gated `pharmsol::dsl` +//! surface. //! -//! - **Events**: Dosing events (bolus, infusion) and observations -//! - **Covariates**: Time-varying subject characteristics -//! - **Subjects**: Collections of events and covariates for a single individual -//! - **Data**: Collections of subjects, representing a complete dataset -//! - **Error Models**: Two types for different algorithm families: -//! - [`ErrorModel`]: Observation-based (assay error) for non-parametric algorithms -//! - [`ResidualErrorModel`]: Prediction-based (residual error) for parametric algorithms +//! # Start Here //! -//! # Examples +//! Most users only need three entrypoints first: //! -//! Creating a subject with the builder pattern: +//! - [`Subject`] for one individual and their full schedule. +//! - [`Data`] for a dataset containing many subjects. +//! - `Subject::builder` for the smallest fluent API to create doses, +//! observations, and covariates in Rust. +//! +//! The main supporting types are: +//! +//! - [`Occasion`] for repeated periods within one subject. +//! - [`Event`], [`Bolus`], [`Infusion`], and [`Observation`] for explicit +//! event-level control. +//! - [`Covariate`] and [`Covariates`] for time-varying subject characteristics. +//! - [`ErrorModel`], [`ResidualErrorModel`], and [`ObservationError`] for the +//! different error surfaces used by downstream workflows. +//! +//! # Choose A Data Input Path +//! +//! - Use `Subject::builder` when you are authoring a schedule directly in Rust. +//! - Use [`row::DataRow`] and [`row::DataRowBuilder`] when your source data is +//! already row-shaped in memory. +//! - Use [`parser::read_pmetrics`] when you are loading a Pmetrics-style file +//! from disk. +//! - Use [`Event`] variants directly when you already have validated event +//! records and need lower-level control than the builder offers. +//! +//! # Label Semantics +//! +//! Dosing inputs and observation outputs use public labels. +//! +//! - The `input` on [`Bolus`] and [`Infusion`] is the route or input label that +//! will be matched against the model. +//! - The `outeq` on [`Observation`] is the output label that identifies which +//! model output the observation belongs to. +//! - Prefer stable names such as `"depot"`, `"central"`, `"iv"`, or `"cp"`. +//! - If you pass a number, it is still treated as a public label string. Use +//! numeric values only when your model intentionally declares numeric labels. +//! +//! [`Occasion`] indices are different: they are integer period markers used to +//! separate repeated dosing blocks within one subject. +//! +//! # Error Surfaces +//! +//! This module exposes three related but different error families: +//! +//! - [`ErrorModel`] for assay or measurement error driven by the observation +//! value, commonly used in non-parametric workflows. +//! - [`ResidualErrorModel`] for residual unexplained variability driven by the +//! prediction value, commonly used in parametric workflows. +//! - [`ObservationError`] for invalid or insufficient observation data during +//! profile construction and related preprocessing. +//! +//! # Example //! //! ```rust //! use pharmsol::*; //! //! let subject = Subject::builder("patient_001") -//! .bolus(0.0, 100.0, 0) -//! .observation(1.0, 10.5, 0) -//! .observation(2.0, 8.2, 0) +//! .bolus(0.0, 100.0, "depot") +//! .observation(1.0, 12.3, "cp") +//! .missing_observation(2.0, "cp") //! .covariate("weight", 0.0, 70.0) //! .build(); +//! +//! let data = Data::new(vec![subject]); +//! +//! assert_eq!(data.subjects().len(), 1); //! ``` pub mod auc; diff --git a/src/data/parser/mod.rs b/src/data/parser/mod.rs index 7bfde3ca..74a50a84 100644 --- a/src/data/parser/mod.rs +++ b/src/data/parser/mod.rs @@ -1,3 +1,15 @@ +//! File-based parsers and parser-facing row utilities. +//! +//! Use this module when your source data starts as files or parser-shaped rows. +//! It re-exports the row ingestion API from [`crate::data::row`] and provides +//! format-specific loaders such as [`read_pmetrics`]. +//! +//! Choose the entrypoint by source shape: +//! - Use [`DataRow`] or [`build_data`] when you already mapped external data into +//! canonical row fields yourself. +//! - Use [`read_pmetrics`] when the source file already follows the Pmetrics CSV +//! convention. + pub mod pmetrics; pub use crate::data::row::{build_data, DataError, DataRow, DataRowBuilder}; diff --git a/src/data/parser/pmetrics.rs b/src/data/parser/pmetrics.rs index 89943f6e..2c90e2a7 100644 --- a/src/data/parser/pmetrics.rs +++ b/src/data/parser/pmetrics.rs @@ -1,3 +1,12 @@ +//! Pmetrics CSV parsing and export helpers. +//! +//! This module reads and writes the Pmetrics-style tabular format while keeping +//! pharmsol's public input and output labels intact. +//! +//! `INPUT` and `OUTEQ` values are parsed as labels, not rewritten to dense +//! indices. Named values such as `iv` and `cp` are preserved exactly, and +//! numeric values such as `1` are preserved as numeric-looking labels. + use crate::{data::*, PharmsolError}; use csv::WriterBuilder; use serde::de::{MapAccess, Visitor}; @@ -10,19 +19,27 @@ use crate::data::row::DataRow; use std::fmt; use std::str::FromStr; -/// Read a Pmetrics datafile and convert it to a [Data] object +/// Read a Pmetrics CSV file into [`Data`]. +/// +/// Use [`read_pmetrics`] when the source file already follows the usual +/// Pmetrics column convention instead of mapping the file into [`DataRow`] +/// values yourself. /// -/// This function parses a Pmetrics-formatted CSV file and constructs a [Data] object containing the structured -/// pharmacokinetic/pharmacodynamic data. The function handles various data formats including doses, observations, -/// and covariates. +/// The parser normalizes header names to lowercase, preserves `INPUT` and +/// `OUTEQ` as public labels, expands `ADDL` dosing rows through the shared row +/// ingestion path, and groups rows into occasions using `EVID=4`. +/// +/// All columns not claimed by the core Pmetrics schema are treated as +/// covariates. /// /// # Arguments /// -/// * `path` - The path to the Pmetrics CSV file +/// * `path` - Path to the Pmetrics CSV file /// /// # Returns /// -/// * `Result` - A result containing either the parsed [Data] object or an error +/// A parsed [`Data`] object or a [`DataError`] if the file cannot be read or a +/// required row field is missing. /// /// # Example /// @@ -33,14 +50,25 @@ use std::str::FromStr; /// println!("Number of subjects: {}", data.subjects().len()); /// ``` /// -/// # Format details +/// # Expected columns +/// +/// The canonical columns are `ID`, `TIME`, `EVID`, `DOSE`, `DUR`, `ADDL`, +/// `II`, `INPUT`, `OUT`, `OUTEQ`, `CENS`, and optional `C0..C3` error +/// coefficients. /// -/// The Pmetrics format expects columns like ID, TIME, EVID, DOSE, DUR, etc. The function will: +/// All other numeric columns are treated as covariates. +/// +/// # Parsing behavior +/// +/// The parser will: /// - Convert all headers to lowercase for case-insensitivity /// - Group rows by subject ID /// - Create occasions based on EVID=4 events /// - Parse covariates and create appropriate interpolations /// - Handle additional doses via ADDL and II fields +/// - Preserve raw `INPUT` and `OUTEQ` labels as strings until model resolution +/// - Treat `OUT=-99` as a missing observation value, matching the common +/// Pmetrics convention /// /// For specific column definitions, see the `Row` struct. #[allow(dead_code)] @@ -72,7 +100,7 @@ pub fn read_pmetrics(path: impl Into) -> Result { build_data(data_rows) } -/// A [Row] represents a row in the Pmetrics data format +/// One row from a Pmetrics file after serde deserialization. #[derive(Deserialize, Debug, Serialize, Default, Clone)] #[serde(rename_all = "lowercase")] struct Row { @@ -94,13 +122,13 @@ struct Row { /// Dosing interval #[serde(deserialize_with = "deserialize_option_f64")] ii: Option, - /// Input compartment + /// Input label from the `INPUT` column #[serde(deserialize_with = "deserialize_option_route_label")] input: Option, /// Observed value #[serde(deserialize_with = "deserialize_option_f64")] out: Option, - /// Corresponding output equation for the observation + /// Output label from the `OUTEQ` column #[serde(deserialize_with = "deserialize_option_output_label")] outeq: Option, /// Censoring output @@ -264,7 +292,14 @@ where } impl Data { - /// Write the dataset to a file in Pmetrics format + /// Write the dataset to a file in Pmetrics format. + /// + /// `INPUT` and `OUTEQ` are written using their stored public labels. Named + /// labels such as `iv` and `cp` remain named labels, and numeric-looking + /// labels are written back exactly as stored. + /// + /// Missing optional fields are emitted as `.` placeholders to match the + /// usual Pmetrics text convention. /// /// # Arguments /// diff --git a/src/data/row.rs b/src/data/row.rs index b9a807c1..fcb610ea 100644 --- a/src/data/row.rs +++ b/src/data/row.rs @@ -1,34 +1,51 @@ -//! Row representation of [Data] for flexible parsing +//! Row-shaped data ingestion for [`Data`] and [`Subject`] assembly. +//! +//! Use this module when your source data already looks like rows from a table, +//! CSV file, database export, or ETL pipeline. +//! +//! Choose the ingestion path by source shape: +//! - Use [`crate::data::builder::SubjectBuilder`] when you want to author a +//! schedule directly in Rust. +//! - Use [`DataRow`] and [`build_data`] when your application already has +//! validated row records in memory. +//! - Use [`crate::data::parser::read_pmetrics`] when the source file already +//! follows the Pmetrics column convention. +//! +//! [`DataRow`] keeps public route and output labels as strings. Labels such as +//! `"iv"`, `"depot"`, and `"cp"` are preserved through row parsing and later +//! resolved against model metadata by downstream workflows. //! //! # Example //! //! ```rust //! use pharmsol::data::parser::DataRow; //! -//! // Create a dosing row with ADDL expansion //! let row = DataRow::builder("subject_1", 0.0) //! .evid(1) //! .dose(100.0) -//! .input(1) -//! .addl(3) // 3 additional doses -//! .ii(12.0) // 12 hours apart +//! .input("iv") +//! .addl(3) +//! .ii(12.0) //! .build(); //! //! let events = row.into_events().unwrap(); -//! assert_eq!(events.len(), 4); // Original + 3 additional doses +//! assert_eq!(events.len(), 4); //! ``` -//! use crate::data::*; use std::collections::HashMap; use thiserror::Error; -/// A format-agnostic representation of a single data row +/// A format-agnostic representation of one input row. +/// +/// [`DataRow`] collects the canonical fields needed to turn one external row +/// into one or more [`Event`] values. /// -/// This struct represents the canonical fields needed to create pharmsol Events. -/// Consumers construct this from their source data (regardless of column names), -/// then call [`into_events()`](DataRow::into_events) to get properly parsed -/// Events with full ADDL expansion, EVID handling, censoring, etc. +/// Build this type from your own column mapping or external schema, then call +/// [`DataRow::into_events`] or [`build_data`] to assemble subjects and datasets. +/// +/// A single row can expand into several events when `ADDL` and `II` are both +/// present. /// /// # Fields /// @@ -42,24 +59,22 @@ use thiserror::Error; /// ```rust /// use pharmsol::data::parser::DataRow; /// -/// // Observation row /// let obs = DataRow::builder("pt1", 1.0) /// .evid(0) /// .out(25.5) -/// .outeq(1) +/// .outeq("cp") /// .build(); /// -/// // Dosing row with negative ADDL (doses before time 0) /// let dose = DataRow::builder("pt1", 0.0) /// .evid(1) /// .dose(100.0) -/// .input(1) -/// .addl(-10) // 10 doses BEFORE time 0 +/// .input("iv") +/// .addl(-10) /// .ii(12.0) /// .build(); /// /// let events = dose.into_events().unwrap(); -/// // Events at times: -120, -108, -96, ..., -12, 0 +/// assert_eq!(obs.outeq.as_ref().map(|label| label.as_str()), Some("cp")); /// assert_eq!(events.len(), 11); /// ``` #[derive(Debug, Clone, Default)] @@ -99,7 +114,7 @@ pub struct DataRow { } impl DataRow { - /// Create a new builder for constructing a DataRow + /// Create a builder for constructing one [`DataRow`]. /// /// # Arguments /// @@ -114,7 +129,7 @@ impl DataRow { /// let row = DataRow::builder("patient_001", 0.0) /// .evid(1) /// .dose(100.0) - /// .input(1) + /// .input("depot") /// .build(); /// ``` pub fn builder(id: impl Into, time: f64) -> DataRowBuilder { @@ -129,13 +144,14 @@ impl DataRow { } } - /// Convert this row into pharmsol Events + /// Convert this row into one or more [`Event`] values. /// - /// This method contains all the complex parsing logic: + /// This method performs the row-level translation logic: /// - EVID interpretation (0=observation, 1=dose, 4=reset) /// - ADDL/II expansion (both positive and negative directions) /// - Infusion vs bolus detection based on DUR /// - Censoring and error polynomial handling + /// - Preservation of public input and output labels /// /// # ADDL Expansion /// @@ -163,13 +179,13 @@ impl DataRow { /// let row = DataRow::builder("pt1", 0.0) /// .evid(1) /// .dose(100.0) - /// .input(1) + /// .input("iv") /// .addl(2) /// .ii(24.0) /// .build(); /// /// let events = row.into_events().unwrap(); - /// assert_eq!(events.len(), 3); // doses at 24, 48, and 0 + /// assert_eq!(events.len(), 3); /// /// let times: Vec = events.iter().map(|e| e.time()).collect(); /// assert_eq!(times, vec![24.0, 48.0, 0.0]); @@ -287,7 +303,11 @@ impl DataRow { } } -/// Builder for constructing DataRow with a fluent API +/// Fluent builder for [`DataRow`]. +/// +/// Use [`DataRowBuilder`] when you have row-shaped data in memory and want to +/// construct rows incrementally before calling [`DataRow::into_events`] or +/// [`build_data`]. /// /// # Example /// @@ -298,7 +318,7 @@ impl DataRow { /// let row = DataRow::builder("patient_001", 1.5) /// .evid(0) /// .out(25.5) -/// .outeq(1) +/// .outeq("cp") /// .cens(Censor::None) /// .covariate("weight", 70.0) /// .covariate("age", 45.0) @@ -373,10 +393,11 @@ impl DataRowBuilder { self } - /// Set the input route label + /// Set the input route label. /// - /// Required for EVID=1 (dosing events). - /// Preserved as the public route label until model resolution. + /// Required for EVID=1 dosing rows. + /// The provided value is preserved as the public label until downstream + /// model resolution. pub fn input(mut self, input: impl ToString) -> Self { self.row.input = Some(InputLabel::new(input)); self @@ -390,10 +411,11 @@ impl DataRowBuilder { self } - /// Set the output label + /// Set the output label. /// - /// Required for EVID=0 (observation events). - /// Preserved as the public output label until model resolution. + /// Required for EVID=0 observation rows. + /// The provided value is preserved as the public label until downstream + /// model resolution. pub fn outeq(mut self, outeq: impl ToString) -> Self { self.row.outeq = Some(OutputLabel::new(outeq)); self @@ -436,13 +458,18 @@ impl DataRowBuilder { } } -/// Build a [Data] object from an iterator of [DataRow]s +/// Build a [`Data`] object from row-shaped input. /// -/// This function handles all the complex assembly logic: +/// This function assembles rows into subjects and occasions: /// - Groups rows by subject ID /// - Splits into occasions at EVID=4 boundaries /// - Converts rows to events via [`DataRow::into_events()`] /// - Builds covariates from row covariate data +/// - Preserves per-subject row order within each occasion block +/// +/// Use this when you already have a collection of [`DataRow`] values in memory. +/// If your source file is a Pmetrics CSV, use [`crate::data::parser::read_pmetrics`] +/// instead. /// /// # Example /// @@ -450,23 +477,21 @@ impl DataRowBuilder { /// use pharmsol::data::parser::{DataRow, build_data}; /// /// let rows = vec![ -/// // Subject 1, Occasion 0 /// DataRow::builder("pt1", 0.0) -/// .evid(1).dose(100.0).input(1).build(), +/// .evid(1).dose(100.0).input("iv").build(), /// DataRow::builder("pt1", 1.0) -/// .evid(0).out(50.0).outeq(1).build(), -/// // Subject 1, Occasion 1 (EVID=4 starts new occasion) +/// .evid(0).out(50.0).outeq("cp").build(), /// DataRow::builder("pt1", 24.0) -/// .evid(4).dose(100.0).input(1).build(), +/// .evid(4).dose(100.0).input("iv").build(), /// DataRow::builder("pt1", 25.0) -/// .evid(0).out(48.0).outeq(1).build(), -/// // Subject 2 +/// .evid(0).out(48.0).outeq("cp").build(), /// DataRow::builder("pt2", 0.0) -/// .evid(1).dose(50.0).input(1).build(), +/// .evid(1).dose(50.0).input("iv").build(), /// ]; /// /// let data = build_data(rows).unwrap(); /// assert_eq!(data.subjects().len(), 2); +/// assert_eq!(data.subjects()[0].occasions().len(), 2); /// ``` pub fn build_data(rows: impl IntoIterator) -> Result { // Group rows by subject ID @@ -562,14 +587,14 @@ pub enum DataError { /// Required observation value (OUT) is missing #[error("Observation OUT is missing for {id} at time {time}")] MissingObservationOut { id: String, time: f64 }, - /// Required observation output equation (OUTEQ) is missing - #[error("Observation OUTEQ is missing in for {id} at time {time}")] + /// Required observation output label (`OUTEQ`) is missing + #[error("Observation OUTEQ is missing for {id} at time {time}")] MissingObservationOuteq { id: String, time: f64 }, /// Required infusion dose amount is missing #[error("Infusion amount (DOSE) is missing for {id} at time {time}")] MissingInfusionDose { id: String, time: f64 }, - /// Required infusion input compartment is missing - #[error("Infusion compartment (INPUT) is missing for {id} at time {time}")] + /// Required infusion input label (`INPUT`) is missing + #[error("Infusion input label (INPUT) is missing for {id} at time {time}")] MissingInfusionInput { id: String, time: f64 }, /// Required infusion duration is missing #[error("Infusion duration (DUR) is missing for {id} at time {time}")] @@ -577,8 +602,8 @@ pub enum DataError { /// Required bolus dose amount is missing #[error("Bolus amount (DOSE) is missing for {id} at time {time}")] MissingBolusDose { id: String, time: f64 }, - /// Required bolus input compartment is missing - #[error("Bolus compartment (INPUT) is missing for {id} at time {time}")] + /// Required bolus input label (`INPUT`) is missing + #[error("Bolus input label (INPUT) is missing for {id} at time {time}")] MissingBolusInput { id: String, time: f64 }, } diff --git a/src/dsl/aot.rs b/src/dsl/aot.rs index 2a46409a..6557e015 100644 --- a/src/dsl/aot.rs +++ b/src/dsl/aot.rs @@ -37,18 +37,23 @@ use pharmsol_dsl::ModelKind; use pharmsol_dsl::{analyze_module, lower_typed_model, parse_module, ExecutionModel}; use pharmsol_dsl::{Diagnostic, DiagnosticReport, LoweringError, ParseError, SemanticError}; +/// ABI version for native AoT artifacts produced by this crate. pub const AOT_API_VERSION: u32 = 1; #[cfg(feature = "dsl-aot")] +/// Selects the compilation target for a native ahead-of-time artifact. #[derive(Debug, Clone, PartialEq, Eq, Default)] pub enum NativeAotTarget { + /// Compile for the current host toolchain target. #[default] Host, + /// Compile for an explicit Rust target triple. Triple(String), } #[cfg(feature = "dsl-aot")] impl NativeAotTarget { + /// Create a target selector for an explicit Rust target triple. pub fn triple(target: impl Into) -> Self { Self::Triple(target.into()) } @@ -62,15 +67,24 @@ impl NativeAotTarget { } #[cfg(feature = "dsl-aot")] +/// Options that control native ahead-of-time artifact export. +/// +/// AoT export writes a small template crate under [`template_root`](Self::template_root), +/// builds a native shared library, and then copies the resulting artifact to +/// [`output`](Self::output) or a generated default path. #[derive(Debug, Clone, PartialEq, Eq)] pub struct NativeAotCompileOptions { + /// Target triple selection for the emitted artifact. pub target: NativeAotTarget, + /// Optional final artifact location. pub output: Option, + /// Working directory used for the temporary template crate and build output. pub template_root: PathBuf, } #[cfg(feature = "dsl-aot")] impl NativeAotCompileOptions { + /// Create AoT options rooted at a template build directory. pub fn new(template_root: PathBuf) -> Self { Self { target: NativeAotTarget::Host, @@ -79,17 +93,20 @@ impl NativeAotCompileOptions { } } + /// Set the final artifact output path. pub fn with_output(mut self, output: PathBuf) -> Self { self.output = Some(output); self } + /// Set the compilation target triple. pub fn with_target(mut self, target: NativeAotTarget) -> Self { self.target = target; self } } +/// Error produced while exporting, reading, or loading a native AoT artifact. #[derive(Error)] pub enum AotError { #[error(transparent)] @@ -151,6 +168,43 @@ impl fmt::Debug for AotError { } #[cfg(feature = "dsl-aot")] +/// Parse DSL source, lower one selected model, and export a native AoT artifact. +/// +/// Use this when you want a reusable native artifact that can be loaded later +/// with [`load_aot_model`] or [`crate::dsl::load_runtime_artifact`]. +/// +/// This function requires the `dsl-aot` feature. Loading the resulting artifact +/// later requires `dsl-aot-load`. +/// +/// ```rust,no_run +/// use std::path::PathBuf; +/// +/// use pharmsol::dsl::{compile_module_source_to_aot, load_aot_model, NativeAotCompileOptions}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let artifact = compile_module_source_to_aot( +/// source, +/// Some("bimodal_ke"), +/// NativeAotCompileOptions::new(PathBuf::from("target/doc-aot-build")), +/// |_, _| {}, +/// )?; +/// let loaded = load_aot_model(&artifact)?; +/// # let _ = loaded; +/// # Ok::<(), Box>(()) +/// ``` pub fn compile_module_source_to_aot( source: &str, model_name: Option<&str>, @@ -184,6 +238,10 @@ pub fn compile_module_source_to_aot( } #[cfg(feature = "dsl-aot")] +/// Export a lowered execution model as a native AoT artifact. +/// +/// Use this lower-level entrypoint when you already own the frontend pipeline +/// and only need artifact generation. pub fn export_execution_model_to_aot( model: &ExecutionModel, options: NativeAotCompileOptions, @@ -240,6 +298,10 @@ pub fn export_execution_model_to_aot( } #[cfg(feature = "dsl-aot-load")] +/// Read only the metadata from a native AoT artifact. +/// +/// This is useful when you need to inspect model identity, routes, outputs, or +/// buffer sizes without loading the executable kernels. pub fn read_aot_model_info(path: impl AsRef) -> Result { let library = unsafe { Library::new(path.as_ref()) } .map_err(|error| AotError::Load(error.to_string()))?; @@ -248,6 +310,7 @@ pub fn read_aot_model_info(path: impl AsRef) -> Result) -> Result { let path = path.as_ref(); let library = diff --git a/src/dsl/jit.rs b/src/dsl/jit.rs index a440c51d..684b7810 100644 --- a/src/dsl/jit.rs +++ b/src/dsl/jit.rs @@ -83,6 +83,11 @@ pub type JitAnalyticalModel = NativeAnalyticalModel; pub type JitSdeModel = NativeSdeModel; pub type CompiledJitModel = CompiledNativeModel; +/// Error reported while lowering an execution model into native in-process JIT +/// code. +/// +/// The error retains the backend diagnostic so callers can render the message +/// against the original DSL source when available. #[derive(Clone, PartialEq, Eq)] pub struct JitCompileError { diagnostic: Box, @@ -214,6 +219,10 @@ struct LoweredValue { ty: ValueType, } +/// Compile one lowered execution model into a reusable JIT kernel artifact. +/// +/// This builds the raw Cranelift-compiled kernel bundle for all roles present in +/// the model. Most callers should use [`compile_execution_model_to_jit`] instead. pub fn compile_execution_artifact( model: &ExecutionModel, ) -> Result { @@ -1217,6 +1226,41 @@ fn state_address( Ok(builder.ins().iadd(base, byte_offset)) } +/// Compile an [`ExecutionModel`](pharmsol_dsl::ExecutionModel) to the native +/// in-process JIT backend. +/// +/// Use this low-level entrypoint when you already own the parse, analyze, and +/// lower steps and want the JIT backend directly instead of the higher-level +/// runtime facade. +/// +/// This function requires the `dsl-jit` feature. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{ +/// analyze_model, compile_execution_model_to_jit, lower_typed_model, parse_model, +/// }; +/// +/// let parsed = parse_model( +/// r#" +/// model implicit_route_injection { +/// kind ode +/// states { central } +/// routes { iv -> central } +/// dynamics { +/// ddt(central) = 0 +/// } +/// outputs { +/// cp = central +/// } +/// } +/// "#, +/// )?; +/// let typed = analyze_model(&parsed)?; +/// let execution = lower_typed_model(&typed)?; +/// let compiled = compile_execution_model_to_jit(&execution)?; +/// # let _ = compiled; +/// # Ok::<(), Box>(()) +/// ``` pub fn compile_execution_model_to_jit( model: &ExecutionModel, ) -> Result { @@ -1229,6 +1273,7 @@ pub fn compile_execution_model_to_jit( } } +/// Compile an ODE execution model to the native in-process JIT backend. pub fn compile_ode_model_to_jit(model: &ExecutionModel) -> Result { if model.kind != ModelKind::Ode { return Err(JitCompileError::new( @@ -1245,6 +1290,7 @@ pub fn compile_ode_model_to_jit(model: &ExecutionModel) -> Result Result { @@ -1263,6 +1309,7 @@ pub fn compile_analytical_model_to_jit( )) } +/// Compile an SDE execution model to the native in-process JIT backend. pub fn compile_sde_model_to_jit(model: &ExecutionModel) -> Result { if model.kind != ModelKind::Sde { return Err(JitCompileError::new( diff --git a/src/dsl/mod.rs b/src/dsl/mod.rs index 563e4cf6..f536c377 100644 --- a/src/dsl/mod.rs +++ b/src/dsl/mod.rs @@ -1,9 +1,94 @@ //! Public DSL facade for pharmsol. //! -//! The backend-neutral frontend is being extracted into `pharmsol-dsl`. -//! Frontend syntax, diagnostics, semantic analysis, and lowering now come -//! from `pharmsol-dsl`, while runtime and backend compilation entrypoints -//! remain owned by `pharmsol`. +//! Use this module when you want to work with pharmsol models as source text +//! and stay inside the main crate for the full workflow: parse DSL source, +//! inspect diagnostics, lower to the execution model, compile to a runtime +//! backend, load saved artifacts, and run predictions. +//! +//! Use the `pharmsol-dsl` crate directly only when you need the backend-neutral +//! frontend as an engineering API. That crate owns parsing, diagnostics, +//! semantic analysis, and lowering. This module re-exports that stable +//! frontend surface and adds the backend-specific entrypoints that stay owned +//! by `pharmsol`. +//! +//! Main entrypoints: +//! +//! - [`parse_model`], [`parse_module`], [`analyze_model`], and +//! [`analyze_module`] for frontend-only validation and inspection. +//! - [`lower_typed_model`] and [`lower_typed_module`] for lowering typed models +//! into the execution representation used by the runtime backends. +//! - [`compile_module_source_to_runtime`] and [`compile_execution_model_to_runtime`] +//! for the one-stop compile-and-run path. +//! - [`load_runtime_artifact`], [`load_aot_model`], and +//! [`load_runtime_wasm_bytes`] for loading saved artifacts back into a model +//! you can execute. +//! +//! Common workflow choices: +//! +//! - Frontend only: parse, analyze, and lower when you need diagnostics, +//! authoring tools, or your own backend. +//! - In-process execution: compile straight to [`RuntimeCompilationTarget`] and +//! keep everything inside the current process. +//! - Native artifact shipping: export a native AoT artifact, then load it later +//! on a compatible host. +//! - WASM artifact shipping: emit `.wasm` bytes or a bundled module for browser +//! or portable runtime use. +//! +//! Feature map: +//! +//! - `dsl-core`: enables this facade and the frontend re-exports from +//! `pharmsol-dsl`. +//! - `dsl-jit`: enables in-process JIT compilation through +//! [`compile_module_source_to_runtime`] with +//! [`RuntimeCompilationTarget::Jit`], plus the lower-level JIT compile +//! entrypoints. +//! - `dsl-aot`: enables native ahead-of-time artifact export through +//! [`compile_module_source_to_aot`] and [`export_execution_model_to_aot`]. +//! - `dsl-aot-load`: enables native AoT artifact loading through +//! [`load_aot_model`] and [`read_aot_model_info`]. +//! - `dsl-wasm-compile`: enables WASM artifact emission through +//! [`compile_module_source_to_wasm_bytes`], +//! [`compile_module_source_to_wasm_module`], and the browser loader helpers. +//! - `dsl-wasm`: enables host-side WASM loading and runtime execution on +//! non-browser native hosts. This includes +//! [`compile_module_source_to_runtime_wasm`], [`load_runtime_wasm_bytes`], +//! [`read_wasm_model_info`], and [`read_wasm_model_info_bytes`]. +//! +//! Smallest compile-to-runtime example: +//! +//! This example requires `dsl-jit`. +//! +//! ```rust,no_run +//! use pharmsol::dsl::{compile_module_source_to_runtime, RuntimeCompilationTarget}; +//! +//! let source = r#" +//! name = bimodal_ke +//! kind = ode +//! +//! params = ke, v +//! states = central +//! outputs = cp +//! +//! infusion(iv) -> central +//! +//! dx(central) = -ke * central +//! out(cp) = central / v +//! "#; +//! +//! let model = compile_module_source_to_runtime( +//! source, +//! Some("bimodal_ke"), +//! RuntimeCompilationTarget::Jit, +//! |_, _| {}, +//! )?; +//! +//! # let _ = model; +//! # Ok::<(), pharmsol::dsl::RuntimeError>(()) +//! ``` +//! +//! For a lower-level frontend pipeline without backend selection, use +//! `pharmsol-dsl`. For a complete runtime path inside the main crate, stay in +//! [`pharmsol::dsl`](self). #[cfg(any(feature = "dsl-aot", feature = "dsl-aot-load"))] mod aot; diff --git a/src/dsl/model_info.rs b/src/dsl/model_info.rs index d9a2fdbd..7dd3f72a 100644 --- a/src/dsl/model_info.rs +++ b/src/dsl/model_info.rs @@ -8,47 +8,80 @@ use pharmsol_dsl::execution::{ }; use pharmsol_dsl::{AnalyticalKernel, ModelKind, RouteKind}; +/// Public metadata extracted from a compiled backend model. +/// +/// This is the shared inspection surface returned by the native AoT, WASM, and +/// runtime loaders. It keeps public labels and buffer sizes available without +/// exposing backend-specific kernel details. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeModelInfo { + /// Public model name. pub name: String, + /// High-level model family. pub kind: ModelKind, + /// Parameter names in support-point order. pub parameters: Vec, + /// Declared covariates and their dense runtime indices. pub covariates: Vec, + /// Declared routes together with declaration-order and dense runtime indices. pub routes: Vec, + /// Declared outputs and their dense runtime indices. pub outputs: Vec, + /// Length of the state buffer used during execution. pub state_len: usize, + /// Length of the derived-value buffer used during execution. pub derived_len: usize, + /// Length of the output buffer used during execution. pub output_len: usize, + /// Length of the dense route-input buffer used during execution. pub route_len: usize, + /// Analytical kernel metadata when the compiled model is analytical. pub analytical: Option, + /// Particle count when the compiled model is stochastic. pub particles: Option, } +/// Metadata for one compiled covariate. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeCovariateInfo { + /// Public covariate name. pub name: String, + /// Dense runtime covariate index. pub index: usize, } +/// Metadata for one compiled route. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeRouteInfo { + /// Public route label. pub name: String, + /// Route position in declaration order. #[serde(default)] pub declaration_index: usize, + /// Dense runtime route-input index. pub index: usize, + /// Coarse route kind when declared in metadata. #[serde(default)] pub kind: Option, + /// Dense destination state offset used by compiled kernels. pub destination_offset: usize, + /// Whether the compiled backend injects the route input into the destination + /// state automatically when the model does not read the route input + /// explicitly. pub inject_input_to_destination: bool, } +/// Metadata for one compiled output. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct NativeOutputInfo { + /// Public output label. pub name: String, + /// Dense runtime output index. pub index: usize, } impl NativeModelInfo { + /// Build public compiled-model metadata from a lowered execution model. pub fn from_execution_model(model: &ExecutionModel) -> Self { let explicit_route_input_usage = explicit_route_input_usage(model); Self { diff --git a/src/dsl/runtime.rs b/src/dsl/runtime.rs index 59c399ab..1cef784e 100644 --- a/src/dsl/runtime.rs +++ b/src/dsl/runtime.rs @@ -1,3 +1,82 @@ +//! Unified runtime entrypoints for DSL-backed models. +//! +//! Use this module when you already know you want an executable model and need +//! one backend-neutral surface for compile, load, and prediction workflows. +//! It normalizes the backend-specific JIT, native AoT, and WASM entrypoints so +//! callers can choose a deployment target without rewriting the downstream +//! prediction code. +//! +//! Use the backend modules directly only when you need a backend-specific +//! artifact or compile control: +//! +//! - [`super::jit`] for direct in-process JIT compilation. +//! - [`compile_module_source_to_aot`][crate::dsl::compile_module_source_to_aot] for native artifact export and reload. +//! - [`compile_module_source_to_wasm_bytes`][crate::dsl::compile_module_source_to_wasm_bytes] and [`load_runtime_wasm_bytes`] for portable WASM bytes, +//! browser-loader assets, and host-side WASM loading. +//! +//! Main entrypoints: +//! +//! - [`compile_module_source_to_runtime`] for the one-stop source-to-runtime +//! path. +//! - [`compile_execution_model_to_runtime`] when you already have an +//! [`ExecutionModel`](pharmsol_dsl::ExecutionModel). +//! - [`load_runtime_artifact`] and [`load_runtime_wasm_bytes`] when the model +//! has already been compiled and stored elsewhere. +//! - [`CompiledRuntimeModel::estimate_predictions`] for backend-neutral +//! execution against a [`Subject`](crate::Subject). +//! +//! Backend choice guide: +//! +//! - [`RuntimeCompilationTarget::Jit`] keeps compilation and execution inside +//! the current process. Use it for native interactive workflows and tests. +//! - [`RuntimeCompilationTarget::NativeAot`] emits a native artifact and reloads +//! it into the same runtime model shape. Use it when you want reusable native +//! artifacts and can control the target platform. +//! - [`RuntimeCompilationTarget::Wasm`] emits portable WASM bytes and reloads +//! them into the host-side runtime adapter. Use it when you need a portable +//! artifact or browser-aligned deployment story. +//! +//! Smallest compile-and-run example: +//! +//! This example requires `dsl-jit`. +//! +//! ```rust,no_run +//! use pharmsol::dsl::{compile_module_source_to_runtime, RuntimeCompilationTarget}; +//! use pharmsol::prelude::*; +//! +//! let source = r#" +//! name = bimodal_ke +//! kind = ode +//! +//! params = ke, v +//! states = central +//! outputs = cp +//! +//! infusion(iv) -> central +//! +//! dx(central) = -ke * central +//! out(cp) = central / v +//! "#; +//! +//! let model = compile_module_source_to_runtime( +//! source, +//! Some("bimodal_ke"), +//! RuntimeCompilationTarget::Jit, +//! |_, _| {}, +//! )?; +//! +//! let subject = Subject::builder("patient_001") +//! .infusion(0.0, 500.0, "iv", 0.5) +//! .missing_observation(0.5, "cp") +//! .missing_observation(1.0, "cp") +//! .missing_observation(2.0, "cp") +//! .build(); +//! +//! let predictions = model.estimate_predictions(&subject, &[1.2, 50.0])?; +//! assert!(predictions.as_subject().is_some()); +//! # Ok::<(), pharmsol::dsl::RuntimeError>(()) +//! ``` + use std::fmt; use std::path::Path; @@ -39,24 +118,39 @@ pub type RuntimeOdeModel = NativeOdeModel; pub type RuntimeAnalyticalModel = NativeAnalyticalModel; pub type RuntimeSdeModel = NativeSdeModel; +/// Selects which backend should produce the executable runtime model. +/// +/// This enum is the main backend-switching point for +/// [`compile_module_source_to_runtime`] and +/// [`compile_execution_model_to_runtime`]. #[derive(Debug, Clone, PartialEq, Eq)] pub enum RuntimeCompilationTarget { + /// Compile and execute the model inside the current native process. #[cfg(feature = "dsl-jit")] Jit, + /// Export a native artifact and reload it as a runtime model. #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] NativeAot(NativeAotCompileOptions), + /// Emit WASM bytes and reload them through the host-side WASM runtime. #[cfg(feature = "dsl-wasm")] Wasm, } +/// Identifies the on-disk artifact format for [`load_runtime_artifact`]. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RuntimeArtifactFormat { + /// A native ahead-of-time artifact produced by the AoT compiler. #[cfg(all(feature = "dsl-aot", feature = "dsl-aot-load"))] NativeAot, + /// A WASM artifact produced by the WASM compiler. #[cfg(feature = "dsl-wasm")] Wasm, } +/// Backend-neutral prediction output from a compiled runtime model. +/// +/// ODE and analytical models return subject predictions. SDE models return the +/// particle matrix used by the stochastic workflow. #[derive(Clone, Debug)] pub enum RuntimePredictions { Subject(SubjectPredictions), @@ -93,6 +187,10 @@ impl RuntimePredictions { } } +/// Executable runtime model returned by the backend-neutral runtime surface. +/// +/// This type hides the concrete backend and keeps the prediction entrypoint the +/// same across JIT, native AoT, and WASM-based flows. #[derive(Clone, Debug)] pub enum CompiledRuntimeModel { Ode(RuntimeOdeModel), @@ -166,6 +264,8 @@ impl CompiledRuntimeModel { } } +/// Errors produced while parsing, lowering, compiling, loading, or executing a +/// runtime DSL model. #[derive(Error)] pub enum RuntimeError { #[error("failed to parse DSL source: {0}")] @@ -231,6 +331,10 @@ impl fmt::Debug for RuntimeError { } } +/// Parse, analyze, lower, compile, and return a runtime model in one step. +/// +/// Use this when your input is DSL source text and you want the shortest path +/// from source to predictions. pub fn compile_module_source_to_runtime( source: &str, model_name: Option<&str>, @@ -269,6 +373,10 @@ pub fn compile_module_source_to_runtime( }) } +/// Compile a lowered execution model to a selected runtime backend. +/// +/// Use this when you already own the frontend pipeline and only need the final +/// backend step. pub fn compile_execution_model_to_runtime( model: &ExecutionModel, target: RuntimeCompilationTarget, @@ -309,6 +417,7 @@ pub fn compile_execution_model_to_runtime( } } +/// Load a previously compiled native AoT or WASM artifact from disk. pub fn load_runtime_artifact( path: impl AsRef, format: RuntimeArtifactFormat, @@ -330,6 +439,7 @@ pub fn load_runtime_artifact( } #[cfg(feature = "dsl-wasm")] +/// Compile DSL source straight to a host-side runtime model via the WASM path. pub fn compile_module_source_to_runtime_wasm( source: &str, model_name: Option<&str>, @@ -339,6 +449,8 @@ pub fn compile_module_source_to_runtime_wasm( } #[cfg(feature = "dsl-wasm")] +/// Compile a lowered execution model straight to a host-side runtime model via +/// the WASM path. pub fn compile_execution_model_to_runtime_wasm( model: &ExecutionModel, ) -> Result { @@ -347,6 +459,7 @@ pub fn compile_execution_model_to_runtime_wasm( } #[cfg(feature = "dsl-wasm")] +/// Load a runtime model from in-memory WASM bytes. pub fn load_runtime_wasm_bytes(bytes: &[u8]) -> Result { let (info, artifact) = load_wasm_artifact_bytes(bytes)?; Ok(runtime_model_from_parts(info, artifact)) diff --git a/src/dsl/wasm.rs b/src/dsl/wasm.rs index f2504d44..e95b799a 100644 --- a/src/dsl/wasm.rs +++ b/src/dsl/wasm.rs @@ -406,11 +406,39 @@ impl RuntimeArtifact for WasmExecutionArtifact { } } +/// Read only the metadata from a compiled WASM artifact on disk. +/// +/// Use this when you need model identity, route labels, output labels, or +/// buffer sizes without loading the executable runtime wrapper. pub fn read_wasm_model_info(path: impl AsRef) -> Result { let (info, _) = load_wasm_artifact(path)?; Ok(info) } +/// Read only the metadata from in-memory compiled WASM bytes. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{compile_module_source_to_wasm_bytes, read_wasm_model_info_bytes}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let bytes = compile_module_source_to_wasm_bytes(source, Some("bimodal_ke"))?; +/// let info = read_wasm_model_info_bytes(&bytes)?; +/// assert_eq!(info.name, "bimodal_ke"); +/// # Ok::<(), Box>(()) +/// ``` pub fn read_wasm_model_info_bytes(bytes: &[u8]) -> Result { let (info, _) = load_wasm_artifact_bytes(bytes)?; Ok(info) diff --git a/src/dsl/wasm_compile.rs b/src/dsl/wasm_compile.rs index caa60216..cda4727d 100644 --- a/src/dsl/wasm_compile.rs +++ b/src/dsl/wasm_compile.rs @@ -19,15 +19,24 @@ use pharmsol_dsl::{ LoweringError, ParseError, SemanticError, }; +/// ABI version for compiled WASM artifacts produced by this crate. pub const WASM_API_VERSION: u32 = 1; +/// Default entry capacity for [`WasmCompileCache`]. pub const DEFAULT_WASM_COMPILE_CACHE_CAPACITY: usize = 32; static BROWSER_LOADER_SOURCE: OnceLock = OnceLock::new(); +/// Portable WASM artifact bundle produced by the WASM compiler path. +/// +/// The bundle includes the raw WASM bytes, model metadata, and a browser loader +/// source string that can instantiate the model in JavaScript. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] pub struct CompiledWasmModule { + /// Raw compiled WASM bytes. pub wasm_bytes: Vec, + /// Serialized model metadata and kernel availability. pub metadata: CompiledModelInfoEnvelope, + /// JavaScript loader source for browser-side instantiation. pub browser_loader_source: String, } @@ -52,6 +61,7 @@ struct WasmCompileCacheState { lru: VecDeque, } +/// In-memory LRU cache for repeated WASM compilation from the same DSL source. #[derive(Debug)] pub struct WasmCompileCache { capacity: usize, @@ -65,6 +75,7 @@ impl Default for WasmCompileCache { } impl WasmCompileCache { + /// Create a new compile cache with at least one entry of capacity. pub fn new(capacity: usize) -> Self { Self { capacity: capacity.max(1), @@ -72,10 +83,12 @@ impl WasmCompileCache { } } + /// Return the configured cache capacity. pub fn capacity(&self) -> usize { self.capacity } + /// Return the number of cached compiled modules. pub fn entry_count(&self) -> usize { self.state .lock() @@ -84,6 +97,7 @@ impl WasmCompileCache { .len() } + /// Remove all cached compiled modules. pub fn clear(&self) { let mut state = self .state @@ -93,6 +107,8 @@ impl WasmCompileCache { state.lru.clear(); } + /// Compile DSL source to a full WASM module bundle, reusing the cache when + /// possible. pub fn compile_module_source_to_wasm_module( &self, source: &str, @@ -108,6 +124,7 @@ impl WasmCompileCache { Ok(compiled) } + /// Compile DSL source to raw WASM bytes, reusing the cache when possible. pub fn compile_module_source_to_wasm_bytes( &self, source: &str, @@ -145,6 +162,8 @@ impl WasmCompileCache { } } +/// Error produced while compiling, inspecting, or loading a DSL-backed WASM +/// artifact. #[derive(Error)] pub enum WasmError { #[error(transparent)] @@ -224,10 +243,12 @@ impl fmt::Debug for WasmError { } } +/// Compile a lowered execution model to raw WASM bytes. pub fn compile_execution_model_to_wasm_bytes(model: &ExecutionModel) -> Result, WasmError> { emit_execution_model_to_wasm_bytes(model, WASM_API_VERSION) } +/// Compile a lowered execution model to a portable WASM bundle. pub fn compile_execution_model_to_wasm_module( model: &ExecutionModel, ) -> Result { @@ -238,6 +259,7 @@ pub fn compile_execution_model_to_wasm_module( }) } +/// Parse DSL source, lower one selected model, and return raw WASM bytes. pub fn compile_module_source_to_wasm_bytes( source: &str, model_name: Option<&str>, @@ -245,6 +267,35 @@ pub fn compile_module_source_to_wasm_bytes( Ok(compile_module_source_to_wasm_module(source, model_name)?.wasm_bytes) } +/// Parse DSL source, lower one selected model, and return the full WASM bundle. +/// +/// Use this when you want a portable artifact for browser or host-side loading +/// together with the browser loader source. +/// +/// This function requires `dsl-wasm-compile`. +/// +/// ```rust,no_run +/// use pharmsol::dsl::{browser_loader_source, compile_module_source_to_wasm_module}; +/// +/// let source = r#" +/// name = bimodal_ke +/// kind = ode +/// +/// params = ke, v +/// states = central +/// outputs = cp +/// +/// infusion(iv) -> central +/// +/// dx(central) = -ke * central +/// out(cp) = central / v +/// "#; +/// +/// let compiled = compile_module_source_to_wasm_module(source, Some("bimodal_ke"))?; +/// let loader = browser_loader_source(); +/// # let _ = (compiled, loader); +/// # Ok::<(), pharmsol::dsl::WasmError>(()) +/// ``` pub fn compile_module_source_to_wasm_module( source: &str, model_name: Option<&str>, @@ -282,6 +333,10 @@ fn compile_module_source_to_wasm_module_uncached( compile_execution_model_to_wasm_module(&execution) } +/// Return the JavaScript loader source for browser-side WASM model execution. +/// +/// This helper is useful when you want to ship compiled WASM bytes together +/// with the minimal browser glue code that understands the pharmsol ABI. pub fn browser_loader_source() -> String { BROWSER_LOADER_SOURCE .get_or_init(build_browser_loader_source) diff --git a/src/lib.rs b/src/lib.rs index c84d4ee1..9c9e40b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,105 @@ +//! `pharmsol` is a Rust library for pharmacometric work. +//! +//! You can use it to: +//! +//! - build PK/PD datasets from dose and observation events +//! - simulate analytical, ODE, and SDE models +//! - run non-compartmental analysis (NCA) +//! - compile and run models from the pharmsol DSL when the DSL features are enabled +//! +//! Most users start in one of these places: +//! +//! - [`prelude`] for the common types, traits, and macros +//! - [`data`] to build subjects, occasions, events, and covariates +//! - [`simulator`] to define models and generate predictions +//! - [`nca`] to calculate NCA metrics from the same data structures +//! - [`optimize`] for optimizer-oriented workflows +//! +//! The DSL runtime surface is feature-gated. When you enable `dsl-core`, the +//! `pharmsol::dsl` module adds parsing, analysis, lowering, compile, and runtime +//! entrypoints for models written as DSL source text. +//! +//! ## Quick Start +//! +//! This example shows the smallest full workflow: define a model, build a +//! subject, and generate predictions. +//! +//! ```rust +//! use pharmsol::prelude::*; +//! +//! let model = analytical! { +//! name: "one_cmt_iv", +//! params: [ke, v], +//! states: [central], +//! outputs: [cp], +//! routes: [ +//! infusion(iv) -> central, +//! ], +//! structure: one_compartment, +//! out: |x, _p, _t, _cov, y| { +//! y[cp] = x[central] / v; +//! }, +//! }; +//! +//! let subject = Subject::builder("patient_001") +//! .infusion(0.0, 500.0, "iv", 0.5) +//! .missing_observation(0.5, "cp") +//! .missing_observation(1.0, "cp") +//! .build(); +//! +//! let predictions = model.estimate_predictions(&subject, &[1.022, 194.0])?; +//! assert_eq!(predictions.flat_predictions().len(), 2); +//! # Ok::<(), pharmsol::PharmsolError>(()) +//! ``` +//! +//! ## Choose A Workflow +//! +//! Use this guide when you are deciding where to start. +//! +//! | Task | Start Here | Notes | +//! | --- | --- | --- | +//! | Build subject data | [`data`] or [`prelude`] | Best when you already know dose times, labels, and observations. | +//! | Simulate a model written in Rust | [`simulator`] or [`prelude`] | Supports analytical, ODE, and SDE models. | +//! | Run NCA | [`nca`] or [`prelude`] | Reuses the same `Subject`, `Occasion`, and `Data` types. | +//! | Use optimization helpers | [`optimize`] | Intended for advanced workflows. | +//! | Parse or compile DSL source | `pharmsol::dsl` | Requires one or more DSL features. | +//! +//! ## Feature Guide +//! +//! Core simulation and NCA APIs do not need extra crate features on native +//! targets. +//! +//! DSL work is feature-gated: +//! +//! - `dsl-core`: exposes the `pharmsol::dsl` facade and frontend types +//! - `dsl-jit`: adds in-process JIT compilation +//! - `dsl-aot`: adds native ahead-of-time artifact compilation +//! - `dsl-aot-load`: adds native artifact loading +//! - `dsl-wasm-compile`: adds WASM artifact generation +//! - `dsl-wasm`: adds WASM runtime loading and execution +//! +//! ## Labels And Indices +//! +//! Public data APIs use route labels and output labels such as `"iv"`, +//! `"oral"`, and `"cp"`. +//! +//! Use labels in builders and parsed data unless you are deliberately working +//! with dense internal indices from a lower-level API. +//! +//! ## Platform Notes +//! +//! The main `data`, `simulator`, `nca`, and `optimize` modules are documented +//! for native targets. Some surfaces are not built on `wasm32-unknown-unknown`. +//! The DSL runtime also has feature-specific platform limits. +//! +//! ## Next Stops +//! +//! - Start with [`prelude`] if you want one import for the common workflow. +//! - Open [`data`] if you need to construct subjects or parse input files. +//! - Open [`simulator`] if you need predictions from analytical, ODE, or SDE models. +//! - Open [`nca`] if you need exposure and terminal metrics. +//! - Use `pharmsol::dsl` if the model comes from source text instead of Rust code. + #[cfg(feature = "dsl-aot")] mod build_support; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] @@ -49,19 +151,31 @@ pub use pharmsol_macros::{analytical, ode, sde}; #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub use std::collections::HashMap; -/// Prelude module that re-exports all commonly used types and traits. +/// Common imports for the main pharmsol workflow. +/// +/// Use the prelude when you want one import that covers the common public API: +/// +/// - subject and dataset types +/// - subject builders and events +/// - simulation types and prediction results +/// - NCA traits and option types +/// - declaration-first macros such as [`crate::ode`] and [`crate::analytical`] /// -/// Importing `pharmsol::prelude::*` brings the main modeling, simulation, -/// and data APIs into scope. +/// This is the fastest way to get started with examples, scripts, and small +/// applications. +/// +/// If you need a narrower import surface, use the modules directly instead. /// /// # Example /// ```rust /// use pharmsol::prelude::*; /// /// let subject = Subject::builder("patient_001") -/// .bolus(0.0, 100.0, 0) -/// .observation(1.0, 10.5, 0) +/// .infusion(0.0, 100.0, "iv", 1.0) +/// .missing_observation(1.0, "cp") /// .build(); +/// +/// assert_eq!(subject.id(), "patient_001"); /// ``` #[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))] pub mod prelude { diff --git a/src/simulator/equation/metadata.rs b/src/simulator/equation/metadata.rs index fecab7e2..c7fbd4c9 100644 --- a/src/simulator/equation/metadata.rs +++ b/src/simulator/equation/metadata.rs @@ -1,17 +1,40 @@ -//! Shared model metadata for handwritten simulator models. +//! Metadata builders and validated metadata views for handwritten models. //! -//! This module defines the public metadata contract that handwritten ODE, -//! analytical, and SDE models can attach to. The field set is intentionally -//! aligned with the public subset of the DSL/runtime metadata surface. +//! Use this module when a handwritten [`crate::ODE`], [`crate::Analytical`], or +//! [`crate::SDE`] model should expose the same public names that appear in data +//! rows, subject builders, or parsed files. //! -//! Internal runtime layout details such as dense buffer lengths, derived buffer -//! shape, or ABI-specific offsets remain internal for now. +//! Metadata gives names to parameters, covariates, states, routes, and outputs. +//! After validation, the execution layer can resolve public labels such as +//! `"iv"` and `"cp"` against those declarations before simulation. +//! +//! Without metadata, handwritten models fall back to numeric labels. With +//! metadata, labels are matched by name. +//! +//! # Example +//! +//! ```rust +//! use pharmsol::{metadata, ModelKind}; +//! +//! let metadata = metadata::new("one_cmt") +//! .kind(ModelKind::Ode) +//! .parameters(["cl", "v"]) +//! .states(["central"]) +//! .outputs(["cp"]) +//! .route(metadata::Route::infusion("iv").to_state("central")) +//! .validate() +//! .unwrap(); +//! +//! assert_eq!(metadata.name(), "one_cmt"); +//! assert_eq!(metadata.route("iv").unwrap().destination(), "central"); +//! assert_eq!(metadata.output_index("cp"), Some(0)); +//! ``` use pharmsol_dsl::{AnalyticalKernel, CovariateInterpolation, ModelKind}; use std::fmt; use thiserror::Error; -/// Create a new handwritten-model metadata builder. +/// Shorthand for [`ModelMetadata::new`]. pub fn new(name: impl Into) -> ModelMetadata { ModelMetadata::new(name) } @@ -71,7 +94,17 @@ impl fmt::Display for NameDomain { } } -/// Immutable validated metadata view used by later attachment slices. +/// Validated metadata view used by the execution layer. +/// +/// This type is what handwritten equation builders store after metadata has +/// passed validation. It provides stable lookup helpers from public names to the +/// dense indices used during execution. +/// +/// Route lookups expose two different indices: +/// - [`ValidatedModelMetadata::route_declaration_index`] is the route position in +/// declaration order. +/// - [`ValidatedModelMetadata::route_index`] is the dense execution input index +/// for that route kind. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ValidatedModelMetadata { name: String, @@ -87,10 +120,12 @@ pub struct ValidatedModelMetadata { } impl ValidatedModelMetadata { + /// Get the public model name. pub fn name(&self) -> &str { &self.name } + /// Get the validated model family. pub fn kind(&self) -> ModelKind { self.kind } @@ -111,6 +146,9 @@ impl ValidatedModelMetadata { &self.routes } + /// Get the number of dense execution input slots needed for routes. + /// + /// This is the maximum of the bolus-route count and infusion-route count. pub fn route_input_count(&self) -> usize { self.route_input_count } @@ -143,14 +181,17 @@ impl ValidatedModelMetadata { self.states.iter().position(|state| state.name() == name) } + /// Look up a route by public name and return its dense execution input index. pub fn route_index(&self, name: &str) -> Option { self.route(name).map(ValidatedRoute::input_index) } + /// Look up a route by public name and return its declaration-order index. pub fn route_declaration_index(&self, name: &str) -> Option { self.routes.iter().position(|route| route.name() == name) } + /// Look up an output by public name and return its dense output index. pub fn output_index(&self, name: &str) -> Option { self.outputs.iter().position(|output| output.name() == name) } @@ -179,7 +220,11 @@ impl ValidatedModelMetadata { } } -/// One validated route declaration with resolved destination state index. +/// One validated route declaration with resolved execution details. +/// +/// A validated route keeps both the declaration-order index and the dense input +/// index used during execution. Those values can differ from each other when a +/// model mixes bolus and infusion routes. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ValidatedRoute { name: String, @@ -194,6 +239,7 @@ pub struct ValidatedRoute { } impl ValidatedRoute { + /// Get the public route name used for label matching. pub fn name(&self) -> &str { &self.name } @@ -202,18 +248,22 @@ impl ValidatedRoute { self.kind } + /// Get the declaration-order index for this route. pub fn declaration_index(&self) -> usize { self.declaration_index } + /// Get the dense execution input index for this route kind. pub fn input_index(&self) -> usize { self.input_index } + /// Get the destination state name. pub fn destination(&self) -> &str { &self.destination } + /// Get the destination state index in model order. pub fn destination_index(&self) -> usize { self.destination_index } @@ -231,7 +281,12 @@ impl ValidatedRoute { } } -/// Metadata describing one handwritten simulator model. +/// Builder for handwritten model metadata. +/// +/// Use [`ModelMetadata`] to declare the public names that should be attached to +/// a handwritten equation. After validation, the resulting metadata can be +/// attached to handwritten [`crate::ODE`], [`crate::Analytical`], and +/// [`crate::SDE`] models through their `with_metadata(...)` methods. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ModelMetadata { name: String, @@ -379,11 +434,17 @@ impl ModelMetadata { } /// Validate this metadata using its declared kind. + /// + /// Use this when the metadata itself already declares whether the model is + /// ODE, analytical, or SDE. pub fn validate(self) -> Result { self.validate_internal(None, None) } /// Validate this metadata for a specific model kind. + /// + /// Use this when the equation type determines the model family and you want + /// validation to enforce that family explicitly. pub fn validate_for( self, kind: ModelKind, @@ -440,6 +501,7 @@ pub struct Parameter { } impl Parameter { + /// Create a named parameter declaration. pub fn new(name: impl Into) -> Self { Self { name: name.into() } } @@ -466,6 +528,7 @@ pub struct Covariate { } impl Covariate { + /// Create a named covariate without an explicit interpolation policy. pub fn new(name: impl Into) -> Self { Self { name: name.into(), @@ -473,14 +536,17 @@ impl Covariate { } } + /// Create a continuous covariate that uses linear interpolation. pub fn continuous(name: impl Into) -> Self { Self::new(name).with_interpolation(CovariateInterpolation::Linear) } + /// Create a covariate that uses last-observation-carried-forward semantics. pub fn locf(name: impl Into) -> Self { Self::new(name).with_interpolation(CovariateInterpolation::Locf) } + /// Set the interpolation policy explicitly. pub fn with_interpolation(mut self, interpolation: CovariateInterpolation) -> Self { self.interpolation = Some(interpolation); self @@ -502,6 +568,7 @@ pub struct State { } impl State { + /// Create a named state declaration. pub fn new(name: impl Into) -> Self { Self { name: name.into() } } @@ -527,6 +594,7 @@ pub struct Output { } impl Output { + /// Create a named output declaration. pub fn new(name: impl Into) -> Self { Self { name: name.into() } } @@ -548,18 +616,25 @@ where /// Route declaration kind. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteKind { + /// Instantaneous dose input. Bolus, + /// Dose input over a duration. Infusion, } /// How route inputs should be interpreted by the execution layer. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum RouteInputPolicy { + /// Inject the resolved input directly into the declared destination state. InjectToDestination, + /// Expect the low-level execution path to provide an explicit input vector. ExplicitInputVector, } /// One named route declaration. +/// +/// Route names are the public labels matched against dose events such as `iv` +/// or `oral`. #[derive(Debug, Clone, PartialEq, Eq)] pub struct Route { name: String, @@ -571,14 +646,17 @@ pub struct Route { } impl Route { + /// Create a named bolus route declaration. pub fn bolus(name: impl Into) -> Self { Self::new(name, RouteKind::Bolus) } + /// Create a named infusion route declaration. pub fn infusion(name: impl Into) -> Self { Self::new(name, RouteKind::Infusion) } + /// Create a route declaration with an explicit kind. pub fn new(name: impl Into, kind: RouteKind) -> Self { Self { name: name.into(), @@ -590,26 +668,31 @@ impl Route { } } + /// Declare which state this route targets. pub fn to_state(mut self, destination: impl Into) -> Self { self.destination = Some(destination.into()); self } + /// Mark this route as supporting lag handling. pub fn with_lag(mut self) -> Self { self.has_lag = true; self } + /// Mark this route as supporting bioavailability handling. pub fn with_bioavailability(mut self) -> Self { self.has_bioavailability = true; self } + /// Request direct injection into the destination state at execution time. pub fn inject_input_to_destination(mut self) -> Self { self.input_policy = Some(RouteInputPolicy::InjectToDestination); self } + /// Request an explicit low-level input vector at execution time. pub fn expect_explicit_input(mut self) -> Self { self.input_policy = Some(RouteInputPolicy::ExplicitInputVector); self diff --git a/src/simulator/equation/mod.rs b/src/simulator/equation/mod.rs index c5a97958..03e5318c 100644 --- a/src/simulator/equation/mod.rs +++ b/src/simulator/equation/mod.rs @@ -1,3 +1,51 @@ +//! Handwritten equation families and their shared simulation interfaces. +//! +//! This module is the public home for handwritten [`ODE`], [`Analytical`], and +//! [`SDE`] models, plus the shared [`Equation`] trait and the metadata types +//! that attach public names to parameters, states, routes, and outputs. +//! +//! Use this module when you want to: +//! - choose between deterministic ODE, analytical, and stochastic SDE models +//! - attach metadata so dataset labels such as `"iv"` and `"cp"` resolve by +//! name instead of by dense numeric index +//! - work with prediction or likelihood APIs across equation families +//! +//! # Equation Families +//! +//! - [`ODE`] for deterministic models that must be numerically integrated. +//! - [`Analytical`] for supported closed-form models. +//! - [`SDE`] for stochastic models that use particles. +//! +//! # Labels And Metadata +//! +//! Input and output labels arrive from public data APIs as strings. +//! +//! - Without metadata, handwritten models fall back to numeric labels such as +//! `0` or `1`. +//! - With [`metadata::ModelMetadata`] attached, route and output labels are +//! resolved by name against the declared routes and outputs before +//! simulation. +//! +//! That label-first path is the preferred public workflow for current authoring. +//! +//! # Example +//! +//! ```rust +//! use pharmsol::{metadata, ModelKind}; +//! +//! let metadata = metadata::new("one_cmt") +//! .kind(ModelKind::Ode) +//! .parameters(["cl", "v"]) +//! .states(["central"]) +//! .outputs(["cp"]) +//! .route(metadata::Route::infusion("iv").to_state("central")) +//! .validate() +//! .unwrap(); +//! +//! assert_eq!(metadata.route_index("iv"), Some(0)); +//! assert_eq!(metadata.output_index("cp"), Some(0)); +//! ``` + use std::fmt::Debug; pub mod analytical; pub mod metadata; @@ -20,10 +68,10 @@ use super::likelihood::Prediction; /// Trait for state vectors that can receive bolus doses. pub trait State { - /// Add a bolus dose to the state at the specified input compartment. + /// Add a bolus dose to the state at the specified resolved input index. /// /// # Parameters - /// - `input`: The compartment index + /// - `input`: The resolved dense input index used by the execution layer /// - `amount`: The bolus amount fn add_bolus(&mut self, input: usize, amount: f64); } @@ -114,7 +162,7 @@ pub trait Cache: Sized { fn disable_cache(self) -> Self; } -/// Trait defining the associated types for equations. +/// Associated state and prediction container types for an equation family. pub trait EquationTypes { /// The state vector type type S: State + Debug; @@ -308,11 +356,15 @@ pub(crate) trait EquationPriv: EquationTypes { } } -/// Trait for model equations that can be simulated. +/// Trait for handwritten model equations that can be simulated. +/// +/// [`Equation`] is the shared interface implemented by handwritten [`ODE`], +/// [`Analytical`], and [`SDE`] models. /// -/// This trait defines the interface for different types of model equations -/// (ODE, SDE, analytical) that can be simulated to generate predictions -/// and estimate parameters. +/// Subject data enters this layer through public labels on dose and observation +/// events. If metadata is attached to the equation, those labels are resolved by +/// name before simulation. Otherwise, the execution layer expects numeric labels +/// that can be interpreted as dense indices. /// /// # Likelihood Calculation /// @@ -440,6 +492,7 @@ pub trait Equation: EquationPriv + 'static + Clone + Sync { } } +/// Runtime family tag for handwritten equations. #[repr(C)] #[derive(Clone, Debug)] pub enum EqnKind {