Skip to content

Commit

Permalink
Add a type T for CudssMatrix and CudssSolver
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Dec 8, 2023
1 parent 7192527 commit a1284e9
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
23 changes: 12 additions & 11 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,38 +27,39 @@ export CudssMatrix, CudssData, CudssConfig
- `'Z'`: 0-based indexing;
- `'O'`: 1-based indexing.
"""
mutable struct CudssMatrix
mutable struct CudssMatrix{T}
type::Type{T}
matrix::cudssMatrix_t

function CudssMatrix(v::CuVector)
function CudssMatrix(v::CuVector{T}) where T <: BlasFloat
m = length(v)
matrix_ref = Ref{cudssMatrix_t}()
cudssMatrixCreateDn(matrix_ref, m, 1, m, v, eltype(v), 'C')
obj = new(matrix_ref[])
cudssMatrixCreateDn(matrix_ref, m, 1, m, v, T, 'C')
obj = new(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end

function CudssMatrix(A::CuMatrix; transposed::Bool=false)
function CudssMatrix(A::CuMatrix{T}; transposed::Bool=false) where T <: BlasFloat
m,n = size(A)
matrix_ref = Ref{cudssMatrix_t}()
if transposed
cudssMatrixCreateDn(matrix_ref, n, m, m, A, eltype(A), 'R')
cudssMatrixCreateDn(matrix_ref, n, m, m, A, T, 'R')
else
cudssMatrixCreateDn(matrix_ref, m, n, m, A, eltype(A), 'C')
cudssMatrixCreateDn(matrix_ref, m, n, m, A, T, 'C')
end
obj = new(matrix_ref[])
obj = new(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end

function CudssMatrix(A::CuSparseMatrixCSR, structure::String, view::Char; index::Char='O')
function CudssMatrix(A::CuSparseMatrixCSR{T}, structure::String, view::Char; index::Char='O') where T <: BlasFloat
m,n = size(A)
matrix_ref = Ref{cudssMatrix_t}()
cudssMatrixCreateCsr(matrix_ref, m, n, nnz(A), A.rowPtr, CU_NULL,
A.colVal, A.nzVal, eltype(A.rowPtr), eltype(A.nzVal), structure,
A.colVal, A.nzVal, eltype(A.rowPtr), T, structure,
view, index)
obj = new(matrix_ref[])
obj = new(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end
Expand Down
26 changes: 13 additions & 13 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ One constructor of `CudssSolver` takes as input the same parameters as [`CudssMa
`CudssSolver` can be also constructed from the three structures `CudssMatrix`, `CudssConfig` and `CudssData` if needed.
"""
mutable struct CudssSolver
mutable struct CudssSolver{T}
matrix::CudssMatrix
config::CudssConfig
data::CudssData
data::CudssData{T}

function CudssSolver(matrix::CudssMatrix, config::CudssConfig, data::CudssData)
return new(matrix, config, data)
Expand Down Expand Up @@ -68,15 +68,15 @@ The available data parameter is:
"""
function cudss_set end

function cudss_set(matrix::CudssMatrix, v::CuVector)
function cudss_set(matrix::CudssMatrix{T}, v::CuVector{T}) where T <: BlasFloat
cudssMatrixSetValues(matrix, v)
end

function cudss_set(matrix::CudssMatrix, A::CuMatrix)
function cudss_set(matrix::CudssMatrix{T}, A::CuMatrix{T}) where T <: BlasFloat
cudssMatrixSetValues(matrix, A)
end

function cudss_set(matrix::CudssMatrix, A::CuSparseMatrixCSR)
function cudss_set(matrix::CudssMatrix{T}, A::CuSparseMatrixCSR{T}) where T <: BlasFloat
cudssMatrixSetCsrPointers(matrix, A.rowPtr, CU_NULL, A.colVal, A.nzVal)
end

Expand Down Expand Up @@ -166,18 +166,18 @@ The phases `"solve_fwd"`, `"solve_diag"` and `"solve_bwd"` are available but not
"""
function cudss end

function cudss(phase::String, solver::CudssSolver, x::CuVector, b::CuVector)
function cudss(phase::String, solver::CudssSolver{T}, X::CudssMatrix{T}, B::CudssMatrix{T}) where T <: BlasFloat
cudssExecute(handle(), phase, solver.config, solver.data, solver.matrix, X, B)
end

function cudss(phase::String, solver::CudssSolver{T}, x::CuVector{T}, b::CuVector{T}) where T <: BlasFloat
solution = CudssMatrix(x)
rhs = CudssMatrix(b)
cudssExecute(handle(), phase, solver.config, solver.data, solver.matrix, solution, rhs)
cudss(phase, solver, solution, rhs)
end

function cudss(phase::String, solver::CudssSolver, X::CuMatrix, B::CuMatrix)
function cudss(phase::String, solver::CudssSolver{T}, X::CuMatrix{T}, B::CuMatrix{T}) where T <: BlasFloat
solution = CudssMatrix(X)
rhs = CudssMatrix(B)
cudssExecute(handle(), phase, solver.config, solver.data, solver.matrix, solution, rhs)
end

function cudss(phase::String, solver::CudssSolver, X::CudssMatrix, B::CudssMatrix)
cudssExecute(handle(), phase, solver.config, solver.data, solver.matrix, X, B)
cudss(phase, solver, solution, rhs)
end

0 comments on commit a1284e9

Please sign in to comment.