From a1284e9c2cd057680756fba19d4886ddeb16f5ab Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Fri, 8 Dec 2023 12:27:56 -0600 Subject: [PATCH] Add a type T for CudssMatrix and CudssSolver --- src/helpers.jl | 23 ++++++++++++----------- src/interfaces.jl | 26 +++++++++++++------------- 2 files changed, 25 insertions(+), 24 deletions(-) diff --git a/src/helpers.jl b/src/helpers.jl index 72aaca0..707dc56 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -27,38 +27,39 @@ export CudssMatrix, CudssData, CudssConfig - `'Z'`: 0-based indexing; - `'O'`: 1-based indexing. """ -mutable struct CudssMatrix +mutable struct CudssMatrix{T} + type::Type{T} matrix::cudssMatrix_t - function CudssMatrix(v::CuVector) + function CudssMatrix(v::CuVector{T}) where T <: BlasFloat m = length(v) matrix_ref = Ref{cudssMatrix_t}() - cudssMatrixCreateDn(matrix_ref, m, 1, m, v, eltype(v), 'C') - obj = new(matrix_ref[]) + cudssMatrixCreateDn(matrix_ref, m, 1, m, v, T, 'C') + obj = new(T, matrix_ref[]) finalizer(cudssMatrixDestroy, obj) obj end - function CudssMatrix(A::CuMatrix; transposed::Bool=false) + function CudssMatrix(A::CuMatrix{T}; transposed::Bool=false) where T <: BlasFloat m,n = size(A) matrix_ref = Ref{cudssMatrix_t}() if transposed - cudssMatrixCreateDn(matrix_ref, n, m, m, A, eltype(A), 'R') + cudssMatrixCreateDn(matrix_ref, n, m, m, A, T, 'R') else - cudssMatrixCreateDn(matrix_ref, m, n, m, A, eltype(A), 'C') + cudssMatrixCreateDn(matrix_ref, m, n, m, A, T, 'C') end - obj = new(matrix_ref[]) + obj = new(T, matrix_ref[]) finalizer(cudssMatrixDestroy, obj) obj end - function CudssMatrix(A::CuSparseMatrixCSR, structure::String, view::Char; index::Char='O') + function CudssMatrix(A::CuSparseMatrixCSR{T}, structure::String, view::Char; index::Char='O') where T <: BlasFloat m,n = size(A) matrix_ref = Ref{cudssMatrix_t}() cudssMatrixCreateCsr(matrix_ref, m, n, nnz(A), A.rowPtr, CU_NULL, - A.colVal, A.nzVal, eltype(A.rowPtr), eltype(A.nzVal), structure, + A.colVal, A.nzVal, eltype(A.rowPtr), T, structure, view, index) - obj = new(matrix_ref[]) + obj = new(T, matrix_ref[]) finalizer(cudssMatrixDestroy, obj) obj end diff --git a/src/interfaces.jl b/src/interfaces.jl index 2671952..bf98c30 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -25,10 +25,10 @@ One constructor of `CudssSolver` takes as input the same parameters as [`CudssMa `CudssSolver` can be also constructed from the three structures `CudssMatrix`, `CudssConfig` and `CudssData` if needed. """ -mutable struct CudssSolver +mutable struct CudssSolver{T} matrix::CudssMatrix config::CudssConfig - data::CudssData + data::CudssData{T} function CudssSolver(matrix::CudssMatrix, config::CudssConfig, data::CudssData) return new(matrix, config, data) @@ -68,15 +68,15 @@ The available data parameter is: """ function cudss_set end -function cudss_set(matrix::CudssMatrix, v::CuVector) +function cudss_set(matrix::CudssMatrix{T}, v::CuVector{T}) where T <: BlasFloat cudssMatrixSetValues(matrix, v) end -function cudss_set(matrix::CudssMatrix, A::CuMatrix) +function cudss_set(matrix::CudssMatrix{T}, A::CuMatrix{T}) where T <: BlasFloat cudssMatrixSetValues(matrix, A) end -function cudss_set(matrix::CudssMatrix, A::CuSparseMatrixCSR) +function cudss_set(matrix::CudssMatrix{T}, A::CuSparseMatrixCSR{T}) where T <: BlasFloat cudssMatrixSetCsrPointers(matrix, A.rowPtr, CU_NULL, A.colVal, A.nzVal) end @@ -166,18 +166,18 @@ The phases `"solve_fwd"`, `"solve_diag"` and `"solve_bwd"` are available but not """ function cudss end -function cudss(phase::String, solver::CudssSolver, x::CuVector, b::CuVector) +function cudss(phase::String, solver::CudssSolver{T}, X::CudssMatrix{T}, B::CudssMatrix{T}) where T <: BlasFloat + cudssExecute(handle(), phase, solver.config, solver.data, solver.matrix, X, B) +end + +function cudss(phase::String, solver::CudssSolver{T}, x::CuVector{T}, b::CuVector{T}) where T <: BlasFloat solution = CudssMatrix(x) rhs = CudssMatrix(b) - cudssExecute(handle(), phase, solver.config, solver.data, solver.matrix, solution, rhs) + cudss(phase, solver, solution, rhs) end -function cudss(phase::String, solver::CudssSolver, X::CuMatrix, B::CuMatrix) +function cudss(phase::String, solver::CudssSolver{T}, X::CuMatrix{T}, B::CuMatrix{T}) where T <: BlasFloat solution = CudssMatrix(X) rhs = CudssMatrix(B) - cudssExecute(handle(), phase, solver.config, solver.data, solver.matrix, solution, rhs) -end - -function cudss(phase::String, solver::CudssSolver, X::CudssMatrix, B::CudssMatrix) - cudssExecute(handle(), phase, solver.config, solver.data, solver.matrix, X, B) + cudss(phase, solver, solution, rhs) end