|
8 | 8 | from .rv import RV, RVError, t, x |
9 | 9 |
|
10 | 10 |
|
| 11 | +try: |
| 12 | + import applpy_rust |
| 13 | +except ImportError: |
| 14 | + raise ImportError( |
| 15 | + "applpy_rust extension is not built. " |
| 16 | + "Run `uv sync --extra rust` then " |
| 17 | + "`uv run --no-sync maturin develop -m rust/Cargo.toml`." |
| 18 | + ) |
| 19 | + |
| 20 | + |
11 | 21 | def transform(random_variable, transform_spec): |
12 | 22 | """ |
13 | 23 | Procedure Name: Transform |
@@ -423,23 +433,14 @@ def _truncate_discrete_functional(pdf_random_variable, cdf_random_variable, supp |
423 | 433 |
|
424 | 434 |
|
425 | 435 | def _truncate_discrete(pdf_random_variable, support_interval): |
426 | | - # Find the area of the truncated random variable |
427 | | - truncation_area = 0 |
428 | | - for i in range(len(pdf_random_variable.support)): |
429 | | - if pdf_random_variable.support[i] >= support_interval[0]: |
430 | | - if pdf_random_variable.support[i] <= support_interval[1]: |
431 | | - truncation_area += pdf_random_variable.func[i] |
432 | | - # Truncate the random variable and find the probability |
433 | | - # at each point |
434 | | - truncated_functions = [] |
435 | | - truncated_support = [] |
436 | | - for i in range(len(pdf_random_variable.support)): |
437 | | - if pdf_random_variable.support[i] >= support_interval[0]: |
438 | | - if pdf_random_variable.support[i] <= support_interval[1]: |
439 | | - truncated_functions.append(pdf_random_variable.func[i] / truncation_area) |
440 | | - truncated_support.append(pdf_random_variable.support[i]) |
441 | | - # Return the truncated random variable |
442 | | - return RV(truncated_functions, truncated_support, ["discrete", "pdf"]) |
| 436 | + min_support, max_support = tuple(support_interval) |
| 437 | + fast_rv = applpy_rust.truncate_discrete(pdf_random_variable, min_support, max_support) |
| 438 | + return RV( |
| 439 | + func=fast_rv.function, |
| 440 | + support=fast_rv.support, |
| 441 | + functional_form=fast_rv.functional_form, |
| 442 | + domain_type=fast_rv.domain_type, |
| 443 | + ) |
443 | 444 |
|
444 | 445 |
|
445 | 446 | def mixture(mix_parameters, mix_random_variables): |
|
0 commit comments