diff --git a/applpy/transform.py b/applpy/transform.py index c51f11b..404b6a1 100644 --- a/applpy/transform.py +++ b/applpy/transform.py @@ -8,6 +8,16 @@ from .rv import RV, RVError, t, x +try: + import applpy_rust +except ImportError: + raise ImportError( + "applpy_rust extension is not built. " + "Run `uv sync --extra rust` then " + "`uv run --no-sync maturin develop -m rust/Cargo.toml`." + ) + + def transform(random_variable, transform_spec): """ Procedure Name: Transform @@ -423,23 +433,14 @@ def _truncate_discrete_functional(pdf_random_variable, cdf_random_variable, supp def _truncate_discrete(pdf_random_variable, support_interval): - # Find the area of the truncated random variable - truncation_area = 0 - for i in range(len(pdf_random_variable.support)): - if pdf_random_variable.support[i] >= support_interval[0]: - if pdf_random_variable.support[i] <= support_interval[1]: - truncation_area += pdf_random_variable.func[i] - # Truncate the random variable and find the probability - # at each point - truncated_functions = [] - truncated_support = [] - for i in range(len(pdf_random_variable.support)): - if pdf_random_variable.support[i] >= support_interval[0]: - if pdf_random_variable.support[i] <= support_interval[1]: - truncated_functions.append(pdf_random_variable.func[i] / truncation_area) - truncated_support.append(pdf_random_variable.support[i]) - # Return the truncated random variable - return RV(truncated_functions, truncated_support, ["discrete", "pdf"]) + min_support, max_support = tuple(support_interval) + fast_rv = applpy_rust.truncate_discrete(pdf_random_variable, min_support, max_support) + return RV( + func=fast_rv.function, + support=fast_rv.support, + functional_form=fast_rv.functional_form, + domain_type=fast_rv.domain_type, + ) def mixture(mix_parameters, mix_random_variables): diff --git a/rust/src/algorithms/mod.rs b/rust/src/algorithms/mod.rs index b6683a3..30e0696 100644 --- a/rust/src/algorithms/mod.rs +++ b/rust/src/algorithms/mod.rs @@ -3,3 +3,4 @@ pub mod moments; pub mod number; pub mod order_stat; pub mod rv; +pub mod transform; diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs new file mode 100644 index 0000000..8283c79 --- /dev/null +++ b/rust/src/algorithms/transform.rs @@ -0,0 +1,169 @@ +#![allow(dead_code)] + +use crate::algorithms::number::Number; +use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; + +/// Truncates a discrete random variable by cutting off a portion of the support +/// and normalizing total probability of the distribution to 1. +/// +/// # Arguments +/// * `random_variable` - the random variable to truncate +/// * `min_support` - the minimum support of the new random variable. +/// Must be greater than or equal to the current minimum support. +/// * `max_support` - the maximum support of the new random variable. +/// Must be less than or equal to the current maximum support. +/// +/// # Returns +/// * `truncated_rv` - the truncated random variable +/// +/// # Examples +/// ``` +/// use applpy_rust::algorithms::number::Number; +/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +/// use applpy_rust::algorithms::transform::truncate_discrete; +/// use num_rational::Rational64; +/// +/// let rv = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 10)), +/// Number::Rational(Rational64::new(2, 10)), +/// Number::Rational(Rational64::new(3, 10)), +/// Number::Rational(Rational64::new(4, 10)), +/// ], +/// support: vec![ +/// Number::Integer(1), +/// Number::Integer(2), +/// Number::Integer(3), +/// Number::Integer(4), +/// ], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let truncated = truncate_discrete(&rv, Number::Integer(2), Number::Integer(3)).unwrap(); +/// +/// assert_eq!(truncated.support, vec![Number::Integer(2), Number::Integer(3)]); +/// assert_eq!( +/// truncated.function, +/// vec![ +/// Number::Rational(Rational64::new(2, 5)), +/// Number::Rational(Rational64::new(3, 5)), +/// ] +/// ); +/// assert!(matches!(truncated.functional_form, FunctionalForm::Pdf)); +/// assert!(matches!(truncated.domain_type, DomainType::Discrete)); +/// ``` +pub fn truncate_discrete( + random_variable: &RandomVariable, + min_support: Number, + max_support: Number, +) -> Result { + let pdf_random_variable = random_variable.to_pdf()?; + let function = pdf_random_variable.function; + let support = pdf_random_variable.support; + + if min_support >= max_support { + return Err("max_support must be greater than the min_support".to_string()); + } + + let first_support = *support.first().ok_or("support is empty")?; + if min_support < first_support { + return Err( + "min support must be greater than or equal to the lowest support value".to_string(), + ); + } + + let last_support = *support.last().ok_or("support is empty")?; + if max_support > last_support { + return Err( + "max support must be less than or equal to the highest support value".to_string(), + ); + } + + let mut truncation_area = Number::Integer(0); + for (&support_value, &function_value) in support.iter().zip(function.iter()) { + if support_value >= min_support && support_value <= max_support { + truncation_area += function_value; + } + } + + let zero = Number::Integer(0); + if truncation_area == zero { + return Err("there is no probability mass within the specified support range".to_string()); + } + + let mut truncated_function = Vec::new(); + let mut truncated_support = Vec::new(); + + for (&support_value, &function_value) in support.iter().zip(function.iter()) { + if support_value >= min_support && support_value <= max_support { + let probability = function_value / truncation_area; + truncated_function.push(probability); + truncated_support.push(support_value); + } + } + + let truncated_rv = RandomVariable { + function: truncated_function, + support: truncated_support, + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + Ok(truncated_rv) +} + +#[cfg(test)] +mod tests { + use super::*; + use num_rational::Rational64; + + fn sample_discrete_rv() -> RandomVariable { + RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 10)), + Number::Rational(Rational64::new(2, 10)), + Number::Rational(Rational64::new(3, 10)), + Number::Rational(Rational64::new(4, 10)), + ], + support: vec![ + Number::Integer(1), + Number::Integer(2), + Number::Integer(3), + Number::Integer(4), + ], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + } + } + + #[test] + fn truncate_discrete_renormalizes_probabilities_within_range() { + let rv = sample_discrete_rv(); + let truncated = truncate_discrete(&rv, Number::Integer(2), Number::Integer(3)).unwrap(); + + assert_eq!( + truncated.support, + vec![Number::Integer(2), Number::Integer(3)] + ); + assert_eq!( + truncated.function, + vec![ + Number::Rational(Rational64::new(2, 5)), + Number::Rational(Rational64::new(3, 5)) + ] + ); + assert!(matches!(truncated.functional_form, FunctionalForm::Pdf)); + assert!(matches!(truncated.domain_type, DomainType::Discrete)); + } + + #[test] + fn truncate_discrete_returns_error_when_min_support_exceeds_bounds() { + let rv = sample_discrete_rv(); + let result = truncate_discrete(&rv, Number::Integer(0), Number::Integer(3)); + + assert!(matches!( + result, + Err(msg) if msg == "min support must be greater than or equal to the lowest support value" + )); + } +} diff --git a/rust/src/lib.rs b/rust/src/lib.rs index 7c82132..814fa18 100644 --- a/rust/src/lib.rs +++ b/rust/src/lib.rs @@ -35,6 +35,9 @@ fn applpy_rust(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> { )?)?; module.add_function(wrap_pyfunction!(python::api::bootstrap_rv_py, module)?)?; + // transformation functions + module.add_function(wrap_pyfunction!(python::api::truncate_discrete_py, module)?)?; + // dummy function to validate imports module.add_function(wrap_pyfunction!(dummy_ping, module)?)?; diff --git a/rust/src/python/api.rs b/rust/src/python/api.rs index ac2c616..37c5ce5 100644 --- a/rust/src/python/api.rs +++ b/rust/src/python/api.rs @@ -9,6 +9,7 @@ use crate::algorithms::number::Number; use crate::algorithms::order_stat; use crate::algorithms::rv; use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +use crate::algorithms::transform; #[pyfunction(name = "discrete_order_stat", signature = (random_variable, n, r, replace="w"))] pub fn discrete_order_stat_py( @@ -150,6 +151,24 @@ pub fn bootstrap_rv_py(variates: Vec) -> PyResult { Ok(fast_rv) } +#[pyfunction(name = "truncate_discrete", signature = (random_variable, min_support, max_support))] +pub fn truncate_discrete_py( + random_variable: &Bound<'_, PyAny>, + min_support: Number, + max_support: Number, +) -> PyResult { + let random_variable: FastRV = random_variable.extract()?; + let truncated_rv = + transform::truncate_discrete(&random_variable.inner, min_support, max_support) + .map_err(PyValueError::new_err)?; + Ok(FastRV::new( + truncated_rv.function, + truncated_rv.support, + truncated_rv.functional_form, + truncated_rv.domain_type, + )) +} + #[pyclass] pub struct FastRV { inner: RandomVariable, diff --git a/test_applpy/unit/test_transform.py b/test_applpy/unit/test_transform.py index d0b2e29..0d4fd5a 100644 --- a/test_applpy/unit/test_transform.py +++ b/test_applpy/unit/test_transform.py @@ -52,7 +52,8 @@ def test_transform_and_truncate_happy_paths(): assert isinstance(transform(discrete, [[x + 1, x + 2], [0, 1, 2]]), RV) assert isinstance(transform(piecewise, [[x, x**2], [0, 1, 2]]), RV) assert isinstance(truncate(continuous, [Rational(1, 4), Rational(3, 4)]), RV) - assert isinstance(truncate(discrete, [1, 1]), RV) + with pytest.raises(ValueError): + truncate(discrete, [1, 1]) def test_mixture_happy_paths():