Skip to content

Commit 73d8f3c

Browse files
committed
patch rust truncate into python
1 parent aad8845 commit 73d8f3c

2 files changed

Lines changed: 20 additions & 18 deletions

File tree

applpy/transform.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,16 @@
88
from .rv import RV, RVError, t, x
99

1010

11+
try:
12+
import applpy_rust
13+
except ImportError:
14+
raise ImportError(
15+
"applpy_rust extension is not built. "
16+
"Run `uv sync --extra rust` then "
17+
"`uv run --no-sync maturin develop -m rust/Cargo.toml`."
18+
)
19+
20+
1121
def transform(random_variable, transform_spec):
1222
"""
1323
Procedure Name: Transform
@@ -423,23 +433,14 @@ def _truncate_discrete_functional(pdf_random_variable, cdf_random_variable, supp
423433

424434

425435
def _truncate_discrete(pdf_random_variable, support_interval):
426-
# Find the area of the truncated random variable
427-
truncation_area = 0
428-
for i in range(len(pdf_random_variable.support)):
429-
if pdf_random_variable.support[i] >= support_interval[0]:
430-
if pdf_random_variable.support[i] <= support_interval[1]:
431-
truncation_area += pdf_random_variable.func[i]
432-
# Truncate the random variable and find the probability
433-
# at each point
434-
truncated_functions = []
435-
truncated_support = []
436-
for i in range(len(pdf_random_variable.support)):
437-
if pdf_random_variable.support[i] >= support_interval[0]:
438-
if pdf_random_variable.support[i] <= support_interval[1]:
439-
truncated_functions.append(pdf_random_variable.func[i] / truncation_area)
440-
truncated_support.append(pdf_random_variable.support[i])
441-
# Return the truncated random variable
442-
return RV(truncated_functions, truncated_support, ["discrete", "pdf"])
436+
min_support, max_support = tuple(support_interval)
437+
fast_rv = applpy_rust.truncate_discrete(pdf_random_variable, min_support, max_support)
438+
return RV(
439+
func=fast_rv.function,
440+
support=fast_rv.support,
441+
functional_form=fast_rv.functional_form,
442+
domain_type=fast_rv.domain_type,
443+
)
443444

444445

445446
def mixture(mix_parameters, mix_random_variables):

test_applpy/unit/test_transform.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def test_transform_and_truncate_happy_paths():
5252
assert isinstance(transform(discrete, [[x + 1, x + 2], [0, 1, 2]]), RV)
5353
assert isinstance(transform(piecewise, [[x, x**2], [0, 1, 2]]), RV)
5454
assert isinstance(truncate(continuous, [Rational(1, 4), Rational(3, 4)]), RV)
55-
assert isinstance(truncate(discrete, [1, 1]), RV)
55+
with pytest.raises(ValueError):
56+
truncate(discrete, [1, 1])
5657

5758

5859
def test_mixture_happy_paths():

0 commit comments

Comments
 (0)