Skip to content

Commit bd433d4

Browse files
committed
fix: add RHS permutation handling in Python solve methods
The solve methods (solve, solve_lu, solve_ldlt) now handle permutation automatically - permuting the RHS to internal order before solving and back to original order after. This makes the Python API simpler since callers don't need to manage permutations themselves.
1 parent 60c9b52 commit bd433d4

1 file changed

Lines changed: 49 additions & 4 deletions

File tree

python/sprux_bindings.cpp

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#include <pybind11/numpy.h>
1616
#include <pybind11/stl.h>
1717

18-
#include "baspacho/baspacho/Solver.h"
18+
#include "sprux/sprux/Solver.h"
1919

2020
namespace py = pybind11;
2121

@@ -119,6 +119,7 @@ class PySpruxSolver {
119119
}
120120

121121
/// Solve using Cholesky factor: L * L^T * x = b. Modifies rhs in place.
122+
/// Permutes rhs to internal order before solving, and back after.
122123
void solve(
123124
py::array_t<double, py::array::c_style> data,
124125
py::array_t<double, py::array::c_style> rhs,
@@ -127,10 +128,24 @@ class PySpruxSolver {
127128
auto data_buf = data.request();
128129
auto rhs_buf = rhs.request();
129130
int64_t stride = solver_->order();
131+
132+
// Permute RHS to internal order
133+
const auto& perm = solver_->paramToSpan();
134+
std::vector<double> rhs_internal(rhs_buf.size);
135+
for (int64_t i = 0; i < solver_->order(); i++) {
136+
rhs_internal[perm[i]] = static_cast<double*>(rhs_buf.ptr)[i];
137+
}
138+
139+
// Solve in-place on internal vector
130140
solver_->solve(
131141
static_cast<const double*>(data_buf.ptr),
132-
static_cast<double*>(rhs_buf.ptr),
142+
rhs_internal.data(),
133143
stride, nrhs);
144+
145+
// Permute back to original order
146+
for (int64_t i = 0; i < solver_->order(); i++) {
147+
static_cast<double*>(rhs_buf.ptr)[i] = rhs_internal[perm[i]];
148+
}
134149
}
135150

136151
// -- LU --
@@ -152,6 +167,7 @@ class PySpruxSolver {
152167
}
153168

154169
/// Solve using LU factor: P * L * U * x = b. Modifies rhs in place.
170+
/// Permutes rhs to internal order before solving, and back after.
155171
void solve_lu(
156172
py::array_t<double, py::array::c_style> data,
157173
py::array_t<int64_t, py::array::c_style> pivots,
@@ -162,11 +178,25 @@ class PySpruxSolver {
162178
auto piv_buf = pivots.request();
163179
auto rhs_buf = rhs.request();
164180
int64_t stride = solver_->order();
181+
182+
// Permute RHS to internal order
183+
const auto& perm = solver_->paramToSpan();
184+
std::vector<double> rhs_internal(rhs_buf.size);
185+
for (int64_t i = 0; i < solver_->order(); i++) {
186+
rhs_internal[perm[i]] = static_cast<double*>(rhs_buf.ptr)[i];
187+
}
188+
189+
// Solve in-place on internal vector
165190
solver_->solveLU(
166191
static_cast<const double*>(data_buf.ptr),
167192
static_cast<const int64_t*>(piv_buf.ptr),
168-
static_cast<double*>(rhs_buf.ptr),
193+
rhs_internal.data(),
169194
stride, nrhs);
195+
196+
// Permute back to original order
197+
for (int64_t i = 0; i < solver_->order(); i++) {
198+
static_cast<double*>(rhs_buf.ptr)[i] = rhs_internal[perm[i]];
199+
}
170200
}
171201

172202
// -- LDL^T --
@@ -179,6 +209,7 @@ class PySpruxSolver {
179209
}
180210

181211
/// Solve using LDL^T factor. Modifies rhs in place.
212+
/// Permutes rhs to internal order before solving, and back after.
182213
void solve_ldlt(
183214
py::array_t<double, py::array::c_style> data,
184215
py::array_t<double, py::array::c_style> rhs,
@@ -187,10 +218,24 @@ class PySpruxSolver {
187218
auto data_buf = data.request();
188219
auto rhs_buf = rhs.request();
189220
int64_t stride = solver_->order();
221+
222+
// Permute RHS to internal order
223+
const auto& perm = solver_->paramToSpan();
224+
std::vector<double> rhs_internal(rhs_buf.size);
225+
for (int64_t i = 0; i < solver_->order(); i++) {
226+
rhs_internal[perm[i]] = static_cast<double*>(rhs_buf.ptr)[i];
227+
}
228+
229+
// Solve in-place on internal vector
190230
solver_->solveLDLT(
191231
static_cast<const double*>(data_buf.ptr),
192-
static_cast<double*>(rhs_buf.ptr),
232+
rhs_internal.data(),
193233
stride, nrhs);
234+
235+
// Permute back to original order
236+
for (int64_t i = 0; i < solver_->order(); i++) {
237+
static_cast<double*>(rhs_buf.ptr)[i] = rhs_internal[perm[i]];
238+
}
194239
}
195240

196241
// -- CSR load/extract (for populating internal format from standard CSR) --

0 commit comments

Comments
 (0)