Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pharmsol = "=0.22.1"
rand = "0.9.0"
anyhow = "1.0.100"
rayon = "1.10.0"
tikv-jemallocator = "0.6.1"

[features]
default = []
Expand Down
6 changes: 6 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,9 @@ pub mod prelude {
pub use pharmsol::fetch_params;
pub use pharmsol::lag;
}

use tikv_jemallocator::Jemalloc;

// Use jemalloc as the global allocator
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;
72 changes: 29 additions & 43 deletions src/routines/estimation/ipm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use anyhow::bail;
use faer::linalg::triangular_solve::solve_lower_triangular_in_place;
use faer::linalg::triangular_solve::solve_upper_triangular_in_place;
use faer::{Col, Mat, Row};
use rayon::prelude::*;

/// Applies Burke's Interior Point Method (IPM) to solve a convex optimization problem.
///
/// The objective function to maximize is:
Expand Down Expand Up @@ -93,14 +93,14 @@ pub fn burke(psi: &Psi) -> anyhow::Result<(Weights, f64)> {

let mut psi_inner: Mat<f64> = Mat::zeros(psi.nrows(), psi.ncols());

let n_threads = faer::get_global_parallelism().degree();

let rows = psi.nrows();

let mut output: Vec<Mat<f64>> = (0..n_threads).map(|_| Mat::zeros(rows, rows)).collect();

let mut h: Mat<f64> = Mat::zeros(rows, rows);

// Cache-size threshold: prefer sequential for small matrices to avoid thread overhead
// For larger matrices, use faer's built-in parallelism which has better cache behavior
const PARALLEL_THRESHOLD: usize = 512;

while mu > eps || norm_r > eps || gap > eps {
let smu = sig * mu;
// inner = lam ./ y, elementwise.
Expand All @@ -109,46 +109,32 @@ pub fn burke(psi: &Psi) -> anyhow::Result<(Weights, f64)> {
let w_plam = Col::from_fn(plam.nrows(), |i| plam.get(i) / w.get(i));

// Scale each column of psi by the corresponding element of 'inner'

if psi.ncols() > n_threads * 128 {
psi_inner
.par_col_partition_mut(n_threads)
.zip(psi.par_col_partition(n_threads))
.zip(inner.par_partition(n_threads))
.zip(output.par_iter_mut())
.for_each(|(((mut psi_inner, psi), inner), output)| {
psi_inner
.as_mut()
.col_iter_mut()
.zip(psi.col_iter())
.zip(inner.iter())
.for_each(|((col, psi_col), inner_val)| {
col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
*x = psi_val * inner_val;
});
});
faer::linalg::matmul::triangular::matmul(
output.as_mut(),
faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
faer::Accum::Replace,
&psi_inner,
faer::linalg::matmul::triangular::BlockStructure::Rectangular,
psi.transpose(),
faer::linalg::matmul::triangular::BlockStructure::Rectangular,
1.0,
faer::Par::Seq,
);
// Use sequential column scaling - cache-friendly access pattern
psi_inner
.as_mut()
.col_iter_mut()
.zip(psi.col_iter())
.zip(inner.iter())
.for_each(|((col, psi_col), inner_val)| {
col.iter_mut().zip(psi_col.iter()).for_each(|(x, psi_val)| {
*x = psi_val * inner_val;
});
});

let mut first_iter = true;
for output in &output {
if first_iter {
h.copy_from(output);
first_iter = false;
} else {
h += output;
}
}
// Use faer's built-in parallelism for matmul - it has better cache tiling
// than our manual partitioning which caused false sharing
if psi.ncols() > PARALLEL_THRESHOLD {
faer::linalg::matmul::triangular::matmul(
h.as_mut(),
faer::linalg::matmul::triangular::BlockStructure::TriangularLower,
faer::Accum::Replace,
&psi_inner,
faer::linalg::matmul::triangular::BlockStructure::Rectangular,
psi.transpose(),
faer::linalg::matmul::triangular::BlockStructure::Rectangular,
1.0,
faer::Par::rayon(0), // Let faer handle parallelism with proper cache tiling
);
} else {
psi_inner
.as_mut()
Expand Down
57 changes: 36 additions & 21 deletions src/routines/expansion/adaptative_grid.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use crate::structs::theta::Theta;
use anyhow::Result;
use faer::Row;

/// Implements the adaptive grid algorithm for support point expansion.
///
Expand All @@ -25,36 +24,52 @@ pub fn adaptative_grid(
ranges: &[(f64, f64)],
min_dist: f64,
) -> Result<()> {
let mut candidates = Vec::new();
let n_params = ranges.len();
let n_spp = theta.nspp();

// Collect all points first to avoid borrowing conflicts
// Pre-compute deltas for each dimension (cache-friendly: sequential access)
let deltas: Vec<f64> = ranges.iter().map(|(lo, hi)| eps * (hi - lo)).collect();

// Pre-allocate flat buffer for candidates to minimize allocations
// Max candidates = n_spp * n_params * 2 directions
let max_candidates = n_spp * n_params * 2;
let mut candidates: Vec<f64> = Vec::with_capacity(max_candidates * n_params);
let mut n_candidates = 0usize;

// Generate candidates using flat buffer
for spp in theta.matrix().row_iter() {
for (j, val) in spp.iter().enumerate() {
let l = eps * (ranges[j].1 - ranges[j].0); //abs?
for (j, &val) in spp.iter().enumerate() {
let l = deltas[j];

// Check +delta direction
if val + l < ranges[j].1 {
let mut plus = Row::zeros(spp.ncols());
plus[j] = l;
plus += spp;
candidates.push(plus.iter().copied().collect::<Vec<f64>>());
// Append candidate point to flat buffer
for (k, &v) in spp.iter().enumerate() {
candidates.push(if k == j { v + l } else { v });
}
n_candidates += 1;
}

// Check -delta direction
if val - l > ranges[j].0 {
let mut minus = Row::zeros(spp.ncols());
minus[j] = -l;
minus += spp;
candidates.push(minus.iter().copied().collect::<Vec<f64>>());
for (k, &v) in spp.iter().enumerate() {
candidates.push(if k == j { v - l } else { v });
}
n_candidates += 1;
}
}
}

// Option 1: Check all points against the original theta, then add them
let keep = candidates
.iter()
.filter(|point| theta.check_point(point, min_dist))
.cloned()
.collect::<Vec<_>>();
// Filter and add valid candidates
// Use slice views into the flat buffer to avoid allocations
for i in 0..n_candidates {
let start = i * n_params;
let end = start + n_params;
let point = &candidates[start..end];

for point in keep {
theta.add_point(point.as_slice())?;
if theta.check_point(point, min_dist) {
theta.add_point(point)?;
}
}

Ok(())
Expand Down
Loading