Skip to content

Commit 90530f0

Browse files
committed
Add cuDSS-style block CSR interface for solver creation
Introduces a new API for creating solvers from block CSR matrices, modeled after NVIDIA's cuDSS library interface: - CsrTypes.h: Enums for MatrixType, MatrixView, IndexBase, IndexType - CsrSolver.h/.cpp: BlockCsrDescriptor and createSolverFromBlockCsr() - Solver.h/.cpp: loadFromCsr() and extractToCsr() for value loading - CsrSolverTest.cpp: Unit tests covering structure conversion, index types, base handling, and full factor+solve workflow The block CSR interface provides a natural entry point for users with existing sparse matrix data, supporting both int32 and int64 indices, zero and one-based indexing, and lower/upper triangular views. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude claude-opus-4-5-20251101
1 parent 6bd4dce commit 90530f0

8 files changed

Lines changed: 859 additions & 0 deletions

File tree

baspacho/baspacho/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
set(BaSpaCho_sources
33
CoalescedBlockMatrix.cpp
44
ComputationModel.cpp
5+
CsrSolver.cpp
56
EliminationTree.cpp
67
MatOpsFast.cpp
78
MatOpsRef.cpp

baspacho/baspacho/CsrSolver.cpp

Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#include "baspacho/baspacho/CsrSolver.h"
9+
#include <stdexcept>
10+
#include <string>
11+
12+
namespace BaSpaCho {
13+
14+
namespace {
15+
16+
// Helper to convert indices from void* based on IndexType
17+
template <typename OutT>
18+
void copyIndices(std::vector<OutT>& out, const void* src, int64_t count, IndexType indexType,
19+
int64_t baseAdjust) {
20+
out.resize(count);
21+
if (indexType == INDEX_INT32) {
22+
const int32_t* srcTyped = static_cast<const int32_t*>(src);
23+
for (int64_t i = 0; i < count; i++) {
24+
out[i] = static_cast<OutT>(srcTyped[i]) - baseAdjust;
25+
}
26+
} else {
27+
const int64_t* srcTyped = static_cast<const int64_t*>(src);
28+
for (int64_t i = 0; i < count; i++) {
29+
out[i] = static_cast<OutT>(srcTyped[i]) - baseAdjust;
30+
}
31+
}
32+
}
33+
34+
// Validate the block CSR descriptor
35+
void validateDescriptor(const BlockCsrDescriptor& desc) {
36+
// Check for null pointers
37+
if (desc.numBlocks > 0 && desc.rowStart == nullptr) {
38+
throw std::invalid_argument("BlockCsrDescriptor: rowStart is null");
39+
}
40+
if (desc.numBlockNonzeros > 0 && desc.colIndices == nullptr) {
41+
throw std::invalid_argument("BlockCsrDescriptor: colIndices is null");
42+
}
43+
if (desc.numBlocks > 0 && desc.blockSizes == nullptr) {
44+
throw std::invalid_argument("BlockCsrDescriptor: blockSizes is null");
45+
}
46+
47+
// Check non-negative counts
48+
if (desc.numBlocks < 0) {
49+
throw std::invalid_argument("BlockCsrDescriptor: numBlocks must be non-negative");
50+
}
51+
if (desc.numBlockNonzeros < 0) {
52+
throw std::invalid_argument("BlockCsrDescriptor: numBlockNonzeros must be non-negative");
53+
}
54+
55+
// Validate matrix type/view combination
56+
validateMatrixTypeView(desc.mtype, desc.mview);
57+
58+
// Check block sizes are positive
59+
for (int64_t i = 0; i < desc.numBlocks; i++) {
60+
if (desc.blockSizes[i] <= 0) {
61+
throw std::invalid_argument("BlockCsrDescriptor: blockSizes[" + std::to_string(i) +
62+
"] must be positive, got " + std::to_string(desc.blockSizes[i]));
63+
}
64+
}
65+
}
66+
67+
} // namespace
68+
69+
SparseStructure blockCsrToSparseStructure(const BlockCsrDescriptor& desc) {
70+
validateDescriptor(desc);
71+
72+
if (desc.numBlocks == 0) {
73+
return SparseStructure({0}, {});
74+
}
75+
76+
// Determine base adjustment for 1-based indexing
77+
int64_t baseAdjust = (desc.indexBase == BASE_ONE) ? 1 : 0;
78+
79+
// Copy row pointers (ptrs)
80+
std::vector<int64_t> ptrs;
81+
copyIndices(ptrs, desc.rowStart, desc.numBlocks + 1, desc.indexType, baseAdjust);
82+
83+
// Validate row pointers
84+
if (ptrs[0] != 0) {
85+
throw std::invalid_argument("BlockCsrDescriptor: rowStart[0] must be 0 (after base adjustment)");
86+
}
87+
if (ptrs[desc.numBlocks] != desc.numBlockNonzeros) {
88+
throw std::invalid_argument(
89+
"BlockCsrDescriptor: rowStart[numBlocks] must equal numBlockNonzeros");
90+
}
91+
92+
// Copy column indices (inds)
93+
std::vector<int64_t> inds;
94+
copyIndices(inds, desc.colIndices, desc.numBlockNonzeros, desc.indexType, baseAdjust);
95+
96+
// Validate column indices
97+
for (int64_t i = 0; i < desc.numBlockNonzeros; i++) {
98+
if (inds[i] < 0 || inds[i] >= desc.numBlocks) {
99+
throw std::invalid_argument("BlockCsrDescriptor: colIndices[" + std::to_string(i) +
100+
"] out of range");
101+
}
102+
}
103+
104+
// Create SparseStructure (currently in CSR format)
105+
SparseStructure ss(std::move(ptrs), std::move(inds));
106+
107+
// Handle matrix view:
108+
// - MVIEW_LOWER: Already in correct format for BaSpaCho (lower triangular CSR)
109+
// - MVIEW_UPPER: Need to transpose to get lower triangular
110+
// - MVIEW_FULL: For SPD, we only need lower triangle, so clear upper
111+
if (desc.mview == MVIEW_UPPER) {
112+
// Transpose converts upper CSR to lower CSR
113+
ss = ss.transpose();
114+
} else if (desc.mview == MVIEW_FULL) {
115+
// For full matrix, extract lower triangle only
116+
// clear(false) clears upper half, keeping lower
117+
ss = ss.clear(false);
118+
}
119+
120+
return ss;
121+
}
122+
123+
std::vector<int64_t> getParamSizes(const BlockCsrDescriptor& desc) {
124+
if (desc.numBlocks == 0) {
125+
return {};
126+
}
127+
return std::vector<int64_t>(desc.blockSizes, desc.blockSizes + desc.numBlocks);
128+
}
129+
130+
SolverPtr createSolverFromBlockCsr(const Settings& settings, const BlockCsrDescriptor& desc,
131+
const std::vector<int64_t>& sparseElimRanges,
132+
const std::unordered_set<int64_t>& elimLastIds) {
133+
// Convert block CSR to SparseStructure
134+
SparseStructure ss = blockCsrToSparseStructure(desc);
135+
136+
// Get parameter sizes
137+
std::vector<int64_t> paramSizes = getParamSizes(desc);
138+
139+
// Delegate to existing createSolver
140+
return createSolver(settings, paramSizes, ss, sparseElimRanges, elimLastIds);
141+
}
142+
143+
// Template instantiations for createSolverFromBlockCsrWithValues
144+
template <typename T>
145+
SolverPtr createSolverFromBlockCsrWithValues(const Settings& settings,
146+
const BlockCsrDescriptor& desc, const T* values,
147+
std::vector<T>& outData,
148+
const std::vector<int64_t>& sparseElimRanges) {
149+
// Create solver (structure only)
150+
SolverPtr solver = createSolverFromBlockCsr(settings, desc, sparseElimRanges, {});
151+
152+
// Resize output data buffer
153+
outData.resize(solver->dataSize());
154+
155+
// Zero-initialize (fill-in entries need to be zero)
156+
std::fill(outData.begin(), outData.end(), T(0));
157+
158+
// Convert row pointers to int64_t for loadFromCsr
159+
int64_t baseAdjust = (desc.indexBase == BASE_ONE) ? 1 : 0;
160+
std::vector<int64_t> rowStart(desc.numBlocks + 1);
161+
std::vector<int64_t> colIndices(desc.numBlockNonzeros);
162+
163+
if (desc.indexType == INDEX_INT32) {
164+
const int32_t* rowStartSrc = static_cast<const int32_t*>(desc.rowStart);
165+
const int32_t* colIndicesSrc = static_cast<const int32_t*>(desc.colIndices);
166+
for (int64_t i = 0; i <= desc.numBlocks; i++) {
167+
rowStart[i] = static_cast<int64_t>(rowStartSrc[i]) - baseAdjust;
168+
}
169+
for (int64_t i = 0; i < desc.numBlockNonzeros; i++) {
170+
colIndices[i] = static_cast<int64_t>(colIndicesSrc[i]) - baseAdjust;
171+
}
172+
} else {
173+
const int64_t* rowStartSrc = static_cast<const int64_t*>(desc.rowStart);
174+
const int64_t* colIndicesSrc = static_cast<const int64_t*>(desc.colIndices);
175+
for (int64_t i = 0; i <= desc.numBlocks; i++) {
176+
rowStart[i] = rowStartSrc[i] - baseAdjust;
177+
}
178+
for (int64_t i = 0; i < desc.numBlockNonzeros; i++) {
179+
colIndices[i] = colIndicesSrc[i] - baseAdjust;
180+
}
181+
}
182+
183+
// Load values from CSR format into internal format
184+
solver->loadFromCsr(rowStart.data(), colIndices.data(), desc.blockSizes, values, outData.data());
185+
186+
return solver;
187+
}
188+
189+
// Explicit template instantiations
190+
template SolverPtr createSolverFromBlockCsrWithValues<float>(const Settings& settings,
191+
const BlockCsrDescriptor& desc,
192+
const float* values,
193+
std::vector<float>& outData,
194+
const std::vector<int64_t>& sparseElimRanges);
195+
196+
template SolverPtr createSolverFromBlockCsrWithValues<double>(const Settings& settings,
197+
const BlockCsrDescriptor& desc,
198+
const double* values,
199+
std::vector<double>& outData,
200+
const std::vector<int64_t>& sparseElimRanges);
201+
202+
} // namespace BaSpaCho

baspacho/baspacho/CsrSolver.h

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
*
4+
* This source code is licensed under the MIT license found in the
5+
* LICENSE file in the root directory of this source tree.
6+
*/
7+
8+
#pragma once
9+
10+
#include <memory>
11+
#include <unordered_set>
12+
#include <vector>
13+
#include "baspacho/baspacho/CsrTypes.h"
14+
#include "baspacho/baspacho/Solver.h"
15+
16+
namespace BaSpaCho {
17+
18+
/**
19+
* Block CSR matrix descriptor - holds structure metadata (no numeric data).
20+
* Modeled after cuDSS cudssMatrixCreateCsr parameters.
21+
*
22+
* This describes a block-sparse matrix in CSR format where:
23+
* - rowStart[i] gives the index in colIndices where row i's blocks begin
24+
* - colIndices[rowStart[i]..rowStart[i+1]) are the column indices of blocks in row i
25+
* - blockSizes[i] gives the dimension of the i-th parameter block
26+
*
27+
* The total matrix dimension is sum(blockSizes).
28+
*/
29+
struct BlockCsrDescriptor {
30+
int64_t numBlocks; // Number of block rows/cols (square matrix)
31+
int64_t numBlockNonzeros; // Number of non-zero blocks
32+
const void* rowStart; // Row start pointers [numBlocks+1], type per indexType
33+
const void* colIndices; // Column indices [numBlockNonzeros], type per indexType
34+
const int64_t* blockSizes; // Size of each block [numBlocks]
35+
IndexType indexType; // INT32 or INT64
36+
MatrixType mtype; // GENERAL, SYMMETRIC, SPD (only SPD supported)
37+
MatrixView mview; // FULL, LOWER, UPPER
38+
IndexBase indexBase; // ZERO or ONE based
39+
40+
BlockCsrDescriptor()
41+
: numBlocks(0),
42+
numBlockNonzeros(0),
43+
rowStart(nullptr),
44+
colIndices(nullptr),
45+
blockSizes(nullptr),
46+
indexType(INDEX_INT64),
47+
mtype(MTYPE_SPD),
48+
mview(MVIEW_LOWER),
49+
indexBase(BASE_ZERO) {}
50+
};
51+
52+
/**
53+
* Create a solver from block-level CSR structure.
54+
*
55+
* This is the primary cuDSS-style interface for block CSR matrices.
56+
* The descriptor provides only the sparsity structure; numeric values
57+
* are loaded separately via Solver::loadFromCsr() or the accessor.
58+
*
59+
* @param settings Solver settings (backend, threading, fill policy)
60+
* @param desc Block CSR descriptor (structure only)
61+
* @param sparseElimRanges Optional ranges for sparse elimination optimization
62+
* @param elimLastIds Optional IDs to keep at end for partial factorization
63+
* @return Unique pointer to solver
64+
*
65+
* @throws std::invalid_argument if desc has invalid parameters
66+
*/
67+
SolverPtr createSolverFromBlockCsr(const Settings& settings, const BlockCsrDescriptor& desc,
68+
const std::vector<int64_t>& sparseElimRanges = {},
69+
const std::unordered_set<int64_t>& elimLastIds = {});
70+
71+
/**
72+
* Create a solver from block-level CSR with values preloaded.
73+
*
74+
* Convenience function that creates solver and loads initial values.
75+
* The values array should contain dense block data in CSR order:
76+
* - Blocks are in row-major order within each block
77+
* - Blocks appear in the order specified by rowStart/colIndices
78+
*
79+
* @param settings Solver settings
80+
* @param desc Block CSR descriptor (structure only)
81+
* @param values Numeric values for all blocks (row-major within each block)
82+
* @param outData Output data buffer (will be resized to solver.dataSize())
83+
* @param sparseElimRanges Optional sparse elimination ranges
84+
* @return Unique pointer to solver
85+
*
86+
* @throws std::invalid_argument if desc has invalid parameters
87+
*/
88+
template <typename T>
89+
SolverPtr createSolverFromBlockCsrWithValues(const Settings& settings,
90+
const BlockCsrDescriptor& desc, const T* values,
91+
std::vector<T>& outData,
92+
const std::vector<int64_t>& sparseElimRanges = {});
93+
94+
/**
95+
* Convert block CSR descriptor to SparseStructure.
96+
*
97+
* Internal helper function that converts the CSR format to BaSpaCho's
98+
* internal SparseStructure representation.
99+
*
100+
* @param desc Block CSR descriptor
101+
* @return SparseStructure in lower triangular CSR format
102+
*/
103+
SparseStructure blockCsrToSparseStructure(const BlockCsrDescriptor& desc);
104+
105+
/**
106+
* Get parameter sizes from block CSR descriptor.
107+
*
108+
* @param desc Block CSR descriptor
109+
* @return Vector of block sizes
110+
*/
111+
std::vector<int64_t> getParamSizes(const BlockCsrDescriptor& desc);
112+
113+
} // namespace BaSpaCho

0 commit comments

Comments
 (0)