diff --git a/test/test_cudss.jl b/test/test_cudss.jl index 1fc01d0..233b7a4 100644 --- a/test/test_cudss.jl +++ b/test/test_cudss.jl @@ -512,13 +512,19 @@ function iterative_refinement() return norm(r_gpu) end - function ir_ldlt(T, A_cpu, x_cpu, b_cpu, ir) - A_gpu = CuSparseMatrixCSR(A_cpu |> tril) + function ir_ldlt(T, A_cpu, x_cpu, b_cpu, ir, uplo) + if uplo == 'L' + A_gpu = CuSparseMatrixCSR(A_cpu |> tril) + elseif uplo == 'U' + A_gpu = CuSparseMatrixCSR(A_cpu |> triu) + else + A_gpu = CuSparseMatrixCSR(A_cpu) + end x_gpu = CuVector(x_cpu) b_gpu = CuVector(b_cpu) structure = T <: Real ? "S" : "H" - solver = CudssSolver(A_gpu, structure, 'L') + solver = CudssSolver(A_gpu, structure, uplo) cudss_set(solver, "ir_n_steps", ir) cudss("analysis", solver, x_gpu, b_gpu) @@ -529,13 +535,19 @@ function iterative_refinement() return norm(r_gpu) end - function ir_llt(T, A_cpu, x_cpu, b_cpu, ir) - A_gpu = CuSparseMatrixCSR(A_cpu |> triu) + function ir_llt(T, A_cpu, x_cpu, b_cpu, ir, uplo) + if uplo == 'L' + A_gpu = CuSparseMatrixCSR(A_cpu |> tril) + elseif uplo == 'U' + A_gpu = CuSparseMatrixCSR(A_cpu |> triu) + else + A_gpu = CuSparseMatrixCSR(A_cpu) + end x_gpu = CuVector(x_cpu) b_gpu = CuVector(b_cpu) structure = T <: Real ? "SPD" : "HPD" - solver = CudssSolver(A_gpu, structure, 'U') + solver = CudssSolver(A_gpu, structure, uplo) cudss_set(solver, "ir_n_steps", ir) cudss("analysis", solver, x_gpu, b_gpu) @@ -562,16 +574,20 @@ function iterative_refinement() A_cpu = A_cpu + A_cpu' x_cpu = zeros(T, n) b_cpu = rand(T, n) - res = ir_ldlt(T, A_cpu, x_cpu, b_cpu, ir) - @test res ≤ √eps(R) + @testset "uplo = $uplo" for uplo in ('L', 'U', 'F') + res = ir_ldlt(T, A_cpu, x_cpu, b_cpu, ir, uplo) + @test res ≤ √eps(R) + end end @testset "LLᵀ / LLᴴ" begin A_cpu = sprand(T, n, n, 0.01) A_cpu = A_cpu * A_cpu' + I x_cpu = zeros(T, n) b_cpu = rand(T, n) - res = ir_llt(T, A_cpu, x_cpu, b_cpu, ir) - @test res ≤ √eps(R) + @testset "uplo = $uplo" for uplo in ('L', 'U', 'F') + res = ir_llt(T, A_cpu, x_cpu, b_cpu, ir, uplo) + @test res ≤ √eps(R) + end end end end