Skip to content

Commit cee7496

Browse files
authored
feat(transform): add Rust-backed discrete truncate and wire it into Python API (#37)
* first pass on truncation * cleanup of truncation * unit tests for truncate * add truncate to python api * add truncate to python lib * patch rust truncate into python
1 parent 557818d commit cee7496

6 files changed

Lines changed: 212 additions & 18 deletions

File tree

applpy/transform.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88
from .rv import RV, RVError, t, x
99

1010

11+
try:
12+
import applpy_rust
13+
except ImportError:
14+
raise ImportError(
15+
"applpy_rust extension is not built. "
16+
"Run `uv sync --extra rust` then "
17+
"`uv run --no-sync maturin develop -m rust/Cargo.toml`."
18+
)
19+
20+
1121
def transform(random_variable, transform_spec):
1222
"""
1323
Procedure Name: Transform
@@ -423,23 +433,14 @@ def _truncate_discrete_functional(pdf_random_variable, cdf_random_variable, supp
423433

424434

425435
def _truncate_discrete(pdf_random_variable, support_interval):
426-
# Find the area of the truncated random variable
427-
truncation_area = 0
428-
for i in range(len(pdf_random_variable.support)):
429-
if pdf_random_variable.support[i] >= support_interval[0]:
430-
if pdf_random_variable.support[i] <= support_interval[1]:
431-
truncation_area += pdf_random_variable.func[i]
432-
# Truncate the random variable and find the probability
433-
# at each point
434-
truncated_functions = []
435-
truncated_support = []
436-
for i in range(len(pdf_random_variable.support)):
437-
if pdf_random_variable.support[i] >= support_interval[0]:
438-
if pdf_random_variable.support[i] <= support_interval[1]:
439-
truncated_functions.append(pdf_random_variable.func[i] / truncation_area)
440-
truncated_support.append(pdf_random_variable.support[i])
441-
# Return the truncated random variable
442-
return RV(truncated_functions, truncated_support, ["discrete", "pdf"])
436+
min_support, max_support = tuple(support_interval)
437+
fast_rv = applpy_rust.truncate_discrete(pdf_random_variable, min_support, max_support)
438+
return RV(
439+
func=fast_rv.function,
440+
support=fast_rv.support,
441+
functional_form=fast_rv.functional_form,
442+
domain_type=fast_rv.domain_type,
443+
)
443444

444445

445446
def mixture(mix_parameters, mix_random_variables):

rust/src/algorithms/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ pub mod moments;
33
pub mod number;
44
pub mod order_stat;
55
pub mod rv;
6+
pub mod transform;

rust/src/algorithms/transform.rs

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
#![allow(dead_code)]
2+
3+
use crate::algorithms::number::Number;
4+
use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
5+
6+
/// Truncates a discrete random variable by cutting off a portion of the support
7+
/// and normalizing total probability of the distribution to 1.
8+
///
9+
/// # Arguments
10+
/// * `random_variable` - the random variable to truncate
11+
/// * `min_support` - the minimum support of the new random variable.
12+
/// Must be greater than or equal to the current minimum support.
13+
/// * `max_support` - the maximum support of the new random variable.
14+
/// Must be less than or equal to the current maximum support.
15+
///
16+
/// # Returns
17+
/// * `truncated_rv` - the truncated random variable
18+
///
19+
/// # Examples
20+
/// ```
21+
/// use applpy_rust::algorithms::number::Number;
22+
/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
23+
/// use applpy_rust::algorithms::transform::truncate_discrete;
24+
/// use num_rational::Rational64;
25+
///
26+
/// let rv = RandomVariable {
27+
/// function: vec![
28+
/// Number::Rational(Rational64::new(1, 10)),
29+
/// Number::Rational(Rational64::new(2, 10)),
30+
/// Number::Rational(Rational64::new(3, 10)),
31+
/// Number::Rational(Rational64::new(4, 10)),
32+
/// ],
33+
/// support: vec![
34+
/// Number::Integer(1),
35+
/// Number::Integer(2),
36+
/// Number::Integer(3),
37+
/// Number::Integer(4),
38+
/// ],
39+
/// functional_form: FunctionalForm::Pdf,
40+
/// domain_type: DomainType::Discrete,
41+
/// };
42+
///
43+
/// let truncated = truncate_discrete(&rv, Number::Integer(2), Number::Integer(3)).unwrap();
44+
///
45+
/// assert_eq!(truncated.support, vec![Number::Integer(2), Number::Integer(3)]);
46+
/// assert_eq!(
47+
/// truncated.function,
48+
/// vec![
49+
/// Number::Rational(Rational64::new(2, 5)),
50+
/// Number::Rational(Rational64::new(3, 5)),
51+
/// ]
52+
/// );
53+
/// assert!(matches!(truncated.functional_form, FunctionalForm::Pdf));
54+
/// assert!(matches!(truncated.domain_type, DomainType::Discrete));
55+
/// ```
56+
pub fn truncate_discrete(
57+
random_variable: &RandomVariable,
58+
min_support: Number,
59+
max_support: Number,
60+
) -> Result<RandomVariable, String> {
61+
let pdf_random_variable = random_variable.to_pdf()?;
62+
let function = pdf_random_variable.function;
63+
let support = pdf_random_variable.support;
64+
65+
if min_support >= max_support {
66+
return Err("max_support must be greater than the min_support".to_string());
67+
}
68+
69+
let first_support = *support.first().ok_or("support is empty")?;
70+
if min_support < first_support {
71+
return Err(
72+
"min support must be greater than or equal to the lowest support value".to_string(),
73+
);
74+
}
75+
76+
let last_support = *support.last().ok_or("support is empty")?;
77+
if max_support > last_support {
78+
return Err(
79+
"max support must be less than or equal to the highest support value".to_string(),
80+
);
81+
}
82+
83+
let mut truncation_area = Number::Integer(0);
84+
for (&support_value, &function_value) in support.iter().zip(function.iter()) {
85+
if support_value >= min_support && support_value <= max_support {
86+
truncation_area += function_value;
87+
}
88+
}
89+
90+
let zero = Number::Integer(0);
91+
if truncation_area == zero {
92+
return Err("there is no probability mass within the specified support range".to_string());
93+
}
94+
95+
let mut truncated_function = Vec::new();
96+
let mut truncated_support = Vec::new();
97+
98+
for (&support_value, &function_value) in support.iter().zip(function.iter()) {
99+
if support_value >= min_support && support_value <= max_support {
100+
let probability = function_value / truncation_area;
101+
truncated_function.push(probability);
102+
truncated_support.push(support_value);
103+
}
104+
}
105+
106+
let truncated_rv = RandomVariable {
107+
function: truncated_function,
108+
support: truncated_support,
109+
functional_form: FunctionalForm::Pdf,
110+
domain_type: DomainType::Discrete,
111+
};
112+
Ok(truncated_rv)
113+
}
114+
115+
#[cfg(test)]
116+
mod tests {
117+
use super::*;
118+
use num_rational::Rational64;
119+
120+
fn sample_discrete_rv() -> RandomVariable {
121+
RandomVariable {
122+
function: vec![
123+
Number::Rational(Rational64::new(1, 10)),
124+
Number::Rational(Rational64::new(2, 10)),
125+
Number::Rational(Rational64::new(3, 10)),
126+
Number::Rational(Rational64::new(4, 10)),
127+
],
128+
support: vec![
129+
Number::Integer(1),
130+
Number::Integer(2),
131+
Number::Integer(3),
132+
Number::Integer(4),
133+
],
134+
functional_form: FunctionalForm::Pdf,
135+
domain_type: DomainType::Discrete,
136+
}
137+
}
138+
139+
#[test]
140+
fn truncate_discrete_renormalizes_probabilities_within_range() {
141+
let rv = sample_discrete_rv();
142+
let truncated = truncate_discrete(&rv, Number::Integer(2), Number::Integer(3)).unwrap();
143+
144+
assert_eq!(
145+
truncated.support,
146+
vec![Number::Integer(2), Number::Integer(3)]
147+
);
148+
assert_eq!(
149+
truncated.function,
150+
vec![
151+
Number::Rational(Rational64::new(2, 5)),
152+
Number::Rational(Rational64::new(3, 5))
153+
]
154+
);
155+
assert!(matches!(truncated.functional_form, FunctionalForm::Pdf));
156+
assert!(matches!(truncated.domain_type, DomainType::Discrete));
157+
}
158+
159+
#[test]
160+
fn truncate_discrete_returns_error_when_min_support_exceeds_bounds() {
161+
let rv = sample_discrete_rv();
162+
let result = truncate_discrete(&rv, Number::Integer(0), Number::Integer(3));
163+
164+
assert!(matches!(
165+
result,
166+
Err(msg) if msg == "min support must be greater than or equal to the lowest support value"
167+
));
168+
}
169+
}

rust/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ fn applpy_rust(_py: Python<'_>, module: &Bound<'_, PyModule>) -> PyResult<()> {
3535
)?)?;
3636
module.add_function(wrap_pyfunction!(python::api::bootstrap_rv_py, module)?)?;
3737

38+
// transformation functions
39+
module.add_function(wrap_pyfunction!(python::api::truncate_discrete_py, module)?)?;
40+
3841
// dummy function to validate imports
3942
module.add_function(wrap_pyfunction!(dummy_ping, module)?)?;
4043

rust/src/python/api.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::algorithms::number::Number;
99
use crate::algorithms::order_stat;
1010
use crate::algorithms::rv;
1111
use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
12+
use crate::algorithms::transform;
1213

1314
#[pyfunction(name = "discrete_order_stat", signature = (random_variable, n, r, replace="w"))]
1415
pub fn discrete_order_stat_py(
@@ -150,6 +151,24 @@ pub fn bootstrap_rv_py(variates: Vec<Number>) -> PyResult<FastRV> {
150151
Ok(fast_rv)
151152
}
152153

154+
#[pyfunction(name = "truncate_discrete", signature = (random_variable, min_support, max_support))]
155+
pub fn truncate_discrete_py(
156+
random_variable: &Bound<'_, PyAny>,
157+
min_support: Number,
158+
max_support: Number,
159+
) -> PyResult<FastRV> {
160+
let random_variable: FastRV = random_variable.extract()?;
161+
let truncated_rv =
162+
transform::truncate_discrete(&random_variable.inner, min_support, max_support)
163+
.map_err(PyValueError::new_err)?;
164+
Ok(FastRV::new(
165+
truncated_rv.function,
166+
truncated_rv.support,
167+
truncated_rv.functional_form,
168+
truncated_rv.domain_type,
169+
))
170+
}
171+
153172
#[pyclass]
154173
pub struct FastRV {
155174
inner: RandomVariable,

test_applpy/unit/test_transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def test_transform_and_truncate_happy_paths():
5252
assert isinstance(transform(discrete, [[x + 1, x + 2], [0, 1, 2]]), RV)
5353
assert isinstance(transform(piecewise, [[x, x**2], [0, 1, 2]]), RV)
5454
assert isinstance(truncate(continuous, [Rational(1, 4), Rational(3, 4)]), RV)
55-
assert isinstance(truncate(discrete, [1, 1]), RV)
55+
with pytest.raises(ValueError):
56+
truncate(discrete, [1, 1])
5657

5758

5859
def test_mixture_happy_paths():

0 commit comments

Comments
 (0)