Skip to content

Commit

Permalink
Add a generic API
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Dec 8, 2023
1 parent 023c792 commit 1a81699
Show file tree
Hide file tree
Showing 5 changed files with 160 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/CUDSS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ using LinearAlgebra
using SparseArrays

import CUDA: @checked, libraryPropertyType, cudaDataType, initialize_context, retry_reclaim, CUstream
import LinearAlgebra: BlasFloat
import LinearAlgebra: lu, lu!, ldlt, ldlt!, cholesky, cholesky!, ldiv!, BlasFloat

include("libcudss.jl")
include("error.jl")
include("types.jl")
include("helpers.jl")
include("management.jl")
include("interfaces.jl")
include("generic.jl")

end # module CUDSS
57 changes: 57 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
function LinearAlgebra.lu(A::CuSparseMatrixCSR{T}) where T <: BlasFloat
n = LinearAlgebra.checksquare(A)
solver = CudssSolver(A, "G", 'F')
x = CudssMatrix(T, n)
b = CudssMatrix(T, n)
cudss("analysis", solver, x, b)
cudss("factorization", solver, x, b)
return solver
end

function LinearAlgebra.ldlt(A::CuSparseMatrixCSR{T}) where T <: BlasFloat
n = LinearAlgebra.checksquare(A)
structure = T <: Real ? "S" : "H"
solver = CudssSolver(A, structure, 'F')
(T <: Complex) && cudss_set(solver, "pivot_type", 'N')
x = CudssMatrix(T, n)
b = CudssMatrix(T, n)
cudss("analysis", solver, x, b)
cudss("factorization", solver, x, b)
return solver
end

function LinearAlgebra.cholesky(A::CuSparseMatrixCSR{T}) where T <: BlasFloat
n = LinearAlgebra.checksquare(A)
structure = T <: Real ? "SPD" : "HPD"
solver = CudssSolver(A, structure, 'F')
x = CudssMatrix(T, n)
b = CudssMatrix(T, n)
cudss("analysis", solver, x, b)
cudss("factorization", solver, x, b)
return solver
end

for fun in (:lu!, :ldlt!, :cholesky!)
@eval begin
function LinearAlgebra.$fun(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T}) where T <: BlasFloat
n = LinearAlgebra.checksquare(A)
cudss_set(solver.matrix, A)
x = CudssMatrix(T, n)
b = CudssMatrix(T, n)
cudss("factorization", solver, x, b)
return solver
end
end
end

for type in (:CuVector, :CuMatrix)
@eval begin
function LinearAlgebra.ldiv!(solver::CudssSolver{T}, b::$type{T}) where T <: BlasFloat
cudss("solve", solver, b, b)
end

function LinearAlgebra.ldiv!(x::$type{T}, solver::CudssSolver{T}, b::$type{T}) where T <: BlasFloat
cudss("solve", solver, x, b)
end
end
end
20 changes: 20 additions & 0 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,26 @@ mutable struct CudssMatrix{T}
type::Type{T}
matrix::cudssMatrix_t

function CudssMatrix(::Type{T}, n::Integer) where where T <: BlasFloat
matrix_ref = Ref{cudssMatrix_t}()
cudssMatrixCreateDn(matrix_ref, n, 1, n, CU_NULL, T, 'C')
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end

function CudssMatrix(::Type{T}, m::Integer, n::Integer; transposed::Bool=false) where T <: BlasFloat
matrix_ref = Ref{cudssMatrix_t}()
if transposed
cudssMatrixCreateDn(matrix_ref, n, m, m, CU_NULL, T, 'R')
else
cudssMatrixCreateDn(matrix_ref, m, n, m, CU_NULL, T, 'C')
end
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end

function CudssMatrix(v::CuVector{T}) where T <: BlasFloat
m = length(v)
matrix_ref = Ref{cudssMatrix_t}()
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,8 @@ include("test_cudss.jl")
@testset "CudssExecution" begin
cudss_execution()
end

@testset "Generic API" begin
cudss_generic()
end
end
77 changes: 77 additions & 0 deletions test/test_cudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,80 @@ function cudss_execution()
end
end
end

function cudss_generic()
n = 100
p = 5
@testset "precision = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
R = real(T)
@testset "Unsymmetric -- Non-Hermitian" begin
A_cpu = sprand(T, n, n, 0.02) + I
x_cpu = zeros(T, n)
b_cpu = rand(T, n)

A_gpu = CuSparseMatrixCSR(A_cpu)
x_gpu = CuVector(x_cpu)
b_gpu = CuVector(b_cpu)

solver = lu(A_gpu)
ldiv!(x_gpu, solver, b_gpu)
r_gpu = b_gpu - A_gpu * x_gpu
@test norm(r_gpu) eps(R)

A_gpu2 = rand(T) * A_gpu
lu!(solver, A_gpu2)
x_gpu .= b_gpu
ldiv!(solver, x_gpu)
r_gpu2 = b_gpu - A_gpu2 * x_gpu
@test norm(r_gpu2) eps(R)
end

@testset "view = $view" for view in ('F',)
@testset "Symmetric -- Hermitian" begin
A_cpu = sprand(T, n, n, 0.01) + I
A_cpu = A_cpu + A_cpu'
X_cpu = zeros(T, n, p)
B_cpu = rand(T, n, p)

A_gpu = CuSparseMatrixCSR(A_cpu)
X_gpu = CuMatrix(X_cpu)
B_gpu = CuMatrix(B_cpu)

solver = ldlt(A_gpu)
ldiv!(X_gpu, solver, B_gpu)
R_gpu = B_gpu - A_gpu * X_gpu
@test norm(R_gpu) eps(R)

A_gpu2 = rand(R) * A_gpu
ldlt!(solver, A_gpu2)
X_gpu .= B_gpu
ldiv!(solver, X_gpu)
R_gpu2 = B_gpu - A_gpu2 * X_gpu
@test norm(R_gpu2) eps(R)
end

@testset "SPD -- HPD" begin
A_cpu = sprand(T, n, n, 0.01)
A_cpu = A_cpu * A_cpu' + I
X_cpu = zeros(T, n, p)
B_cpu = rand(T, n, p)

A_gpu = CuSparseMatrixCSR(A_cpu)
X_gpu = CuMatrix(X_cpu)
B_gpu = CuMatrix(B_cpu)

solver = cholesky(A_gpu)
ldiv!(X_gpu, solver, B_gpu)
R_gpu = B_gpu - A_gpu * X_gpu
@test norm(R_gpu) eps(R)

A_gpu2 = rand(R) * A_gpu
cholesky!(solver, A_gpu2)
X_gpu .= B_gpu
ldiv!(solver, X_gpu)
R_gpu2 = B_gpu - A_gpu2 * X_gpu
@test norm(R_gpu2) eps(R)
end
end
end
end

0 comments on commit 1a81699

Please sign in to comment.