Skip to content

Commit

Permalink
Add a type T for CudssMatrix and CudssSolver
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Dec 8, 2023
1 parent a1284e9 commit c31f3f3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 20 deletions.
14 changes: 8 additions & 6 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ export CudssMatrix, CudssData, CudssConfig
## Matrix

"""
matrix = CudssMatrix(v::CuVector)
matrix = CudssMatrix(A::CuMatrix)
matrix = CudssMatrix(A::CuSparseMatrixCSR, struture::String, view::Char; index::Char='O')
matrix = CudssMatrix(v::CuVector{T})
matrix = CudssMatrix(A::CuMatrix{T})
matrix = CudssMatrix(A::CuSparseMatrixCSR{T}, struture::String, view::Char; index::Char='O')
The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
`CudssMatrix` is a wrapper for `CuVector`, `CuMatrix` and `CuSparseMatrixCSR`.
`CudssMatrix` is used to pass matrix of the linear system, as well as solution and right-hand side.
Expand Down Expand Up @@ -35,7 +37,7 @@ mutable struct CudssMatrix{T}
m = length(v)
matrix_ref = Ref{cudssMatrix_t}()
cudssMatrixCreateDn(matrix_ref, m, 1, m, v, T, 'C')
obj = new(T, matrix_ref[])
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end
Expand All @@ -48,7 +50,7 @@ mutable struct CudssMatrix{T}
else
cudssMatrixCreateDn(matrix_ref, m, n, m, A, T, 'C')
end
obj = new(T, matrix_ref[])
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end
Expand All @@ -59,7 +61,7 @@ mutable struct CudssMatrix{T}
cudssMatrixCreateCsr(matrix_ref, m, n, nnz(A), A.rowPtr, CU_NULL,
A.colVal, A.nzVal, eltype(A.rowPtr), T, structure,
view, index)
obj = new(T, matrix_ref[])
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end
Expand Down
34 changes: 20 additions & 14 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
export CudssSolver, cudss, cudss_set, cudss_get

"""
solver = CudssSolver(A::CuSparseMatrixCSR, structure::String, view::Char; index::Char='O')
solver = CudssSolver(matrix::CudssMatrix, config::CudssConfig, data::CudssData)
solver = CudssSolver(A::CuSparseMatrixCSR{T}, structure::String, view::Char; index::Char='O')
solver = CudssSolver(matrix::CudssMatrix{T}, config::CudssConfig, data::CudssData)
The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
`CudssSolver` contains all structures required to solve linear systems with cuDSS.
One constructor of `CudssSolver` takes as input the same parameters as [`CudssMatrix`](@ref).
Expand All @@ -26,30 +28,32 @@ One constructor of `CudssSolver` takes as input the same parameters as [`CudssMa
`CudssSolver` can be also constructed from the three structures `CudssMatrix`, `CudssConfig` and `CudssData` if needed.
"""
mutable struct CudssSolver{T}
matrix::CudssMatrix
matrix::CudssMatrix{T}
config::CudssConfig
data::CudssData{T}
data::CudssData

function CudssSolver(matrix::CudssMatrix, config::CudssConfig, data::CudssData)
return new(matrix, config, data)
function CudssSolver(matrix::CudssMatrix{T}, config::CudssConfig, data::CudssData) where T <: BlasFloat
return new{T}(matrix, config, data)
end

function CudssSolver(A::CuSparseMatrixCSR, structure::String, view::Char; index::Char='O')
function CudssSolver(A::CuSparseMatrixCSR{T}, structure::String, view::Char; index::Char='O') where T <: BlasFloat
matrix = CudssMatrix(A, structure, view; index)
config = CudssConfig()
data = CudssData()
return new(matrix, config, data)
return new{T}(matrix, config, data)
end
end

"""
cudss_set(matrix::CudssMatrix, v::CuVector)
cudss_set(matrix::CudssMatrix, A::CuMatrix)
cudss_set(matrix::CudssMatrix, A::CuSparseMatrixCSR)
cudss_set(matrix::CudssMatrix{T}, v::CuVector{T})
cudss_set(matrix::CudssMatrix{T}, A::CuMatrix{T})
cudss_set(matrix::CudssMatrix{T}, A::CuSparseMatrixCSR{T})
cudss_set(data::CudssSolver, param::String, value)
cudss_set(config::CudssConfig, param::String, value)
cudss_set(data::CudssData, param::String, value)
The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
The available configuration parameters are:
- `"reordering_alg"`: Algorithm for the reordering phase;
- `"factorization_alg"`: Algorithm for the factorization phase;
Expand Down Expand Up @@ -157,9 +161,11 @@ function cudss_get(config::CudssConfig, param::String)
end

"""
cudss(phase::String, solver::CudssSolver, x::CuVector, b::CuVector)
cudss(phase::String, solver::CudssSolver, X::CuMatrix, B::CuMatrix)
cudss(phase::String, solver::CudssSolver, X::CudssMatrix, B::CudssMatrix)
cudss(phase::String, solver::CudssSolver{T}, x::CuVector{T}, b::CuVector{T})
cudss(phase::String, solver::CudssSolver{T}, X::CuMatrix{T}, B::CuMatrix{T})
cudss(phase::String, solver::CudssSolver{T}, X::CudssMatrix{T}, B::CudssMatrix{T})
The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
The available phases are `"analysis"`, `"factorization"`, `"refactorization"` and `"solve"`.
The phases `"solve_fwd"`, `"solve_diag"` and `"solve_bwd"` are available but not yet functional.
Expand Down

0 comments on commit c31f3f3

Please sign in to comment.