diff --git a/test/libraries/cusolver/dense_generic.jl b/test/libraries/cusolver/dense_generic.jl index ec63ca82fb..2c59a1e120 100644 --- a/test/libraries/cusolver/dense_generic.jl +++ b/test/libraries/cusolver/dense_generic.jl @@ -124,6 +124,7 @@ p = 5 @testset "sytrs!" begin @testset "uplo = $uplo" for uplo in ('L', 'U') @testset "pivoting = $pivoting" for pivoting in (false, true) + !pivoting && (CUSOLVER.version() < v"11.7.2") && continue A = rand(elty,n,n) B = rand(elty,n,p) C = rand(elty,n) @@ -131,22 +132,26 @@ p = 5 d_A = CuMatrix(A) d_B = CuMatrix(B) d_C = CuVector(C) - !pivoting && (CUSOLVER.version() < v"11.7.2") && continue if pivoting d_A, d_ipiv, _ = CUSOLVER.sytrf!(uplo, d_A; pivoting) d_ipiv = CuVector{Int64}(d_ipiv) CUSOLVER.sytrs!(uplo, d_A, d_ipiv, d_B) CUSOLVER.sytrs!(uplo, d_A, d_ipiv, d_C) + A, ipiv, _ = LAPACK.sytrf!(uplo, A) + LAPACK.sytrs!(uplo, A, ipiv, B) + LAPACK.sytrs!(uplo, A, ipiv, C) + @test B ≈ collect(d_B) + @test C ≈ collect(d_C) else d_A, _ = CUSOLVER.sytrf!(uplo, d_A; pivoting) CUSOLVER.sytrs!(uplo, d_A, d_B) CUSOLVER.sytrs!(uplo, d_A, d_C) + # Verify correctness directly: non-pivoting cusolver cannot be + # compared against LAPACK (which always pivots), so instead check + # that A * x ≈ b for the original inputs. + @test A * collect(d_B) ≈ B + @test A * collect(d_C) ≈ C end - A, ipiv, _ = LAPACK.sytrf!(uplo, A) - LAPACK.sytrs!(uplo, A, ipiv, B) - LAPACK.sytrs!(uplo, A, ipiv, C) - @test B ≈ collect(d_B) - @test C ≈ collect(d_C) end end end