Skip to content

Commit

Permalink
Add a generic API
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Dec 8, 2023
1 parent f5e5798 commit 1894f7b
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 1 deletion.
56 changes: 56 additions & 0 deletions generic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
function LinearAlgebra.lu(A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat
n = LinearAlgebra.checksquare(A)
solver = CudssSolver(A, 'G', 'F')
x = CudssMatrix(T, n)
b = CudssMatrix(T, n)
cudss("analysis", solver, x, b)
cudss("factorization", solver, x, b)
return solver
end

function LinearAlgebra.ldlt(A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat
n = LinearAlgebra.checksquare(A)
structure = T <: Real ? 'S' : 'H'
solver = CudssSolver(A, structure, 'F')
x = CudssMatrix(T, n)
b = CudssMatrix(T, n)
cudss("analysis", solver, x, b)
cudss("factorization", solver, x, b)
return solver
end

function LinearAlgebra.cholesky(A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat
n = LinearAlgebra.checksquare(A)
structure = T <: Real ? "SDP" : "HDP"
solver = CudssSolver(A, structure, 'F')
x = CudssMatrix(T, n)
b = CudssMatrix(T, n)
cudss("analysis", solver, x, b)
cudss("factorization", solver, x, b)
return solver
end

for fun in (:lu!, :ldlt!, :cholesky!)
@eval begin
function LinearAlgebra.$fun(solver::CudssSolver, A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat
n = LinearAlgebra.checksquare(A)
cudss_set(solver.matrix, A)
x = CudssMatrix(T, n)
b = CudssMatrix(T, n)
cudss("factorization", solver, x, b)
return solver
end
end
end

for type in (:CuVector, :CuMatrix)
@eval begin
function LinearAlgebra.ldiv!(solver::CudssSolver, b::$type{T}) where T <: BlasFloat
cudss("factorization", solver, b, b)
end

function LinearAlgebra.ldiv!(x::$type{T}, solver::CudssSolver, b::$type{T}) where T <: BlasFloat
cudss("factorization", solver, x, b)
end
end
end
2 changes: 2 additions & 0 deletions src/CUDSS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ using LinearAlgebra
using SparseArrays

import CUDA: @checked, libraryPropertyType, cudaDataType, initialize_context, retry_reclaim, CUstream
import LinearAlgebra: lu, lu!, ldlt, ldlt!, cholesky, cholesky!, ldiv!, BlasFloat

include("libcudss.jl")
include("error.jl")
include("types.jl")
include("helpers.jl")
include("management.jl")
include("interfaces.jl")
include("generic.jl")

end # module CUDSS
22 changes: 21 additions & 1 deletion src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,27 @@ export CudssMatrix, CudssData, CudssConfig
mutable struct CudssMatrix
matrix::cudssMatrix_t

function CudssMatrix(v::CuVector)
function CudssMatrix(T::DataType, n::Integer)
matrix_ref = Ref{cudssMatrix_t}()
cudssMatrixCreateDn(matrix_ref, n, 1, n, CU_NULL, T, 'C')
obj = new(matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end

function CudssMatrix(T::DataType, m::Integer, n::Integer; transposed::Bool=false)
matrix_ref = Ref{cudssMatrix_t}()
if transposed
cudssMatrixCreateDn(matrix_ref, n, m, m, CU_NULL, T, 'R')
else
cudssMatrixCreateDn(matrix_ref, m, n, m, CU_NULL, T, 'C')
end
obj = new(matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end

function CudssMatrix(v::CuVector)
m = length(v)
matrix_ref = Ref{cudssMatrix_t}()
cudssMatrixCreateDn(matrix_ref, m, 1, m, v, eltype(v), 'C')
Expand Down
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,8 @@ include("test_cudss.jl")
@testset "CudssExecution" begin
cudss_execution()
end

@testset "Generic API" begin
cudss_generic()
end
end
78 changes: 78 additions & 0 deletions test/test_cudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,3 +190,81 @@ function cudss_execution()
end
end
end

function cudss_generic()
n = 100
p = 5
@testset "precision = $T" for T in (Float32, Float64, ComplexF32, ComplexF64)
R = real(T)
@testset "Unsymmetric -- Non-Hermitian" begin
A_cpu = sprand(T, n, n, 0.02) + I
x_cpu = zeros(T, n)
b_cpu = rand(T, n)

A_gpu = CuSparseMatrixCSR(A_cpu)
x_gpu = CuVector(x_cpu)
b_gpu = CuVector(b_cpu)

solver = lu(A_gpu)
ldiv!(x_gpu, solver, b_gpu)
r_gpu = b_gpu - A_gpu * x_gpu
@test norm(r_gpu) eps(R)

A_gpu2 = rand(T) * A_gpu
lu!(solver, A_gpu2)
x_gpu .= b_gpu
ldiv!(solver, x_gpu)
r_gpu2 = b_gpu - A_gpu2 * x_gpu
@test norm(r_gpu2) eps(R)
end

@testset "view = $view" for view in ('F',)
@testset "Symmetric -- Hermitian" begin
A_cpu = sprand(T, n, n, 0.01) + I
A_cpu = A_cpu + A_cpu'
X_cpu = zeros(T, n, p)
B_cpu = rand(T, n, p)

A_gpu = CuSparseMatrixCSR(A_cpu)
X_gpu = CuMatrix(X_cpu)
B_gpu = CuMatrix(B_cpu)

solver = ldlt(A_gpu)
(structure == 'H') && cudss_set(solver, "pivot_type", 'N')
ldiv!(X_gpu, solver, B_gpu)
R_gpu = B_gpu - A_gpu * X_gpu
@test norm(R_gpu) eps(R)

A_gpu2 = rand(T) * A_gpu
ldlt!(solver, A_gpu2)
X_gpu .= B_gpu
ldiv!(solver, X_gpu)
R_gpu2 = B_gpu - A_gpu2 * X_gpu
@test norm(R_gpu2) eps(R)
end

@testset "SPD -- HPD" begin
A_cpu = sprand(T, n, n, 0.01)
A_cpu = A_cpu * A_cpu' + I
X_cpu = zeros(T, n, p)
B_cpu = rand(T, n, p)

A_gpu = CuSparseMatrixCSR(A_cpu)
X_gpu = CuMatrix(X_cpu)
B_gpu = CuMatrix(B_cpu)

solver = cholesky(A_gpu)
ldiv!(X_gpu, solver, B_gpu)
R_gpu = B_gpu - A_gpu * X_gpu
@test norm(R_gpu) eps(R)

A_gpu2 = rand(T) * A_gpu
cholesky!(solver, A_gpu2)
X_gpu .= B_gpu
ldiv!(solver, X_gpu)
R_gpu2 = B_gpu - A_gpu2 * X_gpu
@test norm(R_gpu2) eps(R)
end
end
end
end

0 comments on commit 1894f7b

Please sign in to comment.