Skip to content

Commit

Permalink
Add more support for batch linear systems
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Dec 12, 2024
1 parent 1132a51 commit ab0a1f1
Showing 1 changed file with 22 additions and 0 deletions.
22 changes: 22 additions & 0 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export CudssSolver, cudss, cudss_set, cudss_get

"""
solver = CudssSolver(A::CuSparseMatrixCSR{T,Cint}, structure::String, view::Char; index::Char='O')
solver = CudssSolver(A::Vector{CuSparseMatrixCSR{T,Cint}}, structure::String, view::Char; index::Char='O')
solver = CudssSolver(matrix::CudssMatrix{T}, config::CudssConfig, data::CudssData)
The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
Expand Down Expand Up @@ -42,6 +43,13 @@ mutable struct CudssSolver{T}
data = CudssData()
return new{T}(matrix, config, data)
end

function CudssSolver(A::Vector{CuSparseMatrixCSR{T,Cint}}, structure::String, view::Char; index::Char='O') where T <: BlasFloat
matrix = CudssMatrix(A, structure, view; index)
config = CudssConfig()
data = CudssData()
return new{T}(matrix, config, data)
end
end

"""
Expand Down Expand Up @@ -222,6 +230,8 @@ end
"""
cudss(phase::String, solver::CudssSolver{T}, x::CuVector{T}, b::CuVector{T})
cudss(phase::String, solver::CudssSolver{T}, X::CuMatrix{T}, B::CuMatrix{T})
cudss(phase::String, solver::CudssSolver{T}, x::Vector{CuVector{T}}, b::Vector{CuVector{T}})
cudss(phase::String, solver::CudssSolver{T}, X::Vector{CuMatrix{T}}, B::Vector{CuMatrix{T}})
cudss(phase::String, solver::CudssSolver{T}, X::CudssMatrix{T}, B::CudssMatrix{T})
The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
Expand All @@ -246,3 +256,15 @@ function cudss(phase::String, solver::CudssSolver{T}, X::CuMatrix{T}, B::CuMatri
rhs = CudssMatrix(B)
cudss(phase, solver, solution, rhs)
end

function cudss(phase::String, solver::CudssSolver{T}, x::Vector{CuVector{T}}, b::Vector{CuVector{T}}) where T <: BlasFloat
solution = CudssMatrix(x)
rhs = CudssMatrix(b)
cudss(phase, solver, solution, rhs)
end

function cudss(phase::String, solver::CudssSolver{T}, X::Vector{CuMatrix{T}}, B::Vector{CuMatrix{T}}) where T <: BlasFloat
solution = CudssMatrix(X)
rhs = CudssMatrix(B)
cudss(phase, solver, solution, rhs)
end

0 comments on commit ab0a1f1

Please sign in to comment.