From 0b5f7cdbe23f8ed41ad8041a45dace4eee91c3d0 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Tue, 3 Mar 2026 13:18:03 +0100 Subject: [PATCH] Fix sytrs! test: verify non-pivoting path directly instead of skipping MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The non-pivoting cusolver path was compared against LAPACK which always uses pivoting, making the comparison invalid. For the non-pivoting case, verify correctness directly by checking A * x ≈ b against the original inputs. Keep the LAPACK comparison for the pivoting case where both use the same algorithm. Also move the version check before matrix allocation to avoid unnecessary work when skipping. Co-Authored-By: Claude Opus 4.6 --- test/libraries/cusolver/dense_generic.jl | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/test/libraries/cusolver/dense_generic.jl b/test/libraries/cusolver/dense_generic.jl index ec63ca82fb..2c59a1e120 100644 --- a/test/libraries/cusolver/dense_generic.jl +++ b/test/libraries/cusolver/dense_generic.jl @@ -124,6 +124,7 @@ p = 5 @testset "sytrs!" begin @testset "uplo = $uplo" for uplo in ('L', 'U') @testset "pivoting = $pivoting" for pivoting in (false, true) + !pivoting && (CUSOLVER.version() < v"11.7.2") && continue A = rand(elty,n,n) B = rand(elty,n,p) C = rand(elty,n) @@ -131,22 +132,26 @@ p = 5 d_A = CuMatrix(A) d_B = CuMatrix(B) d_C = CuVector(C) - !pivoting && (CUSOLVER.version() < v"11.7.2") && continue if pivoting d_A, d_ipiv, _ = CUSOLVER.sytrf!(uplo, d_A; pivoting) d_ipiv = CuVector{Int64}(d_ipiv) CUSOLVER.sytrs!(uplo, d_A, d_ipiv, d_B) CUSOLVER.sytrs!(uplo, d_A, d_ipiv, d_C) + A, ipiv, _ = LAPACK.sytrf!(uplo, A) + LAPACK.sytrs!(uplo, A, ipiv, B) + LAPACK.sytrs!(uplo, A, ipiv, C) + @test B ≈ collect(d_B) + @test C ≈ collect(d_C) else d_A, _ = CUSOLVER.sytrf!(uplo, d_A; pivoting) CUSOLVER.sytrs!(uplo, d_A, d_B) CUSOLVER.sytrs!(uplo, d_A, d_C) + # Verify correctness directly: non-pivoting cusolver cannot be + # compared against LAPACK (which always pivots), so instead check + # that A * x ≈ b for the original inputs. + @test A * collect(d_B) ≈ B + @test A * collect(d_C) ≈ C end - A, ipiv, _ = LAPACK.sytrf!(uplo, A) - LAPACK.sytrs!(uplo, A, ipiv, B) - LAPACK.sytrs!(uplo, A, ipiv, C) - @test B ≈ collect(d_B) - @test C ≈ collect(d_C) end end end