From 10a21df989695553804417d535edfa28e7e04dc4 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 15:31:22 -0400 Subject: [PATCH 1/6] first pass on truncation --- rust/src/algorithms/mod.rs | 1 + rust/src/algorithms/transform.rs | 69 ++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+) create mode 100644 rust/src/algorithms/transform.rs diff --git a/rust/src/algorithms/mod.rs b/rust/src/algorithms/mod.rs index b6683a3..30e0696 100644 --- a/rust/src/algorithms/mod.rs +++ b/rust/src/algorithms/mod.rs @@ -3,3 +3,4 @@ pub mod moments; pub mod number; pub mod order_stat; pub mod rv; +pub mod transform; diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs new file mode 100644 index 0000000..3874a19 --- /dev/null +++ b/rust/src/algorithms/transform.rs @@ -0,0 +1,69 @@ +#![allow(dead_code)] + +use crate::algorithms::number::Number; +use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; + + +/// Truncates a discrete random variable by cutting off a portion of the support +/// and normalizng total probability of the distribution to 1 +/// +/// # Arguments +/// * `random_variable` - the random variable to truncate +/// * `min_support` - the minimum support of the new random variable. +/// Must be greater than or equal to the current minimum support. +/// * `max_support` - the maximum support of the new random variable. +/// Must be less than or equal to the current minimum support. +/// +/// # Returns +/// * `truncated_rv` - the truncated random variable +pub fn truncate_discrete( + random_variable: &RandomVariable, + min_support: Number, + max_support: Number, +) -> Result { + let pdf_random_variable = random_variable.to_pdf()?; + let function = pdf_random_variable.function; + let support = pdf_random_variable.support; + + let first_support = *support.first().expect("could not extract the first item"); + if min_support < first_support { + return Err( + "min support must be greater than or equal to the lowest support value" + .to_string() + ); + } + + let last_support = *support.last().expect("could not extract the first item"); + if max_support > last_support { + return Err( + "min support must be less than or equal to the highest support value" + .to_string() + ); + } + + let mut truncation_area = Number::Integer(0); + for (&support_value, &function_value) in support.iter().zip(function.iter()) { + if support_value >= min_support && support_value <= max_support { + truncation_area += function_value; + } + } + + let mut truncated_function = Vec::new(); + let mut truncated_support = Vec::new(); + + for (&support_value, &function_value) in support.iter().zip(function.iter()) { + if support_value >= min_support && support_value <= max_support { + let probability = function_value / truncation_area; + truncated_function.push(probability); + truncated_support.push(support_value); + } + } + + let truncated_rv = RandomVariable { + function: truncated_function, + support: truncated_support, + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + Ok(truncated_rv) +} From 8e011c6d03d5b175ecaa48e2b78960b557ecdb77 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 15:38:28 -0400 Subject: [PATCH 2/6] cleanup of truncation --- rust/src/algorithms/transform.rs | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index 3874a19..bf2bd26 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -25,7 +25,14 @@ pub fn truncate_discrete( let function = pdf_random_variable.function; let support = pdf_random_variable.support; - let first_support = *support.first().expect("could not extract the first item"); + if min_support >= max_support { + return Err( + "max_support must be greater than the min_support" + .to_string() + ); + } + + let first_support = *support.first().ok_or("support is empty")?; if min_support < first_support { return Err( "min support must be greater than or equal to the lowest support value" @@ -33,10 +40,10 @@ pub fn truncate_discrete( ); } - let last_support = *support.last().expect("could not extract the first item"); + let last_support = *support.last().ok_or("support is empty")?; if max_support > last_support { return Err( - "min support must be less than or equal to the highest support value" + "max support must be less than or equal to the highest support value" .to_string() ); } @@ -48,6 +55,14 @@ pub fn truncate_discrete( } } + let zero = Number::Integer(0); + if truncation_area == zero { + return Err( + "there is no probability mass within the specified support range" + .to_string() + ); + } + let mut truncated_function = Vec::new(); let mut truncated_support = Vec::new(); From 204b7c74bc2cc19ecfc51492615f3645932c51fb Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 15:43:13 -0400 Subject: [PATCH 3/6] unit tests for truncate --- rust/src/algorithms/transform.rs | 95 +++++++++++++++++++++++++++++++- 1 file changed, 93 insertions(+), 2 deletions(-) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index bf2bd26..f7c8704 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -5,17 +5,55 @@ use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; /// Truncates a discrete random variable by cutting off a portion of the support -/// and normalizng total probability of the distribution to 1 +/// and normalizing total probability of the distribution to 1. /// /// # Arguments /// * `random_variable` - the random variable to truncate /// * `min_support` - the minimum support of the new random variable. /// Must be greater than or equal to the current minimum support. /// * `max_support` - the maximum support of the new random variable. -/// Must be less than or equal to the current minimum support. +/// Must be less than or equal to the current maximum support. /// /// # Returns /// * `truncated_rv` - the truncated random variable +/// +/// # Examples +/// ``` +/// use applpy_rust::algorithms::number::Number; +/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +/// use applpy_rust::algorithms::transform::truncate_discrete; +/// use num_rational::Rational64; +/// +/// let rv = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 10)), +/// Number::Rational(Rational64::new(2, 10)), +/// Number::Rational(Rational64::new(3, 10)), +/// Number::Rational(Rational64::new(4, 10)), +/// ], +/// support: vec![ +/// Number::Integer(1), +/// Number::Integer(2), +/// Number::Integer(3), +/// Number::Integer(4), +/// ], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let truncated = truncate_discrete(&rv, Number::Integer(2), Number::Integer(3)).unwrap(); +/// +/// assert_eq!(truncated.support, vec![Number::Integer(2), Number::Integer(3)]); +/// assert_eq!( +/// truncated.function, +/// vec![ +/// Number::Rational(Rational64::new(2, 5)), +/// Number::Rational(Rational64::new(3, 5)), +/// ] +/// ); +/// assert!(matches!(truncated.functional_form, FunctionalForm::Pdf)); +/// assert!(matches!(truncated.domain_type, DomainType::Discrete)); +/// ``` pub fn truncate_discrete( random_variable: &RandomVariable, min_support: Number, @@ -82,3 +120,56 @@ pub fn truncate_discrete( }; Ok(truncated_rv) } + +#[cfg(test)] +mod tests { + use super::*; + use num_rational::Rational64; + + fn sample_discrete_rv() -> RandomVariable { + RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 10)), + Number::Rational(Rational64::new(2, 10)), + Number::Rational(Rational64::new(3, 10)), + Number::Rational(Rational64::new(4, 10)), + ], + support: vec![ + Number::Integer(1), + Number::Integer(2), + Number::Integer(3), + Number::Integer(4), + ], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + } + } + + #[test] + fn truncate_discrete_renormalizes_probabilities_within_range() { + let rv = sample_discrete_rv(); + let truncated = truncate_discrete(&rv, Number::Integer(2), Number::Integer(3)).unwrap(); + + assert_eq!(truncated.support, vec![Number::Integer(2), Number::Integer(3)]); + assert_eq!( + truncated.function, + vec![ + Number::Rational(Rational64::new(2, 5)), + Number::Rational(Rational64::new(3, 5)) + ] + ); + assert!(matches!(truncated.functional_form, FunctionalForm::Pdf)); + assert!(matches!(truncated.domain_type, DomainType::Discrete)); + } + + #[test] + fn truncate_discrete_returns_error_when_min_support_exceeds_bounds() { + let rv = sample_discrete_rv(); + let result = truncate_discrete(&rv, Number::Integer(0), Number::Integer(3)); + + assert!(matches!( + result, + Err(msg) if msg == "min support must be greater than or equal to the lowest support value" + )); + } +} From 021581fb2c014ebd585455c116a7c2e5625bea47 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 15:43:22 -0400 Subject: [PATCH 4/6] add truncate to python api --- rust/src/python/api.rs | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/rust/src/python/api.rs b/rust/src/python/api.rs index ac2c616..f99bd3c 100644 --- a/rust/src/python/api.rs +++ b/rust/src/python/api.rs @@ -9,6 +9,7 @@ use crate::algorithms::number::Number; use crate::algorithms::order_stat; use crate::algorithms::rv; use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +use crate::algorithms::transform; #[pyfunction(name = "discrete_order_stat", signature = (random_variable, n, r, replace="w"))] pub fn discrete_order_stat_py( @@ -150,6 +151,23 @@ pub fn bootstrap_rv_py(variates: Vec) -> PyResult { Ok(fast_rv) } +#[pyfunction(name = "truncate_discrete", signature = (random_variable, min_support, max_support))] +pub fn truncate_discrete_py( + random_variable: &Bound<'_, PyAny>, + min_support: Number, + max_support: Number, +) -> PyResult { + let random_variable: FastRV = random_variable.extract()?; + let truncated_rv = transform::truncate_discrete(&random_variable.inner, min_support, max_support) + .map_err(PyValueError::new_err)?; + Ok(FastRV::new( + truncated_rv.function, + truncated_rv.support, + truncated_rv.functional_form, + truncated_rv.domain_type, + )) +} + #[pyclass] pub struct FastRV { inner: RandomVariable, From aad8845542e77b4f19754da06473a9727d82bbb5 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 15:45:06 -0400 Subject: [PATCH 5/6] add truncate to python lib --- rust/src/algorithms/transform.rs | 22 ++++++++-------------- rust/src/lib.rs | 3 +++ rust/src/python/api.rs | 5 +++-- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index f7c8704..8283c79 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -3,7 +3,6 @@ use crate::algorithms::number::Number; use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; - /// Truncates a discrete random variable by cutting off a portion of the support /// and normalizing total probability of the distribution to 1. /// @@ -64,25 +63,20 @@ pub fn truncate_discrete( let support = pdf_random_variable.support; if min_support >= max_support { - return Err( - "max_support must be greater than the min_support" - .to_string() - ); + return Err("max_support must be greater than the min_support".to_string()); } let first_support = *support.first().ok_or("support is empty")?; if min_support < first_support { return Err( - "min support must be greater than or equal to the lowest support value" - .to_string() + "min support must be greater than or equal to the lowest support value".to_string(), ); } let last_support = *support.last().ok_or("support is empty")?; if max_support > last_support { return Err( - "max support must be less than or equal to the highest support value" - .to_string() + "max support must be less than or equal to the highest support value".to_string(), ); } @@ -95,10 +89,7 @@ pub fn truncate_discrete( let zero = Number::Integer(0); if truncation_area == zero { - return Err( - "there is no probability mass within the specified support range" - .to_string() - ); + return Err("there is no probability mass within the specified support range".to_string()); } let mut truncated_function = Vec::new(); @@ -150,7 +141,10 @@ mod tests { let rv = sample_discrete_rv(); let truncated = truncate_discrete(&rv, Number::Integer(2), Number::Integer(3)).unwrap(); - assert_eq!(truncated.support, vec![Number::Integer(2), Number::Integer(3)]); + assert_eq!( + truncated.support, + vec![Number::Integer(2), Number::Integer(3)] + ); assert_eq!( truncated.function, vec![ diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 7c82132..814fa18 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -35,6 +35,9 @@ fn applpy_rust(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> { )?)?; module.add_function(wrap_pyfunction!(python::api::bootstrap_rv_py, module)?)?; + // transformation functions + module.add_function(wrap_pyfunction!(python::api::truncate_discrete_py, module)?)?; + // dummy function to validate imports module.add_function(wrap_pyfunction!(dummy_ping, module)?)?; diff --git a/rust/src/python/api.rs b/rust/src/python/api.rs index f99bd3c..37c5ce5 100644 --- a/rust/src/python/api.rs +++ b/rust/src/python/api.rs @@ -158,8 +158,9 @@ pub fn truncate_discrete_py( max_support: Number, ) -> PyResult { let random_variable: FastRV = random_variable.extract()?; - let truncated_rv = transform::truncate_discrete(&random_variable.inner, min_support, max_support) - .map_err(PyValueError::new_err)?; + let truncated_rv = + transform::truncate_discrete(&random_variable.inner, min_support, max_support) + .map_err(PyValueError::new_err)?; Ok(FastRV::new( truncated_rv.function, truncated_rv.support, From 73d8f3cad299ae7c1d575014cc86fbd5aaff11e8 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Tue, 24 Mar 2026 15:49:54 -0400 Subject: [PATCH 6/6] patch rust truncate into python --- applpy/transform.py | 35 +++++++++++++++--------------- test_applpy/unit/test_transform.py | 3 ++- 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/applpy/transform.py b/applpy/transform.py index c51f11b..404b6a1 100644 --- a/applpy/transform.py +++ b/applpy/transform.py @@ -8,6 +8,16 @@ from .rv import RV, RVError, t, x +try: + import applpy_rust +except ImportError: + raise ImportError( + "applpy_rust extension is not built. " + "Run `uv sync --extra rust` then " + "`uv run --no-sync maturin develop -m rust/Cargo.toml`." + ) + + def transform(random_variable, transform_spec): """ Procedure Name: Transform @@ -423,23 +433,14 @@ def _truncate_discrete_functional(pdf_random_variable, cdf_random_variable, supp def _truncate_discrete(pdf_random_variable, support_interval): - # Find the area of the truncated random variable - truncation_area = 0 - for i in range(len(pdf_random_variable.support)): - if pdf_random_variable.support[i] >= support_interval[0]: - if pdf_random_variable.support[i] <= support_interval[1]: - truncation_area += pdf_random_variable.func[i] - # Truncate the random variable and find the probability - # at each point - truncated_functions = [] - truncated_support = [] - for i in range(len(pdf_random_variable.support)): - if pdf_random_variable.support[i] >= support_interval[0]: - if pdf_random_variable.support[i] <= support_interval[1]: - truncated_functions.append(pdf_random_variable.func[i] / truncation_area) - truncated_support.append(pdf_random_variable.support[i]) - # Return the truncated random variable - return RV(truncated_functions, truncated_support, ["discrete", "pdf"]) + min_support, max_support = tuple(support_interval) + fast_rv = applpy_rust.truncate_discrete(pdf_random_variable, min_support, max_support) + return RV( + func=fast_rv.function, + support=fast_rv.support, + functional_form=fast_rv.functional_form, + domain_type=fast_rv.domain_type, + ) def mixture(mix_parameters, mix_random_variables): diff --git a/test_applpy/unit/test_transform.py b/test_applpy/unit/test_transform.py index d0b2e29..0d4fd5a 100644 --- a/test_applpy/unit/test_transform.py +++ b/test_applpy/unit/test_transform.py @@ -52,7 +52,8 @@ def test_transform_and_truncate_happy_paths(): assert isinstance(transform(discrete, [[x + 1, x + 2], [0, 1, 2]]), RV) assert isinstance(transform(piecewise, [[x, x**2], [0, 1, 2]]), RV) assert isinstance(truncate(continuous, [Rational(1, 4), Rational(3, 4)]), RV) - assert isinstance(truncate(discrete, [1, 1]), RV) + with pytest.raises(ValueError): + truncate(discrete, [1, 1]) def test_mixture_happy_paths():