1515#include < pybind11/numpy.h>
1616#include < pybind11/stl.h>
1717
18- #include " baspacho/baspacho /Solver.h"
18+ #include " sprux/sprux /Solver.h"
1919
2020namespace py = pybind11;
2121
@@ -119,6 +119,7 @@ class PySpruxSolver {
119119 }
120120
121121 // / Solve using Cholesky factor: L * L^T * x = b. Modifies rhs in place.
122+ // / Permutes rhs to internal order before solving, and back after.
122123 void solve (
123124 py::array_t <double , py::array::c_style> data,
124125 py::array_t <double , py::array::c_style> rhs,
@@ -127,10 +128,24 @@ class PySpruxSolver {
127128 auto data_buf = data.request ();
128129 auto rhs_buf = rhs.request ();
129130 int64_t stride = solver_->order ();
131+
132+ // Permute RHS to internal order
133+ const auto & perm = solver_->paramToSpan ();
134+ std::vector<double > rhs_internal (rhs_buf.size );
135+ for (int64_t i = 0 ; i < solver_->order (); i++) {
136+ rhs_internal[perm[i]] = static_cast <double *>(rhs_buf.ptr )[i];
137+ }
138+
139+ // Solve in-place on internal vector
130140 solver_->solve (
131141 static_cast <const double *>(data_buf.ptr ),
132- static_cast < double *>(rhs_buf. ptr ),
142+ rhs_internal. data ( ),
133143 stride, nrhs);
144+
145+ // Permute back to original order
146+ for (int64_t i = 0 ; i < solver_->order (); i++) {
147+ static_cast <double *>(rhs_buf.ptr )[i] = rhs_internal[perm[i]];
148+ }
134149 }
135150
136151 // -- LU --
@@ -152,6 +167,7 @@ class PySpruxSolver {
152167 }
153168
154169 // / Solve using LU factor: P * L * U * x = b. Modifies rhs in place.
170+ // / Permutes rhs to internal order before solving, and back after.
155171 void solve_lu (
156172 py::array_t <double , py::array::c_style> data,
157173 py::array_t <int64_t , py::array::c_style> pivots,
@@ -162,11 +178,25 @@ class PySpruxSolver {
162178 auto piv_buf = pivots.request ();
163179 auto rhs_buf = rhs.request ();
164180 int64_t stride = solver_->order ();
181+
182+ // Permute RHS to internal order
183+ const auto & perm = solver_->paramToSpan ();
184+ std::vector<double > rhs_internal (rhs_buf.size );
185+ for (int64_t i = 0 ; i < solver_->order (); i++) {
186+ rhs_internal[perm[i]] = static_cast <double *>(rhs_buf.ptr )[i];
187+ }
188+
189+ // Solve in-place on internal vector
165190 solver_->solveLU (
166191 static_cast <const double *>(data_buf.ptr ),
167192 static_cast <const int64_t *>(piv_buf.ptr ),
168- static_cast < double *>(rhs_buf. ptr ),
193+ rhs_internal. data ( ),
169194 stride, nrhs);
195+
196+ // Permute back to original order
197+ for (int64_t i = 0 ; i < solver_->order (); i++) {
198+ static_cast <double *>(rhs_buf.ptr )[i] = rhs_internal[perm[i]];
199+ }
170200 }
171201
172202 // -- LDL^T --
@@ -179,6 +209,7 @@ class PySpruxSolver {
179209 }
180210
181211 // / Solve using LDL^T factor. Modifies rhs in place.
212+ // / Permutes rhs to internal order before solving, and back after.
182213 void solve_ldlt (
183214 py::array_t <double , py::array::c_style> data,
184215 py::array_t <double , py::array::c_style> rhs,
@@ -187,10 +218,24 @@ class PySpruxSolver {
187218 auto data_buf = data.request ();
188219 auto rhs_buf = rhs.request ();
189220 int64_t stride = solver_->order ();
221+
222+ // Permute RHS to internal order
223+ const auto & perm = solver_->paramToSpan ();
224+ std::vector<double > rhs_internal (rhs_buf.size );
225+ for (int64_t i = 0 ; i < solver_->order (); i++) {
226+ rhs_internal[perm[i]] = static_cast <double *>(rhs_buf.ptr )[i];
227+ }
228+
229+ // Solve in-place on internal vector
190230 solver_->solveLDLT (
191231 static_cast <const double *>(data_buf.ptr ),
192- static_cast < double *>(rhs_buf. ptr ),
232+ rhs_internal. data ( ),
193233 stride, nrhs);
234+
235+ // Permute back to original order
236+ for (int64_t i = 0 ; i < solver_->order (); i++) {
237+ static_cast <double *>(rhs_buf.ptr )[i] = rhs_internal[perm[i]];
238+ }
194239 }
195240
196241 // -- CSR load/extract (for populating internal format from standard CSR) --
0 commit comments