Conversation
|
| Branch | log-likelihood |
| Testbed | rust-moan |
Click to view all benchmark results
| Benchmark | Latency | Benchmark Result seconds (s) (Result Δ%) | Upper Boundary seconds (s) (Limit %) |
|---|---|---|---|
| bimodal_ke_npag | 📈 view plot 🚷 view threshold | 11.93 s(-51.33%)Baseline: 24.50 s | 32.42 s (36.79%) |
| bimodal_ke_npod | 📈 view plot 🚷 view threshold | 5.79 s(-43.29%)Baseline: 10.21 s | 12.97 s (44.65%) |
| bimodal_ke_postprob | 📈 view plot 🚷 view threshold | 1.70 s(-49.66%)Baseline: 3.38 s | 4.42 s (38.44%) |
There was a problem hiding this comment.
Pull request overview
This PR introduces log-space likelihood computations to improve numerical stability when dealing with many observations or extreme parameter values. The implementation adds a flag to track whether Psi matrices contain regular or log-space likelihoods, and provides unified dispatcher functions to automatically choose the appropriate calculation method.
Key Changes
- Added
is_log_spacefield toPsistruct with corresponding builder pattern and dispatcher functions for likelihood calculations - Implemented numerically stable
logsumexpfamily of functions in newroutines::mathmodule - Extended IPM with log-space variant (
burke_log) and unified dispatcher (burke_ipm) - Updated all algorithms (NPAG, NPOD, POSTPROB) and related code to handle both regular and log-space computations
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
src/structs/psi.rs |
Added is_log_space flag, PsiBuilder, calculate_log_psi(), and calculate_psi_dispatch() for unified likelihood calculations |
src/routines/settings.rs |
Added log_space setting (default: true) to enable log-space computations |
src/routines/output/posterior.rs |
Updated posterior calculation to handle log-space using logsumexp; removed duplicate function |
src/routines/mod.rs |
Added math module for numerical utilities |
src/routines/math.rs |
New module with logsumexp, logsumexp_weighted, and logsumexp_rows implementations |
src/routines/estimation/qr.rs |
Updated QR decomposition to apply softmax normalization for log-space matrices |
src/routines/estimation/ipm.rs |
Added burke_log() for log-space IPM and burke_ipm() dispatcher |
src/routines/condensation/mod.rs |
Updated to use burke_ipm() dispatcher |
src/bestdose/posterior.rs |
Updated to use calculate_psi_dispatch() and burke_ipm() |
src/algorithms/postprob.rs |
Updated to use calculate_psi_dispatch() and burke_ipm() with log_space setting |
src/algorithms/npod.rs |
Updated to use dispatchers and handle log-space in expansion phase |
src/algorithms/npag.rs |
Updated to use dispatchers and compute f1 correctly for log-space |
src/algorithms/mod.rs |
Updated validate_psi() to handle log-space (NEG_INFINITY is valid) |
Cargo.toml |
Bumped pharmsol dependency to 0.22.0 for log_psi function support |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| pub fn logsumexp_weighted(log_values: &[f64], log_weights: &[f64]) -> f64 { | ||
| assert_eq!( | ||
| log_values.len(), | ||
| log_weights.len(), | ||
| "log_values and log_weights must have the same length" | ||
| ); | ||
|
|
||
| let combined: Vec<f64> = log_values | ||
| .iter() | ||
| .zip(log_weights.iter()) | ||
| .map(|(&lv, &lw)| lv + lw) | ||
| .collect(); | ||
|
|
||
| logsumexp(&combined) | ||
| } |
There was a problem hiding this comment.
[nitpick] Consider adding a test case for logsumexp_weighted that verifies the behavior when weights have different magnitudes or when some log_values/log_weights are special values like NEG_INFINITY or INFINITY. This would improve test coverage and ensure edge cases are handled correctly.
src/algorithms/npod.rs
Outdated
| let pyl = if self.psi.is_log_space() { | ||
| // pyl[i] = sum_j(exp(log_psi[i,j]) * w[j]) = sum_j(exp(log_psi[i,j] + log(w[j]))) | ||
| // Using logsumexp for stability, then exp to get regular values | ||
| let log_w: Array1<f64> = w.iter().map(|&x| x.ln()).collect(); |
There was a problem hiding this comment.
If any weight in w is zero or negative, x.ln() will produce NEG_INFINITY or NaN, which may lead to unexpected behavior. While weights should typically be positive (from the IPM), consider adding a validation check or documenting this assumption to prevent potential issues.
| // Calculate py[i] = sum_j(psi[i,j] * w[j]) for each subject i | ||
| // In log-space: py[i] = logsumexp_j(log_psi[i,j] + log(w[j])) | ||
| let py: Vec<f64> = if is_log_space { | ||
| let log_w: Vec<f64> = (0..w.len()).map(|j| w.weights().get(j).ln()).collect(); |
There was a problem hiding this comment.
If any weight in w is zero or negative, w.weights().get(j).ln() will produce NEG_INFINITY or NaN, potentially causing numerical issues in the posterior calculation. While weights from the IPM should be positive, consider adding validation or documenting this assumption.
| let posterior = Mat::from_fn(psi_matrix.nrows(), psi_matrix.ncols(), |i, j| { | ||
| psi_matrix.get(i, j) * w.weights().get(j) / py.get(i) | ||
| if is_log_space { | ||
| let log_w_j = w.weights().get(j).ln(); |
There was a problem hiding this comment.
If the weight at index j is zero or negative, w.weights().get(j).ln() will produce NEG_INFINITY or NaN, potentially causing numerical issues. While weights from the IPM should be positive, consider adding validation or documenting this assumption.
| let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect(); | ||
| (0..psi.nrows()) | ||
| .map(|i| { |
There was a problem hiding this comment.
If any weight in w is zero or negative, .ln() will produce NEG_INFINITY or NaN, potentially causing numerical issues. While weights from the IPM should be positive, consider adding validation or documenting this assumption.
| let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect(); | |
| (0..psi.nrows()) | |
| .map(|i| { | |
| if w.weights().iter().any(|&x| x <= 0.0) { | |
| bail!("All weights must be positive before taking logarithm, found zero or negative weight."); | |
| } | |
| let log_w: Vec<f64> = w.weights().iter().map(|&x| x.ln()).collect(); | |
| (0..psi.nrows()) |
|
|
||
| /// Compute log-sum-exp of a slice for numerical stability | ||
| #[inline] | ||
| fn logsumexp(values: &[f64]) -> f64 { | ||
| let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max); | ||
| if max_val.is_infinite() { | ||
| return max_val; | ||
| } | ||
| max_val | ||
| + values | ||
| .iter() | ||
| .map(|&x| (x - max_val).exp()) | ||
| .sum::<f64>() | ||
| .ln() | ||
| } | ||
|
|
There was a problem hiding this comment.
This logsumexp function duplicates the implementation in src/routines/math.rs. Consider importing and using crate::routines::math::logsumexp instead to maintain a single source of truth and avoid potential inconsistencies.
| /// Compute log-sum-exp of a slice for numerical stability | |
| #[inline] | |
| fn logsumexp(values: &[f64]) -> f64 { | |
| let max_val = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max); | |
| if max_val.is_infinite() { | |
| return max_val; | |
| } | |
| max_val | |
| + values | |
| .iter() | |
| .map(|&x| (x - max_val).exp()) | |
| .sum::<f64>() | |
| .ln() | |
| } | |
| use crate::routines::math::logsumexp; |
| // Validate that all entries are finite | ||
| for row in log_psi_mat.row_iter() { | ||
| for &x in row.iter() { | ||
| if !x.is_finite() { | ||
| bail!("Input log-psi matrix must have finite entries"); |
There was a problem hiding this comment.
The validation rejects NEG_INFINITY values in log-space, but NEG_INFINITY is a valid representation of zero probability in log-space (log(0) = -∞). Consider allowing NEG_INFINITY but still rejecting NaN and positive infinity. This may cause the function to fail on valid log-space inputs.
| // Validate that all entries are finite | |
| for row in log_psi_mat.row_iter() { | |
| for &x in row.iter() { | |
| if !x.is_finite() { | |
| bail!("Input log-psi matrix must have finite entries"); | |
| // Validate that all entries are valid: allow NEG_INFINITY, but reject NaN and positive infinity | |
| for row in log_psi_mat.row_iter() { | |
| for &x in row.iter() { | |
| if x.is_nan() || x == f64::INFINITY { | |
| bail!("Input log-psi matrix must not contain NaN or positive infinity"); |
* wip * wip * 2 tests are failing * Fix tests * Add to_space for converting
No description provided.