From 1894f7b7f16bb191305f589f91f2edbe05a3726b Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Fri, 8 Dec 2023 01:22:16 -0600 Subject: [PATCH] Add a generic API --- generic.jl | 56 +++++++++++++++++++++++++++++++++ src/CUDSS.jl | 2 ++ src/helpers.jl | 22 ++++++++++++- test/runtests.jl | 4 +++ test/test_cudss.jl | 78 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 161 insertions(+), 1 deletion(-) create mode 100644 generic.jl diff --git a/generic.jl b/generic.jl new file mode 100644 index 0000000..d6843ff --- /dev/null +++ b/generic.jl @@ -0,0 +1,56 @@ +function LinearAlgebra.lu(A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat + n = LinearAlgebra.checksquare(A) + solver = CudssSolver(A, 'G', 'F') + x = CudssMatrix(T, n) + b = CudssMatrix(T, n) + cudss("analysis", solver, x, b) + cudss("factorization", solver, x, b) + return solver +end + +function LinearAlgebra.ldlt(A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat + n = LinearAlgebra.checksquare(A) + structure = T <: Real ? 'S' : 'H' + solver = CudssSolver(A, structure, 'F') + x = CudssMatrix(T, n) + b = CudssMatrix(T, n) + cudss("analysis", solver, x, b) + cudss("factorization", solver, x, b) + return solver +end + +function LinearAlgebra.cholesky(A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat + n = LinearAlgebra.checksquare(A) + structure = T <: Real ? "SDP" : "HDP" + solver = CudssSolver(A, structure, 'F') + x = CudssMatrix(T, n) + b = CudssMatrix(T, n) + cudss("analysis", solver, x, b) + cudss("factorization", solver, x, b) + return solver +end + +for fun in (:lu!, :ldlt!, :cholesky!) + @eval begin + function LinearAlgebra.$fun(solver::CudssSolver, A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat + n = LinearAlgebra.checksquare(A) + cudss_set(solver.matrix, A) + x = CudssMatrix(T, n) + b = CudssMatrix(T, n) + cudss("factorization", solver, x, b) + return solver + end + end +end + +for type in (:CuVector, :CuMatrix) + @eval begin + function LinearAlgebra.ldiv!(solver::CudssSolver, b::$type{T}) where T <: BlasFloat + cudss("factorization", solver, b, b) + end + + function LinearAlgebra.ldiv!(x::$type{T}, solver::CudssSolver, b::$type{T}) where T <: BlasFloat + cudss("factorization", solver, x, b) + end + end +end diff --git a/src/CUDSS.jl b/src/CUDSS.jl index 5c40109..121ba16 100644 --- a/src/CUDSS.jl +++ b/src/CUDSS.jl @@ -6,6 +6,7 @@ using LinearAlgebra using SparseArrays import CUDA: @checked, libraryPropertyType, cudaDataType, initialize_context, retry_reclaim, CUstream +import LinearAlgebra: lu, lu!, ldlt, ldlt!, cholesky, cholesky!, ldiv!, BlasFloat include("libcudss.jl") include("error.jl") @@ -13,5 +14,6 @@ include("types.jl") include("helpers.jl") include("management.jl") include("interfaces.jl") +include("generic.jl") end # module CUDSS diff --git a/src/helpers.jl b/src/helpers.jl index 788bfb2..32f3fba 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -30,7 +30,27 @@ export CudssMatrix, CudssData, CudssConfig mutable struct CudssMatrix matrix::cudssMatrix_t - function CudssMatrix(v::CuVector) + function CudssMatrix(T::DataType, n::Integer) + matrix_ref = Ref{cudssMatrix_t}() + cudssMatrixCreateDn(matrix_ref, n, 1, n, CU_NULL, T, 'C') + obj = new(matrix_ref[]) + finalizer(cudssMatrixDestroy, obj) + obj + end + + function CudssMatrix(T::DataType, m::Integer, n::Integer; transposed::Bool=false) + matrix_ref = Ref{cudssMatrix_t}() + if transposed + cudssMatrixCreateDn(matrix_ref, n, m, m, CU_NULL, T, 'R') + else + cudssMatrixCreateDn(matrix_ref, m, n, m, CU_NULL, T, 'C') + end + obj = new(matrix_ref[]) + finalizer(cudssMatrixDestroy, obj) + obj + end + + function CudssMatrix(v::CuVector) m = length(v) matrix_ref = Ref{cudssMatrix_t}() cudssMatrixCreateDn(matrix_ref, m, 1, m, v, eltype(v), 'C') diff --git a/test/runtests.jl b/test/runtests.jl index 173e550..d6c060b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,4 +26,8 @@ include("test_cudss.jl") @testset "CudssExecution" begin cudss_execution() end + + @testset "Generic API" begin + cudss_generic() + end end diff --git a/test/test_cudss.jl b/test/test_cudss.jl index 5721a93..2ba9aa1 100644 --- a/test/test_cudss.jl +++ b/test/test_cudss.jl @@ -190,3 +190,81 @@ function cudss_execution() end end end + +function cudss_generic() + n = 100 + p = 5 + @testset "precision = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + R = real(T) + @testset "Unsymmetric -- Non-Hermitian" begin + A_cpu = sprand(T, n, n, 0.02) + I + x_cpu = zeros(T, n) + b_cpu = rand(T, n) + + A_gpu = CuSparseMatrixCSR(A_cpu) + x_gpu = CuVector(x_cpu) + b_gpu = CuVector(b_cpu) + + solver = lu(A_gpu) + ldiv!(x_gpu, solver, b_gpu) + r_gpu = b_gpu - A_gpu * x_gpu + @test norm(r_gpu) ≤ √eps(R) + + A_gpu2 = rand(T) * A_gpu + lu!(solver, A_gpu2) + x_gpu .= b_gpu + ldiv!(solver, x_gpu) + r_gpu2 = b_gpu - A_gpu2 * x_gpu + @test norm(r_gpu2) ≤ √eps(R) + end + + @testset "view = $view" for view in ('F',) + @testset "Symmetric -- Hermitian" begin + A_cpu = sprand(T, n, n, 0.01) + I + A_cpu = A_cpu + A_cpu' + X_cpu = zeros(T, n, p) + B_cpu = rand(T, n, p) + + A_gpu = CuSparseMatrixCSR(A_cpu) + X_gpu = CuMatrix(X_cpu) + B_gpu = CuMatrix(B_cpu) + + solver = ldlt(A_gpu) + (structure == 'H') && cudss_set(solver, "pivot_type", 'N') + ldiv!(X_gpu, solver, B_gpu) + R_gpu = B_gpu - A_gpu * X_gpu + @test norm(R_gpu) ≤ √eps(R) + + A_gpu2 = rand(T) * A_gpu + ldlt!(solver, A_gpu2) + X_gpu .= B_gpu + ldiv!(solver, X_gpu) + R_gpu2 = B_gpu - A_gpu2 * X_gpu + @test norm(R_gpu2) ≤ √eps(R) + end + + @testset "SPD -- HPD" begin + A_cpu = sprand(T, n, n, 0.01) + A_cpu = A_cpu * A_cpu' + I + X_cpu = zeros(T, n, p) + B_cpu = rand(T, n, p) + + A_gpu = CuSparseMatrixCSR(A_cpu) + X_gpu = CuMatrix(X_cpu) + B_gpu = CuMatrix(B_cpu) + + solver = cholesky(A_gpu) + ldiv!(X_gpu, solver, B_gpu) + R_gpu = B_gpu - A_gpu * X_gpu + @test norm(R_gpu) ≤ √eps(R) + + A_gpu2 = rand(T) * A_gpu + cholesky!(solver, A_gpu2) + X_gpu .= B_gpu + ldiv!(solver, X_gpu) + R_gpu2 = B_gpu - A_gpu2 * X_gpu + @test norm(R_gpu2) ≤ √eps(R) + end + end + end +end