Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/GenericSparseArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ export AbstractGenericSparseArray,
AbstractGenericSparseVector, AbstractGenericSparseMatrix, AbstractGenericSparseVecOrMat

export GenericSparseVector,
GenericSparseMatrixCSC, GenericSparseMatrixCSR, GenericSparseMatrixCOO
GenericSparseMatrixCSC, GenericSparseMatrixCSR, GenericSparseMatrixCOO,
GenericSparseDiagMatrix

include("core.jl")
include("helpers.jl")
Expand All @@ -45,6 +46,9 @@ include("matrix_csr/matrix_csr.jl")
include("matrix_coo/matrix_coo_kernels.jl")
include("matrix_coo/matrix_coo.jl")

include("matrix_diag/matrix_diag_kernels.jl")
include("matrix_diag/matrix_diag.jl")

include("conversions/conversion_kernels.jl")
include("conversions/conversions.jl")

Expand Down
116 changes: 116 additions & 0 deletions src/conversions/conversions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,119 @@ function GenericSparseMatrixCOO(A::Adjoint{Tv, <:GenericSparseMatrixCSR}) where
conj.(parent_coo.nzval),
)
end

# ============================================================================
# DIA ↔ COO Conversions
# ============================================================================

function GenericSparseMatrixCOO(A::GenericSparseDiagMatrix{Tv, Ti}) where {Tv, Ti}
m, n = size(A)
nnz_count = nnz(A)

if nnz_count == 0
rowind = similar(A.nzval, Ti, 0)
colind = similar(A.nzval, Ti, 0)
nzval = similar(A.nzval, Tv, 0)
return GenericSparseMatrixCOO(m, n, rowind, colind, nzval)
end

backend = get_backend(A.nzval)

# Allocate output arrays on the same backend
rowind = similar(A.nzval, Ti, nnz_count)
colind = similar(A.nzval, Ti, nnz_count)
nzval = similar(A.nzval, Tv, nnz_count)

# Use kernel to convert DIA to COO
kernel! = kernel_diag_to_coo!(backend)
kernel!(rowind, colind, nzval, A.offsets, A.diag_ptrs, A.nzval, m; ndrange = (nnz_count,))

return GenericSparseMatrixCOO(m, n, rowind, colind, nzval)
end

function GenericSparseDiagMatrix(A::GenericSparseMatrixCOO{Tv, Ti}) where {Tv, Ti}
# Convert COO -> CSC -> DIA (via CPU path for structure building)
A_csc = GenericSparseMatrixCSC(A)
return GenericSparseDiagMatrix(A_csc)
end

# Transpose and Adjoint conversions for DIA to COO
function GenericSparseMatrixCOO(A::Transpose{Tv, <:GenericSparseDiagMatrix}) where {Tv}
parent_coo = GenericSparseMatrixCOO(A.parent)
return GenericSparseMatrixCOO(
size(A, 1),
size(A, 2),
parent_coo.colind,
parent_coo.rowind,
parent_coo.nzval,
)
end

function GenericSparseMatrixCOO(A::Adjoint{Tv, <:GenericSparseDiagMatrix}) where {Tv}
parent_coo = GenericSparseMatrixCOO(A.parent)
return GenericSparseMatrixCOO(
size(A, 1),
size(A, 2),
parent_coo.colind,
parent_coo.rowind,
conj.(parent_coo.nzval),
)
end

# ============================================================================
# DIA ↔ CSC Conversions
# ============================================================================

function GenericSparseMatrixCSC(A::GenericSparseDiagMatrix)
return GenericSparseMatrixCSC(GenericSparseMatrixCOO(A))
end

function GenericSparseDiagMatrix(A::GenericSparseMatrixCSC{Tv, Ti}) where {Tv, Ti}
# Convert to CPU SparseMatrixCSC first, then build DIA on CPU, then adapt
A_cpu = SparseMatrixCSC(A)
dia_cpu = GenericSparseDiagMatrix(A_cpu)

# Adapt back to original backend
backend = get_backend(A.nzval)
if backend isa KernelAbstractions.CPU
return dia_cpu
else
return Adapt.adapt_structure(backend, dia_cpu)
end
end

function GenericSparseMatrixCSC(A::Transpose{Tv, <:GenericSparseDiagMatrix}) where {Tv}
return GenericSparseMatrixCSC(GenericSparseMatrixCOO(A))
end

function GenericSparseMatrixCSC(A::Adjoint{Tv, <:GenericSparseDiagMatrix}) where {Tv}
return GenericSparseMatrixCSC(GenericSparseMatrixCOO(A))
end

# ============================================================================
# DIA ↔ CSR Conversions
# ============================================================================

function GenericSparseMatrixCSR(A::GenericSparseDiagMatrix)
return GenericSparseMatrixCSR(GenericSparseMatrixCOO(A))
end

function GenericSparseDiagMatrix(A::GenericSparseMatrixCSR{Tv, Ti}) where {Tv, Ti}
return GenericSparseDiagMatrix(GenericSparseMatrixCSC(A))
end

# ============================================================================
# SparseMatrixCSC ↔ GenericSparseDiagMatrix Conversions
# ============================================================================

function SparseMatrixCSC(A::GenericSparseDiagMatrix)
return SparseMatrixCSC(GenericSparseMatrixCOO(A))
end

function SparseMatrixCSC(A::Transpose{Tv, <:GenericSparseDiagMatrix}) where {Tv}
return SparseMatrixCSC(GenericSparseMatrixCOO(A))
end

function SparseMatrixCSC(A::Adjoint{Tv, <:GenericSparseDiagMatrix}) where {Tv}
return SparseMatrixCSC(GenericSparseMatrixCOO(A))
end
Loading
Loading