diff --git a/src/GenericSparseArrays.jl b/src/GenericSparseArrays.jl index 621a78d..18e846b 100644 --- a/src/GenericSparseArrays.jl +++ b/src/GenericSparseArrays.jl @@ -30,7 +30,8 @@ export AbstractGenericSparseArray, AbstractGenericSparseVector, AbstractGenericSparseMatrix, AbstractGenericSparseVecOrMat export GenericSparseVector, - GenericSparseMatrixCSC, GenericSparseMatrixCSR, GenericSparseMatrixCOO + GenericSparseMatrixCSC, GenericSparseMatrixCSR, GenericSparseMatrixCOO, + GenericSparseDiagMatrix include("core.jl") include("helpers.jl") @@ -45,6 +46,9 @@ include("matrix_csr/matrix_csr.jl") include("matrix_coo/matrix_coo_kernels.jl") include("matrix_coo/matrix_coo.jl") +include("matrix_diag/matrix_diag_kernels.jl") +include("matrix_diag/matrix_diag.jl") + include("conversions/conversion_kernels.jl") include("conversions/conversions.jl") diff --git a/src/conversions/conversions.jl b/src/conversions/conversions.jl index e29e565..ca330d7 100644 --- a/src/conversions/conversions.jl +++ b/src/conversions/conversions.jl @@ -344,3 +344,119 @@ function GenericSparseMatrixCOO(A::Adjoint{Tv, <:GenericSparseMatrixCSR}) where conj.(parent_coo.nzval), ) end + +# ============================================================================ +# DIA ↔ COO Conversions +# ============================================================================ + +function GenericSparseMatrixCOO(A::GenericSparseDiagMatrix{Tv, Ti}) where {Tv, Ti} + m, n = size(A) + nnz_count = nnz(A) + + if nnz_count == 0 + rowind = similar(A.nzval, Ti, 0) + colind = similar(A.nzval, Ti, 0) + nzval = similar(A.nzval, Tv, 0) + return GenericSparseMatrixCOO(m, n, rowind, colind, nzval) + end + + backend = get_backend(A.nzval) + + # Allocate output arrays on the same backend + rowind = similar(A.nzval, Ti, nnz_count) + colind = similar(A.nzval, Ti, nnz_count) + nzval = similar(A.nzval, Tv, nnz_count) + + # Use kernel to convert DIA to COO + kernel! = kernel_diag_to_coo!(backend) + kernel!(rowind, colind, nzval, A.offsets, A.diag_ptrs, A.nzval, m; ndrange = (nnz_count,)) + + return GenericSparseMatrixCOO(m, n, rowind, colind, nzval) +end + +function GenericSparseDiagMatrix(A::GenericSparseMatrixCOO{Tv, Ti}) where {Tv, Ti} + # Convert COO -> CSC -> DIA (via CPU path for structure building) + A_csc = GenericSparseMatrixCSC(A) + return GenericSparseDiagMatrix(A_csc) +end + +# Transpose and Adjoint conversions for DIA to COO +function GenericSparseMatrixCOO(A::Transpose{Tv, <:GenericSparseDiagMatrix}) where {Tv} + parent_coo = GenericSparseMatrixCOO(A.parent) + return GenericSparseMatrixCOO( + size(A, 1), + size(A, 2), + parent_coo.colind, + parent_coo.rowind, + parent_coo.nzval, + ) +end + +function GenericSparseMatrixCOO(A::Adjoint{Tv, <:GenericSparseDiagMatrix}) where {Tv} + parent_coo = GenericSparseMatrixCOO(A.parent) + return GenericSparseMatrixCOO( + size(A, 1), + size(A, 2), + parent_coo.colind, + parent_coo.rowind, + conj.(parent_coo.nzval), + ) +end + +# ============================================================================ +# DIA ↔ CSC Conversions +# ============================================================================ + +function GenericSparseMatrixCSC(A::GenericSparseDiagMatrix) + return GenericSparseMatrixCSC(GenericSparseMatrixCOO(A)) +end + +function GenericSparseDiagMatrix(A::GenericSparseMatrixCSC{Tv, Ti}) where {Tv, Ti} + # Convert to CPU SparseMatrixCSC first, then build DIA on CPU, then adapt + A_cpu = SparseMatrixCSC(A) + dia_cpu = GenericSparseDiagMatrix(A_cpu) + + # Adapt back to original backend + backend = get_backend(A.nzval) + if backend isa KernelAbstractions.CPU + return dia_cpu + else + return Adapt.adapt_structure(backend, dia_cpu) + end +end + +function GenericSparseMatrixCSC(A::Transpose{Tv, <:GenericSparseDiagMatrix}) where {Tv} + return GenericSparseMatrixCSC(GenericSparseMatrixCOO(A)) +end + +function GenericSparseMatrixCSC(A::Adjoint{Tv, <:GenericSparseDiagMatrix}) where {Tv} + return GenericSparseMatrixCSC(GenericSparseMatrixCOO(A)) +end + +# ============================================================================ +# DIA ↔ CSR Conversions +# ============================================================================ + +function GenericSparseMatrixCSR(A::GenericSparseDiagMatrix) + return GenericSparseMatrixCSR(GenericSparseMatrixCOO(A)) +end + +function GenericSparseDiagMatrix(A::GenericSparseMatrixCSR{Tv, Ti}) where {Tv, Ti} + return GenericSparseDiagMatrix(GenericSparseMatrixCSC(A)) +end + +# ============================================================================ +# SparseMatrixCSC ↔ GenericSparseDiagMatrix Conversions +# ============================================================================ + +function SparseMatrixCSC(A::GenericSparseDiagMatrix) + return SparseMatrixCSC(GenericSparseMatrixCOO(A)) +end + +function SparseMatrixCSC(A::Transpose{Tv, <:GenericSparseDiagMatrix}) where {Tv} + return SparseMatrixCSC(GenericSparseMatrixCOO(A)) +end + +function SparseMatrixCSC(A::Adjoint{Tv, <:GenericSparseDiagMatrix}) where {Tv} + return SparseMatrixCSC(GenericSparseMatrixCOO(A)) +end diff --git a/src/matrix_diag/matrix_diag.jl b/src/matrix_diag/matrix_diag.jl new file mode 100644 index 0000000..5c45b00 --- /dev/null +++ b/src/matrix_diag/matrix_diag.jl @@ -0,0 +1,625 @@ +# GenericSparseDiagMatrix implementation + +""" + GenericSparseDiagMatrix{Tv,Ti,OffsetsT<:AbstractVector{Ti},DiagPtrsT<:AbstractVector{Ti},NzValT<:AbstractVector{Tv}} <: AbstractGenericSparseMatrix{Tv,Ti} + +Sparse Diagonal (DIA) format matrix with generic storage vectors. Stores the +matrix as a collection of diagonals. Each diagonal is identified by an offset +from the main diagonal (0 = main, positive = above, negative = below). + +The values of all diagonals are packed contiguously in `nzval`, and `diag_ptrs` +indicates where each diagonal's data starts (similar to `colptr` in CSC). + +# Fields +- `m::Int` - number of rows +- `n::Int` - number of columns +- `offsets::OffsetsT` - diagonal offsets (length = number of diagonals) +- `diag_ptrs::DiagPtrsT` - pointer to the start of each diagonal in nzval (length = ndiags+1) +- `nzval::NzValT` - stored values (packed diagonals) +""" +struct GenericSparseDiagMatrix{ + Tv, + Ti, + OffsetsT <: AbstractVector{Ti}, + DiagPtrsT <: AbstractVector{Ti}, + NzValT <: AbstractVector{Tv}, + } <: AbstractGenericSparseMatrix{Tv, Ti} + m::Int + n::Int + offsets::OffsetsT + diag_ptrs::DiagPtrsT + nzval::NzValT + + function GenericSparseDiagMatrix( + m::Integer, + n::Integer, + offsets::OffsetsT, + diag_ptrs::DiagPtrsT, + nzval::NzValT, + ) where { + Tv, + Ti, + OffsetsT <: AbstractVector{Ti}, + DiagPtrsT <: AbstractVector{Ti}, + NzValT <: AbstractVector{Tv}, + } + get_backend(offsets) == get_backend(diag_ptrs) == get_backend(nzval) || + throw(ArgumentError("All storage vectors must be on the same device/backend.")) + + m >= 0 || throw(ArgumentError("m must be non-negative")) + n >= 0 || throw(ArgumentError("n must be non-negative")) + + ndiags = length(offsets) + length(diag_ptrs) == ndiags + 1 || + throw(ArgumentError("diag_ptrs length must be ndiags+1")) + + return new{Tv, Ti, OffsetsT, DiagPtrsT, NzValT}( + Int(m), + Int(n), + copy(offsets), + copy(diag_ptrs), + copy(nzval), + ) + end +end + +""" + _diag_length(m, n, d) + +Compute the length of the diagonal at offset `d` for an m×n matrix. +""" +function _diag_length(m::Integer, n::Integer, d::Integer) + if d >= 0 + return min(m, n - d) + else + return min(m + d, n) + end +end + +""" + GenericSparseDiagMatrix(m, n, diag_offsets, diag_values) + +Construct a `GenericSparseDiagMatrix` from a vector of diagonal offsets and +a vector of diagonal value vectors. +""" +function GenericSparseDiagMatrix( + m::Integer, + n::Integer, + diag_offsets::AbstractVector{Ti}, + diag_values::AbstractVector{<:AbstractVector{Tv}}, + ) where {Tv, Ti <: Integer} + ndiags = length(diag_offsets) + length(diag_values) == ndiags || + throw(ArgumentError("diag_offsets and diag_values must have the same length")) + + # Validate diagonal lengths + for k in 1:ndiags + d = diag_offsets[k] + expected_len = _diag_length(m, n, d) + expected_len >= 0 || + throw(ArgumentError("Invalid diagonal offset $d for $m×$n matrix")) + length(diag_values[k]) == expected_len || + throw( + ArgumentError( + "Diagonal at offset $d should have length $expected_len, got $(length(diag_values[k]))", + ), + ) + end + + # Build diag_ptrs on CPU + diag_ptrs_cpu = Vector{Ti}(undef, ndiags + 1) + diag_ptrs_cpu[1] = one(Ti) + for k in 1:ndiags + diag_ptrs_cpu[k + 1] = diag_ptrs_cpu[k] + Ti(length(diag_values[k])) + end + + # Concatenate all diagonal values + total_nnz = diag_ptrs_cpu[end] - one(Ti) + nzval = similar(diag_values[1], Tv, Int(total_nnz)) + pos = 1 + for k in 1:ndiags + len = length(diag_values[k]) + nzval[pos:(pos + len - 1)] .= diag_values[k] + pos += len + end + + # Build offsets + offsets = similar(diag_values[1], Ti, ndiags) + offsets .= diag_offsets + + diag_ptrs = similar(diag_values[1], Ti, ndiags + 1) + diag_ptrs .= diag_ptrs_cpu + + return GenericSparseDiagMatrix(m, n, offsets, diag_ptrs, nzval) +end + +# Conversion from SparseMatrixCSC +function GenericSparseDiagMatrix(A::SparseMatrixCSC{Tv, Ti}) where {Tv, Ti} + m, n = size(A) + + # Find all non-zero diagonals + diag_set = Set{Ti}() + for col in 1:n + for j in nzrange(A, col) + row = rowvals(A)[j] + d = Ti(col - row) + push!(diag_set, d) + end + end + + diag_offsets = sort!(collect(diag_set)) + ndiags = length(diag_offsets) + + if ndiags == 0 + offsets = Vector{Ti}() + diag_ptrs = Ti[one(Ti)] + nzval = Vector{Tv}() + return GenericSparseDiagMatrix(m, n, offsets, diag_ptrs, nzval) + end + + # Build diag_ptrs + diag_ptrs = Vector{Ti}(undef, ndiags + 1) + diag_ptrs[1] = one(Ti) + for k in 1:ndiags + d = diag_offsets[k] + diag_ptrs[k + 1] = diag_ptrs[k] + Ti(_diag_length(m, n, Int(d))) + end + + total_nnz = Int(diag_ptrs[end] - one(Ti)) + nzval = zeros(Tv, total_nnz) + + # Create a map from offset to diagonal index + offset_to_idx = Dict{Ti, Int}() + for k in 1:ndiags + offset_to_idx[diag_offsets[k]] = k + end + + # Fill nzval + for col in 1:n + for j in nzrange(A, col) + row = rowvals(A)[j] + d = Ti(col - row) + k = offset_to_idx[d] + row_start = max(1, 1 - Int(d)) + local_idx = row - row_start + 1 + nzval[Int(diag_ptrs[k]) + local_idx - 1] = nonzeros(A)[j] + end + end + + return GenericSparseDiagMatrix(m, n, diag_offsets, diag_ptrs, nzval) +end + +Adapt.adapt_structure(to, A::GenericSparseDiagMatrix) = GenericSparseDiagMatrix( + A.m, + A.n, + Adapt.adapt_structure(to, A.offsets), + Adapt.adapt_structure(to, A.diag_ptrs), + Adapt.adapt_structure(to, A.nzval), +) + +Base.size(A::GenericSparseDiagMatrix) = (A.m, A.n) +Base.length(A::GenericSparseDiagMatrix) = A.m * A.n +Base.copy(A::GenericSparseDiagMatrix) = + GenericSparseDiagMatrix(A.m, A.n, copy(A.offsets), copy(A.diag_ptrs), copy(A.nzval)) + +Base.collect(A::GenericSparseDiagMatrix) = collect(SparseMatrixCSC(A)) + +function Base.zero(A::GenericSparseDiagMatrix) + offsets = similar(A.offsets, 0) + diag_ptrs = similar(A.diag_ptrs, 1) + fill!(diag_ptrs, one(eltype(diag_ptrs))) + nzval = similar(A.nzval, 0) + return GenericSparseDiagMatrix(A.m, A.n, offsets, diag_ptrs, nzval) +end + +function Base.:-(A::GenericSparseDiagMatrix) + return GenericSparseDiagMatrix(A.m, A.n, copy(A.offsets), copy(A.diag_ptrs), -A.nzval) +end + +Base.conj(A::GenericSparseDiagMatrix{<:Real}) = A +function Base.conj(A::GenericSparseDiagMatrix{<:Complex}) + return GenericSparseDiagMatrix( + A.m, A.n, copy(A.offsets), copy(A.diag_ptrs), conj.(A.nzval), + ) +end + +Base.real(A::GenericSparseDiagMatrix{<:Real}) = A +function Base.real(A::GenericSparseDiagMatrix{<:Complex}) + return GenericSparseDiagMatrix( + A.m, A.n, copy(A.offsets), copy(A.diag_ptrs), real.(A.nzval), + ) +end + +Base.imag(A::GenericSparseDiagMatrix{<:Real}) = zero(A) +function Base.imag(A::GenericSparseDiagMatrix{<:Complex}) + return GenericSparseDiagMatrix( + A.m, A.n, copy(A.offsets), copy(A.diag_ptrs), imag.(A.nzval), + ) +end + +SparseArrays.nonzeros(A::GenericSparseDiagMatrix) = A.nzval +getoffsets(A::GenericSparseDiagMatrix) = A.offsets +getdiagptrs(A::GenericSparseDiagMatrix) = A.diag_ptrs + +function LinearAlgebra.tr(A::GenericSparseDiagMatrix) + m, n = size(A) + m == n || throw(DimensionMismatch("Matrix must be square to compute the trace.")) + + # The trace is the sum of the main diagonal (offset 0). + # Find if offset 0 exists + backend = get_backend(A) + offsets = A.offsets + diag_ptrs = A.diag_ptrs + + # We need to find the diagonal with offset 0 + # To avoid scalar indexing, search on CPU + offsets_cpu = collect(offsets) + diag_ptrs_cpu = collect(diag_ptrs) + + for k in eachindex(offsets_cpu) + if offsets_cpu[k] == 0 + start_idx = diag_ptrs_cpu[k] + end_idx = diag_ptrs_cpu[k + 1] - 1 + return sum(A.nzval[start_idx:end_idx]) + end + end + + return zero(eltype(A)) +end + +# Matrix-Vector and Matrix-Matrix multiplication +for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSparseDiagMatrix) + for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:DenseVecOrMat) + TypeA = wrapa(:(T1)) + TypeB = wrapb(:(T2)) + TypeC = :(DenseVecOrMat{T3}) + + kernel_spmatmul! = transa ? :kernel_spmatmul_diag_T! : :kernel_spmatmul_diag_N! + + @eval function LinearAlgebra.mul!( + C::$TypeC, + A::$TypeA, + B::$TypeB, + α::Number, + β::Number, + ) where {$(whereT1(:T1)), $(whereT2(:T2)), T3} + size(A, 2) == size(B, 1) || throw( + DimensionMismatch( + "second dimension of A, $(size(A, 2)), does not match the first dimension of B, $(size(B, 1))", + ), + ) + size(A, 1) == size(C, 1) || throw( + DimensionMismatch( + "first dimension of A, $(size(A, 1)), does not match the first dimension of C, $(size(C, 1))", + ), + ) + size(B, 2) == size(C, 2) || throw( + DimensionMismatch( + "second dimension of B, $(size(B, 2)), does not match the second dimension of C, $(size(C, 2))", + ), + ) + + promote_type(T1, T2, eltype(α), eltype(β)) <: T3 || throw( + ArgumentError( + "element types of A, B, α, and β must be promotable to the element type of C", + ), + ) + + _A = $(unwrapa(:A)) + _B = $(unwrapb(:B)) + + backend_C = get_backend(C) + backend_A = get_backend(_A) + backend_B = get_backend(_B) + + backend_A == backend_B == backend_C || + throw(ArgumentError("All arrays must have the same backend")) + + β != one(β) && LinearAlgebra._rmul_or_fill!(C, β) + + total_nnz = nnz(_A) + total_nnz == 0 && return C + + kernel! = $kernel_spmatmul!(backend_A) + kernel!( + C, + getoffsets(_A), + getdiagptrs(_A), + getnzval(_A), + _B, + α, + Val{$conja}(), + Val{$conjb}(), + Val{$transb}(); + ndrange = (total_nnz, size(C, 2)), + ) + + return C + end + end +end + +# Three-argument dot product: dot(x, A, y) = x' * A * y +for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSparseDiagMatrix) + TypeA = wrapa(:(T1)) + + kernel_dot! = transa ? :kernel_workgroup_dot_diag_T! : :kernel_workgroup_dot_diag_N! + + @eval function LinearAlgebra.dot( + x::AbstractVector{T2}, + A::$TypeA, + y::AbstractVector{T3}, + ) where {$(whereT1(:T1)), T2, T3} + size(A, 1) == length(x) || throw( + DimensionMismatch( + "first dimension of A, $(size(A, 1)), does not match the length of x, $(length(x))", + ), + ) + size(A, 2) == length(y) || throw( + DimensionMismatch( + "second dimension of A, $(size(A, 2)), does not match the length of y, $(length(y))", + ), + ) + + _A = $(unwrapa(:A)) + + backend_x = get_backend(x) + backend_A = get_backend(_A) + backend_y = get_backend(y) + + backend_x == backend_A == backend_y || + throw(ArgumentError("All arrays must have the same backend")) + + T = promote_type(T1, T2, T3) + + total_nnz = nnz(_A) + total_nnz == 0 && return zero(T) + + nzval = getnzval(_A) + + backend = backend_A + + group_size = 256 + n_groups = min(cld(total_nnz, group_size), 256) + total_workitems = group_size * n_groups + + # Allocate array for block results (one per workgroup) + block_results = similar(nzval, T, n_groups) + + # Launch kernel with workgroup configuration + kernel! = $kernel_dot!(backend, group_size) + kernel!( + block_results, + x, + getoffsets(_A), + getdiagptrs(_A), + nzval, + y, + total_nnz, + Val{$conja}(); + ndrange = (total_workitems,), + ) + + # Final reduction: sum all block results + return sum(block_results) + end +end + +# Helper function for adding GenericSparseDiagMatrix to dense matrix +function _add_sparse_to_dense!(C::DenseMatrix, A::GenericSparseDiagMatrix) + backend = get_backend(A) + nnz_val = nnz(A) + nnz_val == 0 && return C + + kernel! = kernel_add_sparse_to_dense_diag!(backend) + kernel!(C, getoffsets(A), getdiagptrs(A), getnzval(A), A.m; ndrange = (nnz_val,)) + + return C +end + +# Addition between two GenericSparseDiagMatrix: convert to COO, add, convert back +function Base.:+(A::GenericSparseDiagMatrix, B::GenericSparseDiagMatrix) + size(A) == size(B) || throw( + DimensionMismatch( + "dimensions must match: A has dims $(size(A)), B has dims $(size(B))", + ), + ) + + backend_A = get_backend(A) + backend_B = get_backend(B) + backend_A == backend_B || + throw(ArgumentError("Both matrices must have the same backend")) + + # Convert to COO, add, convert back + A_coo = GenericSparseMatrixCOO(A) + B_coo = GenericSparseMatrixCOO(B) + C_coo = A_coo + B_coo + return GenericSparseDiagMatrix(C_coo) +end + +# Addition with transpose/adjoint support +for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSparseDiagMatrix) + for (wrapb, transb, conjb, unwrapb, whereT2) in + trans_adj_wrappers(:GenericSparseDiagMatrix) + # Skip the case where both are not transposed (already handled above) + (transa == false && transb == false) && continue + + TypeA = wrapa(:(T1)) + TypeB = wrapb(:(T2)) + + @eval function Base.:+(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)), $(whereT2(:T2))} + size(A) == size(B) || throw( + DimensionMismatch( + "dimensions must match: A has dims $(size(A)), B has dims $(size(B))", + ), + ) + + # Convert to COO, add, convert back + A_coo = GenericSparseMatrixCOO(A) + B_coo = GenericSparseMatrixCOO(B) + C_coo = A_coo + B_coo + return GenericSparseDiagMatrix(C_coo) + end + + @eval function Base.:-(A::$TypeA, B::$TypeB) where {$(whereT1(:T1)), $(whereT2(:T2))} + return A + (-B) + end + end +end + +# Addition with UniformScaling +function Base.:+(A::GenericSparseDiagMatrix{Tv, Ti}, J::UniformScaling) where {Tv, Ti} + m, n = size(A) + m == n || throw(DimensionMismatch("Matrix must be square to add UniformScaling.")) + λ = J.λ + iszero(λ) && return copy(A) + + # Convert to COO, add UniformScaling, convert back + A_coo = GenericSparseMatrixCOO(A) + C_coo = A_coo + J + return GenericSparseDiagMatrix(C_coo) +end + +# Sparse-sparse multiplication: delegate to COO +function Base.:(*)(A::GenericSparseDiagMatrix, B::GenericSparseDiagMatrix) + size(A, 2) == size(B, 1) || throw( + DimensionMismatch( + "second dimension of A, $(size(A, 2)), does not match first dimension of B, $(size(B, 1))", + ), + ) + + backend_A = get_backend(A) + backend_B = get_backend(B) + backend_A == backend_B || + throw(ArgumentError("Both matrices must have the same backend")) + + # Convert to COO, multiply via CSC, convert back + A_coo = GenericSparseMatrixCOO(A) + B_coo = GenericSparseMatrixCOO(B) + C_coo = A_coo * B_coo + return GenericSparseDiagMatrix(C_coo) +end + +# Multiplication with transpose/adjoint support +for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSparseDiagMatrix) + for (wrapb, transb, conjb, unwrapb, whereT2) in + trans_adj_wrappers(:GenericSparseDiagMatrix) + # Skip the case where both are not transposed (already handled above) + (transa == false && transb == false) && continue + + TypeA = wrapa(:(T1)) + TypeB = wrapb(:(T2)) + + @eval function Base.:(*)( + A::$TypeA, + B::$TypeB, + ) where {$(whereT1(:T1)), $(whereT2(:T2))} + size(A, 2) == size(B, 1) || throw( + DimensionMismatch( + "second dimension of A, $(size(A, 2)), does not match first dimension of B, $(size(B, 1))", + ), + ) + + backend_A = get_backend($(unwrapa(:A))) + backend_B = get_backend($(unwrapb(:B))) + backend_A == backend_B || + throw(ArgumentError("Both matrices must have the same backend")) + + # Convert to COO, multiply via CSC, convert back + A_coo = GenericSparseMatrixCOO(A) + B_coo = GenericSparseMatrixCOO(B) + C_coo = A_coo * B_coo + return GenericSparseDiagMatrix(C_coo) + end + end +end + +# Kronecker product: delegate to COO +function LinearAlgebra.kron( + A::GenericSparseDiagMatrix{Tv1, Ti1}, + B::GenericSparseDiagMatrix{Tv2, Ti2}, + ) where {Tv1, Ti1, Tv2, Ti2} + backend_A = get_backend(A) + backend_B = get_backend(B) + backend_A == backend_B || throw(ArgumentError("Both arrays must have the same backend")) + + A_coo = GenericSparseMatrixCOO(A) + B_coo = GenericSparseMatrixCOO(B) + C_coo = kron(A_coo, B_coo) + return GenericSparseDiagMatrix(C_coo) +end + +for (wrapa, transa, conja, unwrapa, whereT1) in trans_adj_wrappers(:GenericSparseDiagMatrix) + for (wrapb, transb, conjb, unwrapb, whereT2) in trans_adj_wrappers(:GenericSparseDiagMatrix) + # Skip the case where both are not transposed (already handled above) + (transa == false && transb == false) && continue + + TypeA = wrapa(:(T1)) + TypeB = wrapb(:(T2)) + + @eval function LinearAlgebra.kron( + A::$TypeA, + B::$TypeB, + ) where {$(whereT1(:T1)), $(whereT2(:T2))} + return kron(GenericSparseMatrixCOO(A), GenericSparseMatrixCOO(B)) |> GenericSparseDiagMatrix + end + end +end + +# kron with Diagonal +function LinearAlgebra.kron( + D::Diagonal{Tv1}, + B::GenericSparseDiagMatrix{Tv2, Ti}, + ) where {Tv1, Tv2, Ti} + B_coo = GenericSparseMatrixCOO(B) + C_coo = kron(D, B_coo) + return GenericSparseDiagMatrix(C_coo) +end + +function LinearAlgebra.kron( + A::GenericSparseDiagMatrix{Tv1, Ti}, + D::Diagonal{Tv2}, + ) where {Tv1, Ti, Tv2} + A_coo = GenericSparseMatrixCOO(A) + C_coo = kron(A_coo, D) + return GenericSparseDiagMatrix(C_coo) +end + +# kron with Diagonal and transpose/adjoint wrappers +for (wrap, trans, conj, unwrap, whereT) in trans_adj_wrappers(:GenericSparseDiagMatrix) + # Skip identity case (already handled above) + trans == false && continue + + TypeB = wrap(:(T)) + + # kron(D, op(B)) + @eval function LinearAlgebra.kron( + D::Diagonal{Tv1}, + B::$TypeB, + ) where {Tv1, $(whereT(:T))} + B_coo = GenericSparseMatrixCOO(B) + return kron(D, B_coo) |> GenericSparseDiagMatrix + end + + # kron(op(A), D) + TypeA = wrap(:(T)) + @eval function LinearAlgebra.kron( + A::$TypeA, + D::Diagonal{Tv2}, + ) where {$(whereT(:T)), Tv2} + A_coo = GenericSparseMatrixCOO(A) + return kron(A_coo, D) |> GenericSparseDiagMatrix + end +end + +function LinearAlgebra.issymmetric(A::GenericSparseDiagMatrix) + m, n = size(A) + m == n || return false + return issymmetric(GenericSparseMatrixCSC(A)) +end + +function LinearAlgebra.ishermitian(A::GenericSparseDiagMatrix) + m, n = size(A) + m == n || return false + return ishermitian(GenericSparseMatrixCSC(A)) +end diff --git a/src/matrix_diag/matrix_diag_kernels.jl b/src/matrix_diag/matrix_diag_kernels.jl new file mode 100644 index 0000000..61d9071 --- /dev/null +++ b/src/matrix_diag/matrix_diag_kernels.jl @@ -0,0 +1,192 @@ +# Helper function to find which diagonal a flat nzval index belongs to +# and compute the (row, col) position. Used by all DIA kernels. +# Returns (i, j, found) where found is whether idx is valid. +@inline function _diag_index_to_ij(idx, offsets, diag_ptrs) + ndiags = length(offsets) + lo = 1 + hi = ndiags + diag_idx = 1 + while lo <= hi + mid = (lo + hi) ÷ 2 + if diag_ptrs[mid] <= idx + diag_idx = mid + lo = mid + 1 + else + hi = mid - 1 + end + end + + d = offsets[diag_idx] + diag_start = diag_ptrs[diag_idx] + local_pos = idx - diag_start + 1 + + row_start = max(1, 1 - d) + i = row_start + local_pos - 1 + j = i + d + + return (i, j) +end + +# Kernels for SpMV: y = α * A * x + β * y (non-transposed) +# Launched with ndrange = (total_nnz, ncols_C) - one work item per (nz entry, output column) +@kernel inbounds = true function kernel_spmatmul_diag_N!( + C, + @Const(offsets), + @Const(diag_ptrs), + @Const(nzval), + @Const(B), + α, + ::Val{conjA}, + ::Val{conjB}, + ::Val{transB}, + ) where {conjA, conjB, transB} + nz_idx, k = @index(Global, NTuple) + + i, j = _diag_index_to_ij(nz_idx, offsets, diag_ptrs) + + val = nzval[nz_idx] + val = conjA ? conj(val) : val + + Bi, Bj = transB ? (k, j) : (j, k) + b_val = conjB ? conj(B[Bi, Bj]) : B[Bi, Bj] + @atomic C[i, k] += α * val * b_val +end + +# Kernels for SpMV: y = α * A' * x + β * y (transposed) +# A' has size (n, m), so row of A' = col of A, col of A' = row of A +@kernel inbounds = true function kernel_spmatmul_diag_T!( + C, + @Const(offsets), + @Const(diag_ptrs), + @Const(nzval), + @Const(B), + α, + ::Val{conjA}, + ::Val{conjB}, + ::Val{transB}, + ) where {conjA, conjB, transB} + nz_idx, k = @index(Global, NTuple) + + i, j = _diag_index_to_ij(nz_idx, offsets, diag_ptrs) + + # In original A: entry at (i, j) with value val + # In A^T: entry at (j, i), so row=j, col=i + val = nzval[nz_idx] + val = conjA ? conj(val) : val + + Bi, Bj = transB ? (k, i) : (i, k) + b_val = conjB ? conj(B[Bi, Bj]) : B[Bi, Bj] + @atomic C[j, k] += α * val * b_val +end + +# Kernel for three-argument dot: dot(x, A, y) = x' * A * y (non-transposed) +@kernel inbounds = true unsafe_indices = true function kernel_workgroup_dot_diag_N!( + block_results, + @Const(x), + @Const(offsets), + @Const(diag_ptrs), + @Const(nzval), + @Const(y), + @Const(total_nnz), + ::Val{conjA}, + ) where {conjA} + local_id = @index(Local, Linear) + group_id = @index(Group, Linear) + global_id = @index(Global, Linear) + + workgroup_size = @uniform @groupsize()[1] + stride = @uniform @ndrange()[1] + + shared = @localmem(eltype(block_results), workgroup_size) + + local_sum = zero(eltype(block_results)) + for idx in global_id:stride:total_nnz + i, j = _diag_index_to_ij(idx, offsets, diag_ptrs) + val = conjA ? conj(nzval[idx]) : nzval[idx] + local_sum += dot(x[i], val, y[j]) + end + + shared[local_id] = local_sum + @synchronize() + + if local_id == 1 + s = zero(eltype(block_results)) + for k in 1:workgroup_size + s += shared[k] + end + block_results[group_id] = s + end +end + +# Kernel for three-argument dot: dot(x, A', y) (transposed) +@kernel inbounds = true unsafe_indices = true function kernel_workgroup_dot_diag_T!( + block_results, + @Const(x), + @Const(offsets), + @Const(diag_ptrs), + @Const(nzval), + @Const(y), + @Const(total_nnz), + ::Val{conjA}, + ) where {conjA} + local_id = @index(Local, Linear) + group_id = @index(Group, Linear) + global_id = @index(Global, Linear) + + workgroup_size = @uniform @groupsize()[1] + stride = @uniform @ndrange()[1] + + shared = @localmem(eltype(block_results), workgroup_size) + + local_sum = zero(eltype(block_results)) + for idx in global_id:stride:total_nnz + i, j = _diag_index_to_ij(idx, offsets, diag_ptrs) + # A^T[j, i] = A[i, j], so x is indexed by j, y by i + val = conjA ? conj(nzval[idx]) : nzval[idx] + local_sum += dot(x[j], val, y[i]) + end + + shared[local_id] = local_sum + @synchronize() + + if local_id == 1 + s = zero(eltype(block_results)) + for k in 1:workgroup_size + s += shared[k] + end + block_results[group_id] = s + end +end + +# Kernel for adding sparse DIA entries to a dense matrix +@kernel inbounds = true function kernel_add_sparse_to_dense_diag!( + C, + @Const(offsets), + @Const(diag_ptrs), + @Const(nzval), + m, + ) + idx = @index(Global) + + i, j = _diag_index_to_ij(idx, offsets, diag_ptrs) + C[i, j] += nzval[idx] +end + +# Kernel for converting DIA format to COO format +@kernel inbounds = true function kernel_diag_to_coo!( + rowind, + colind, + nzval_out, + @Const(offsets), + @Const(diag_ptrs), + @Const(nzval_in), + m, + ) + idx = @index(Global) + + i, j = _diag_index_to_ij(idx, offsets, diag_ptrs) + + rowind[idx] = i + colind[idx] = j + nzval_out[idx] = nzval_in[idx] +end diff --git a/test/runtests.jl b/test/runtests.jl index 2ca9daf..801dcc1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,7 @@ include(joinpath(@__DIR__, "shared", "vector.jl")) include(joinpath(@__DIR__, "shared", "matrix_csc.jl")) include(joinpath(@__DIR__, "shared", "matrix_csr.jl")) include(joinpath(@__DIR__, "shared", "matrix_coo.jl")) +include(joinpath(@__DIR__, "shared", "matrix_diag.jl")) include(joinpath(@__DIR__, "shared", "conversions.jl")) const GROUP_LIST = ("All", "Code-Quality", "CPU", "CUDA", "Metal", "Reactant") @@ -59,6 +60,13 @@ if GROUP in ("All", "CPU") (Float32, Float64), (ComplexF32, ComplexF64), ) + shared_test_matrix_diag( + func, + name, + (Int32, Int64), + (Float32, Float64), + (ComplexF32, ComplexF64), + ) shared_test_conversions( func, name, diff --git a/test/shared/matrix_diag.jl b/test/shared/matrix_diag.jl new file mode 100644 index 0000000..871c877 --- /dev/null +++ b/test/shared/matrix_diag.jl @@ -0,0 +1,443 @@ +function shared_test_matrix_diag( + op, + array_type::String, + int_types::Tuple, + float_types::Tuple, + complex_types::Tuple, + ) + return @testset "GenericSparseDiagMatrix $array_type" verbose = true begin + shared_test_conversion_matrix_diag( + op, + array_type, + int_types, + float_types, + complex_types, + ) + shared_test_linearalgebra_matrix_diag( + op, + array_type, + int_types, + float_types, + complex_types, + ) + end +end + +function shared_test_conversion_matrix_diag( + op, + array_type::String, + int_types::Tuple, + float_types::Tuple, + complex_types::Tuple, + ) + return @testset "Conversion" begin + A = spzeros(Float32, 0, 0) + + # Empty matrix + dA = adapt(op, GenericSparseDiagMatrix(A)) + @test size(dA) == (0, 0) + @test length(dA) == 0 + @test nnz(dA) == 0 + @test SparseMatrixCSC(dA) == A + + # Square matrix with diagonals + B = sparse([1, 2, 1, 2, 3], [1, 2, 2, 3, 3], Float64[1, 2, 3, 4, 5], 3, 3) + dB = adapt(op, GenericSparseDiagMatrix(B)) + @test size(dB) == (3, 3) + @test SparseMatrixCSC(dB) ≈ B + + # Rectangular matrix + C = sparse([1, 2, 1], [1, 2, 3], Float64[1, 2, 3], 3, 4) + dC = adapt(op, GenericSparseDiagMatrix(C)) + @test size(dC) == (3, 4) + @test SparseMatrixCSC(dC) ≈ C + + # Test conversion from COO + B_coo = adapt(op, GenericSparseMatrixCOO(B)) + dB_from_coo = GenericSparseDiagMatrix(B_coo) + @test SparseMatrixCSC(dB_from_coo) ≈ B + + # Test conversion from CSR + B_csr = adapt(op, GenericSparseMatrixCSR(B)) + dB_from_csr = GenericSparseDiagMatrix(B_csr) + @test SparseMatrixCSC(dB_from_csr) ≈ B + + # Test conversion to COO + dB2 = adapt(op, GenericSparseDiagMatrix(B)) + B_coo2 = GenericSparseMatrixCOO(dB2) + @test SparseMatrixCSC(B_coo2) ≈ B + + # Test conversion to CSC + B_csc = GenericSparseMatrixCSC(dB2) + @test SparseMatrixCSC(B_csc) ≈ B + + # Test conversion to CSR + B_csr2 = GenericSparseMatrixCSR(dB2) + @test SparseMatrixCSC(B_csr2) ≈ B + + # Test transpose/adjoint conversions + dB3 = adapt(op, GenericSparseDiagMatrix(B)) + @test SparseMatrixCSC(transpose(dB3)) ≈ transpose(B) + @test SparseMatrixCSC(adjoint(dB3)) ≈ adjoint(B) + end +end + +function shared_test_linearalgebra_matrix_diag( + op, + array_type::String, + int_types::Tuple, + float_types::Tuple, + complex_types::Tuple, + ) + @testset "Sum and Trace" begin + for T in (int_types..., float_types..., complex_types...) + A = sprand(T, 100, 100, 0.05) + dA = adapt(op, GenericSparseDiagMatrix(A)) + + @test sum(dA) ≈ sum(A) + @test tr(dA) ≈ tr(A) + end + end + + @testset "issymmetric and ishermitian" begin + for T in (complex_types...,) + n = 50 + # Non-symmetric/non-hermitian matrix + A_nonsym = sprand(T, n, n, 0.1) + A_nonsym[1, 2] = 1.0 + 0.0im + A_nonsym[2, 1] = 2.0 + 1.0im + dA_nonsym = adapt(op, GenericSparseDiagMatrix(A_nonsym)) + @test issymmetric(dA_nonsym) == false + @test ishermitian(dA_nonsym) == false + @test issymmetric(transpose(dA_nonsym)) == false + @test ishermitian(adjoint(dA_nonsym)) == false + + # Symmetric matrix (complex symmetric is NOT hermitian) + A_sym = sparse(A_nonsym + transpose(A_nonsym)) + dA_sym = adapt(op, GenericSparseDiagMatrix(A_sym)) + @test issymmetric(dA_sym) == true + @test issymmetric(transpose(dA_sym)) == true + + # Hermitian matrix (complex) + A_herm = sparse(A_nonsym + adjoint(A_nonsym)) + dA_herm = adapt(op, GenericSparseDiagMatrix(A_herm)) + @test ishermitian(dA_herm) == true + @test ishermitian(adjoint(dA_herm)) == true + end + end + + @testset "Three-argument dot" begin + for T in (int_types..., float_types..., complex_types...) + if T in (Int32,) + continue + end + for op_A in (identity, transpose, adjoint) + m, n = op_A === identity ? (100, 80) : (80, 100) + A = sprand(T, m, n, 0.1) + x = rand(T, size(op_A(A), 1)) + y = rand(T, size(op_A(A), 2)) + + dA = adapt(op, GenericSparseDiagMatrix(A)) + dx = op(x) + dy = op(y) + + result_device = dot(dx, op_A(dA), dy) + result_expected = dot(x, op_A(A), y) + + @test result_device ≈ result_expected + end + end + end + + @testset "Scalar Operations" begin + for T in (int_types..., float_types..., complex_types...) + A = sprand(T, 45, 35, 0.1) + dA = adapt(op, GenericSparseDiagMatrix(A)) + + α = T <: Complex ? T(2.0 + 1.5im) : (T <: Integer ? T(2) : T(1.8)) + + # Test scalar multiplication + scaled_left = α * dA + scaled_right = dA * α + @test nnz(scaled_left) == nnz(dA) + @test nnz(scaled_right) == nnz(dA) + @test collect(nonzeros(scaled_left)) ≈ α .* collect(nonzeros(dA)) + @test collect(nonzeros(scaled_right)) ≈ collect(nonzeros(dA)) .* α + + # Test scalar division + if !(T <: Integer) + divided = dA / α + @test nnz(divided) == nnz(dA) + @test collect(nonzeros(divided)) ≈ collect(nonzeros(dA)) ./ α + end + end + end + + @testset "Unary Operations" begin + for T in (float_types..., complex_types...) + A = sprand(T, 28, 22, 0.15) + dA = adapt(op, GenericSparseDiagMatrix(A)) + + # Test unary plus + pos_A = +dA + @test nnz(pos_A) == nnz(dA) + @test collect(nonzeros(pos_A)) ≈ collect(nonzeros(dA)) + + # Test unary minus + neg_A = -dA + @test nnz(neg_A) == nnz(dA) + @test collect(nonzeros(neg_A)) ≈ -collect(nonzeros(dA)) + + # Test complex operations + if T <: Complex + conj_A = conj(dA) + real_A = real(dA) + imag_A = imag(dA) + + @test nnz(conj_A) == nnz(dA) + @test eltype(conj_A) == T + @test collect(nonzeros(conj_A)) ≈ conj.(collect(nonzeros(dA))) + + @test eltype(real_A) == real(T) + @test collect(nonzeros(real_A)) ≈ real.(collect(nonzeros(dA))) + + @test eltype(imag_A) == real(T) + @test collect(nonzeros(imag_A)) ≈ imag.(collect(nonzeros(dA))) + else + # For real types + conj_A = conj(dA) + real_A = real(dA) + imag_A = imag(dA) + + @test conj_A === dA + @test real_A === dA + @test nnz(imag_A) == 0 + end + end + end + + @testset "UniformScaling Multiplication" begin + for T in (float_types..., complex_types...) + A = sprand(T, 18, 18, 0.2) + dA = adapt(op, GenericSparseDiagMatrix(A)) + + # Test A * I (identity) + result_I = dA * I + @test nnz(result_I) == nnz(dA) + @test collect(nonzeros(result_I)) ≈ collect(nonzeros(dA)) + + # Test I * A (identity) + result_I2 = I * dA + @test nnz(result_I2) == nnz(dA) + @test collect(nonzeros(result_I2)) ≈ collect(nonzeros(dA)) + + # Test with scaled identity + α = T <: Complex ? T(1.5 - 0.8im) : T(2.2) + result_αI = dA * (α * I) + @test nnz(result_αI) == nnz(dA) + @test collect(nonzeros(result_αI)) ≈ α .* collect(nonzeros(dA)) + end + end + + @testset "UniformScaling Addition" begin + for T in (float_types..., complex_types...) + # Test with square matrix + A_sq = sprand(T, 20, 20, 0.2) + dA_sq = adapt(op, GenericSparseDiagMatrix(A_sq)) + + # Test A + I (identity) + result_I = dA_sq + I + expected_I = A_sq + I + @test collect(SparseMatrixCSC(result_I)) ≈ collect(expected_I) + + # Test I + A (identity) + result_I2 = I + dA_sq + @test collect(SparseMatrixCSC(result_I2)) ≈ collect(expected_I) + + # Test with scaled identity + α = T <: Complex ? T(2.0 + 1.0im) : T(3.0) + result_αI = dA_sq + (α * I) + expected_αI = A_sq + (α * I) + @test collect(SparseMatrixCSC(result_αI)) ≈ collect(expected_αI) + + # Test subtraction + result_sub = dA_sq - (α * I) + expected_sub = A_sq - (α * I) + @test collect(SparseMatrixCSC(result_sub)) ≈ collect(expected_sub) + + # Test J - A + result_sub2 = (α * I) - dA_sq + expected_sub2 = (α * I) - A_sq + @test collect(SparseMatrixCSC(result_sub2)) ≈ collect(expected_sub2) + + # Test with non-square matrix throws DimensionMismatch + A_nonsq = sprand(T, 30, 20, 0.2) + dA_nonsq = adapt(op, GenericSparseDiagMatrix(A_nonsq)) + @test_throws DimensionMismatch dA_nonsq + I + + # Test with zero λ (should return copy) + result_zero = dA_sq + (zero(T) * I) + @test collect(SparseMatrixCSC(result_zero)) ≈ collect(A_sq) + end + end + + @testset "Matrix-Scalar, Matrix-Vector and Matrix-Matrix multiplication" begin + for T in (int_types..., float_types..., complex_types...) + for (op_A, op_B) in Iterators.product( + (identity, transpose, adjoint), + (identity, transpose, adjoint), + ) + dims_A = op_A === identity ? (100, 80) : (80, 100) + dims_B = op_B === identity ? (80, 50) : (50, 80) + + A = sprand(T, dims_A..., 0.1) + B = rand(T, dims_B...) + b = rand(T, 80) + c = op_A(A) * b + C = op_A(A) * op_B(B) + + dA = adapt(op, GenericSparseDiagMatrix(A)) + + # Matrix-Scalar multiplication + if T != Int32 + @test collect(SparseMatrixCSC(2 * dA)) ≈ 2 * collect(A) + @test collect(SparseMatrixCSC(dA * 2)) ≈ collect(A * 2) + @test collect(SparseMatrixCSC(dA / 2)) ≈ collect(A) / 2 + end + + # Matrix-Vector multiplication + db = op(b) + dc = op_A(dA) * db + @test collect(dc) ≈ c + dc2 = similar(dc) + mul!(dc2, op_A(dA), db) + @test collect(dc2) ≈ c + + # Matrix-Matrix multiplication + dB = op(B) + dC = op_A(dA) * op_B(dB) + @test collect(dC) ≈ C + dC2 = similar(dB, size(C)...) + mul!(dC2, op_A(dA), op_B(dB)) + @test collect(dC2) ≈ C + end + end + end + + @testset "Sparse + Dense Matrix Addition" begin + for T in (int_types..., float_types..., complex_types...) + m, n = 50, 40 + A = sprand(T, m, n, 0.1) + B = rand(T, m, n) + + dA = adapt(op, GenericSparseDiagMatrix(A)) + dB = op(B) + + # Test sparse + dense + result = dA + dB + expected = Matrix(A) + B + @test collect(result) ≈ expected + + # Test dense + sparse (commutative) + result2 = dB + dA + @test collect(result2) ≈ expected + + # Test dimension mismatch + B_wrong = rand(T, m + 1, n) + dB_wrong = op(B_wrong) + @test_throws DimensionMismatch dA + dB_wrong + end + end + + @testset "Sparse + Sparse Matrix Addition" begin + for T in (int_types..., float_types..., complex_types...) + for (op_A, op_B) in Iterators.product( + (identity, transpose, adjoint), + (identity, transpose, adjoint), + ) + + m, n = (op_A === identity && op_B === identity) ? (50, 40) : (30, 30) + dims_A = op_A === identity ? (m, n) : (n, m) + dims_B = op_B === identity ? (m, n) : (n, m) + + A = sprand(T, dims_A..., 0.1) + B = sprand(T, dims_B..., 0.15) + + dA = adapt(op, GenericSparseDiagMatrix(A)) + dB = adapt(op, GenericSparseDiagMatrix(B)) + + # Test sparse + sparse + result = op_A(dA) + op_B(dB) + expected = op_A(A) + op_B(B) + @test collect(SparseMatrixCSC(result)) ≈ Matrix(expected) + @test result isa GenericSparseDiagMatrix + end + end + end + + @testset "Sparse * Sparse Matrix Multiplication" begin + for T in (int_types..., float_types..., complex_types...) + for (op_A, op_B) in Iterators.product( + (identity, transpose, adjoint), + (identity, transpose, adjoint), + ) + + m, k, n = + (op_A === identity && op_B === identity) ? (50, 40, 30) : (30, 30, 30) + dims_A = op_A === identity ? (m, k) : (k, m) + dims_B = op_B === identity ? (k, n) : (n, k) + + A = sprand(T, dims_A..., 0.1) + B = sprand(T, dims_B..., 0.15) + + dA = adapt(op, GenericSparseDiagMatrix(A)) + dB = adapt(op, GenericSparseDiagMatrix(B)) + + # Test sparse * sparse + result = op_A(dA) * op_B(dB) + expected = op_A(A) * op_B(B) + @test collect(SparseMatrixCSC(result)) ≈ Matrix(expected) + @test result isa GenericSparseDiagMatrix + end + end + end + + return @testset "Kronecker Product" begin + for T in (int_types..., float_types..., complex_types...) + A_sparse = sprand(T, 30, 25, 0.1) + B_sparse = sprand(T, 20, 15, 0.1) + D_diag = Diagonal(rand(T, 4)) + + A = adapt(op, GenericSparseDiagMatrix(A_sparse)) + B = adapt(op, GenericSparseDiagMatrix(B_sparse)) + D1 = adapt(op, D_diag) + D2 = Diagonal(FillArrays.Fill(T(2), 4)) + + for (op_A, op_B) in Iterators.product( + (identity, transpose, adjoint), + (identity, transpose, adjoint), + ) + + result = kron(op_A(A), op_B(B)) + expected_A_csc = GenericSparseMatrixCSC(op_A(A)) + expected_B_csc = GenericSparseMatrixCSC(op_B(B)) + expected = kron(SparseMatrixCSC(expected_A_csc), SparseMatrixCSC(expected_B_csc)) + @test collect(SparseMatrixCSC(result)) ≈ collect(expected) + end + + # Test kron(Diagonal, Sparse) and kron(Sparse, Diagonal) + for op_A in (identity, transpose, adjoint) + # kron(D, op(A)) + result1 = kron(D1, op_A(A)) + expected_A_csc = GenericSparseMatrixCSC(op_A(A)) + expected1 = kron(D_diag, SparseMatrixCSC(expected_A_csc)) + @test collect(SparseMatrixCSC(result1)) ≈ collect(expected1) + + # kron(op(A), D) + result2 = kron(op_A(A), D1) + expected2 = kron(SparseMatrixCSC(expected_A_csc), D_diag) + @test collect(SparseMatrixCSC(result2)) ≈ collect(expected2) + end + end + end +end