Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 18 additions & 17 deletions applpy/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,16 @@
from .rv import RV, RVError, t, x


try:
import applpy_rust
except ImportError:
raise ImportError(
"applpy_rust extension is not built. "
"Run `uv sync --extra rust` then "
"`uv run --no-sync maturin develop -m rust/Cargo.toml`."
)


def transform(random_variable, transform_spec):
"""
Procedure Name: Transform
Expand Down Expand Up @@ -423,23 +433,14 @@ def _truncate_discrete_functional(pdf_random_variable, cdf_random_variable, supp


def _truncate_discrete(pdf_random_variable, support_interval):
# Find the area of the truncated random variable
truncation_area = 0
for i in range(len(pdf_random_variable.support)):
if pdf_random_variable.support[i] >= support_interval[0]:
if pdf_random_variable.support[i] <= support_interval[1]:
truncation_area += pdf_random_variable.func[i]
# Truncate the random variable and find the probability
# at each point
truncated_functions = []
truncated_support = []
for i in range(len(pdf_random_variable.support)):
if pdf_random_variable.support[i] >= support_interval[0]:
if pdf_random_variable.support[i] <= support_interval[1]:
truncated_functions.append(pdf_random_variable.func[i] / truncation_area)
truncated_support.append(pdf_random_variable.support[i])
# Return the truncated random variable
return RV(truncated_functions, truncated_support, ["discrete", "pdf"])
min_support, max_support = tuple(support_interval)
fast_rv = applpy_rust.truncate_discrete(pdf_random_variable, min_support, max_support)
return RV(
func=fast_rv.function,
support=fast_rv.support,
functional_form=fast_rv.functional_form,
domain_type=fast_rv.domain_type,
)


def mixture(mix_parameters, mix_random_variables):
Expand Down
1 change: 1 addition & 0 deletions rust/src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pub mod moments;
pub mod number;
pub mod order_stat;
pub mod rv;
pub mod transform;
169 changes: 169 additions & 0 deletions rust/src/algorithms/transform.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
#![allow(dead_code)]

use crate::algorithms::number::Number;
use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};

/// Truncates a discrete random variable by cutting off a portion of the support
/// and normalizing total probability of the distribution to 1.
///
/// # Arguments
/// * `random_variable` - the random variable to truncate
/// * `min_support` - the minimum support of the new random variable.
/// Must be greater than or equal to the current minimum support.
/// * `max_support` - the maximum support of the new random variable.
/// Must be less than or equal to the current maximum support.
///
/// # Returns
/// * `truncated_rv` - the truncated random variable
///
/// # Examples
/// ```
/// use applpy_rust::algorithms::number::Number;
/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
/// use applpy_rust::algorithms::transform::truncate_discrete;
/// use num_rational::Rational64;
///
/// let rv = RandomVariable {
/// function: vec![
/// Number::Rational(Rational64::new(1, 10)),
/// Number::Rational(Rational64::new(2, 10)),
/// Number::Rational(Rational64::new(3, 10)),
/// Number::Rational(Rational64::new(4, 10)),
/// ],
/// support: vec![
/// Number::Integer(1),
/// Number::Integer(2),
/// Number::Integer(3),
/// Number::Integer(4),
/// ],
/// functional_form: FunctionalForm::Pdf,
/// domain_type: DomainType::Discrete,
/// };
///
/// let truncated = truncate_discrete(&rv, Number::Integer(2), Number::Integer(3)).unwrap();
///
/// assert_eq!(truncated.support, vec![Number::Integer(2), Number::Integer(3)]);
/// assert_eq!(
/// truncated.function,
/// vec![
/// Number::Rational(Rational64::new(2, 5)),
/// Number::Rational(Rational64::new(3, 5)),
/// ]
/// );
/// assert!(matches!(truncated.functional_form, FunctionalForm::Pdf));
/// assert!(matches!(truncated.domain_type, DomainType::Discrete));
/// ```
pub fn truncate_discrete(
random_variable: &RandomVariable,
min_support: Number,
max_support: Number,
) -> Result<RandomVariable, String> {
let pdf_random_variable = random_variable.to_pdf()?;
let function = pdf_random_variable.function;
let support = pdf_random_variable.support;

if min_support >= max_support {
return Err("max_support must be greater than the min_support".to_string());
}

let first_support = *support.first().ok_or("support is empty")?;
if min_support < first_support {
return Err(
"min support must be greater than or equal to the lowest support value".to_string(),
);
}

let last_support = *support.last().ok_or("support is empty")?;
if max_support > last_support {
return Err(
"max support must be less than or equal to the highest support value".to_string(),
);
}

let mut truncation_area = Number::Integer(0);
for (&support_value, &function_value) in support.iter().zip(function.iter()) {
if support_value >= min_support && support_value <= max_support {
truncation_area += function_value;
}
}

let zero = Number::Integer(0);
if truncation_area == zero {
return Err("there is no probability mass within the specified support range".to_string());
}

let mut truncated_function = Vec::new();
let mut truncated_support = Vec::new();

for (&support_value, &function_value) in support.iter().zip(function.iter()) {
if support_value >= min_support && support_value <= max_support {
let probability = function_value / truncation_area;
truncated_function.push(probability);
truncated_support.push(support_value);
}
}

let truncated_rv = RandomVariable {
function: truncated_function,
support: truncated_support,
functional_form: FunctionalForm::Pdf,
domain_type: DomainType::Discrete,
};
Ok(truncated_rv)
}

#[cfg(test)]
mod tests {
use super::*;
use num_rational::Rational64;

fn sample_discrete_rv() -> RandomVariable {
RandomVariable {
function: vec![
Number::Rational(Rational64::new(1, 10)),
Number::Rational(Rational64::new(2, 10)),
Number::Rational(Rational64::new(3, 10)),
Number::Rational(Rational64::new(4, 10)),
],
support: vec![
Number::Integer(1),
Number::Integer(2),
Number::Integer(3),
Number::Integer(4),
],
functional_form: FunctionalForm::Pdf,
domain_type: DomainType::Discrete,
}
}

#[test]
fn truncate_discrete_renormalizes_probabilities_within_range() {
let rv = sample_discrete_rv();
let truncated = truncate_discrete(&rv, Number::Integer(2), Number::Integer(3)).unwrap();

assert_eq!(
truncated.support,
vec![Number::Integer(2), Number::Integer(3)]
);
assert_eq!(
truncated.function,
vec![
Number::Rational(Rational64::new(2, 5)),
Number::Rational(Rational64::new(3, 5))
]
);
assert!(matches!(truncated.functional_form, FunctionalForm::Pdf));
assert!(matches!(truncated.domain_type, DomainType::Discrete));
}

#[test]
fn truncate_discrete_returns_error_when_min_support_exceeds_bounds() {
let rv = sample_discrete_rv();
let result = truncate_discrete(&rv, Number::Integer(0), Number::Integer(3));

assert!(matches!(
result,
Err(msg) if msg == "min support must be greater than or equal to the lowest support value"
));
}
}
3 changes: 3 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ fn applpy_rust(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {
)?)?;
module.add_function(wrap_pyfunction!(python::api::bootstrap_rv_py, module)?)?;

// transformation functions
module.add_function(wrap_pyfunction!(python::api::truncate_discrete_py, module)?)?;

// dummy function to validate imports
module.add_function(wrap_pyfunction!(dummy_ping, module)?)?;

Expand Down
19 changes: 19 additions & 0 deletions rust/src/python/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use crate::algorithms::number::Number;
use crate::algorithms::order_stat;
use crate::algorithms::rv;
use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
use crate::algorithms::transform;

#[pyfunction(name = "discrete_order_stat", signature = (random_variable, n, r, replace="w"))]
pub fn discrete_order_stat_py(
Expand Down Expand Up @@ -150,6 +151,24 @@ pub fn bootstrap_rv_py(variates: Vec<Number>) -> PyResult<FastRV> {
Ok(fast_rv)
}

#[pyfunction(name = "truncate_discrete", signature = (random_variable, min_support, max_support))]
pub fn truncate_discrete_py(
random_variable: &Bound<'_, PyAny>,
min_support: Number,
max_support: Number,
) -> PyResult<FastRV> {
let random_variable: FastRV = random_variable.extract()?;
let truncated_rv =
transform::truncate_discrete(&random_variable.inner, min_support, max_support)
.map_err(PyValueError::new_err)?;
Ok(FastRV::new(
truncated_rv.function,
truncated_rv.support,
truncated_rv.functional_form,
truncated_rv.domain_type,
))
}

#[pyclass]
pub struct FastRV {
inner: RandomVariable,
Expand Down
3 changes: 2 additions & 1 deletion test_applpy/unit/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def test_transform_and_truncate_happy_paths():
assert isinstance(transform(discrete, [[x + 1, x + 2], [0, 1, 2]]), RV)
assert isinstance(transform(piecewise, [[x, x**2], [0, 1, 2]]), RV)
assert isinstance(truncate(continuous, [Rational(1, 4), Rational(3, 4)]), RV)
assert isinstance(truncate(discrete, [1, 1]), RV)
with pytest.raises(ValueError):
truncate(discrete, [1, 1])


def test_mixture_happy_paths():
Expand Down
Loading