From ab0a1f1f91657eb359acbd2e877b7a04a73e06a1 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Thu, 12 Dec 2024 13:01:04 -0600 Subject: [PATCH] Add more support for batch linear systems --- src/interfaces.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/interfaces.jl b/src/interfaces.jl index e056156..b28b820 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -2,6 +2,7 @@ export CudssSolver, cudss, cudss_set, cudss_get """ solver = CudssSolver(A::CuSparseMatrixCSR{T,Cint}, structure::String, view::Char; index::Char='O') + solver = CudssSolver(A::Vector{CuSparseMatrixCSR{T,Cint}}, structure::String, view::Char; index::Char='O') solver = CudssSolver(matrix::CudssMatrix{T}, config::CudssConfig, data::CudssData) The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`. @@ -42,6 +43,13 @@ mutable struct CudssSolver{T} data = CudssData() return new{T}(matrix, config, data) end + + function CudssSolver(A::Vector{CuSparseMatrixCSR{T,Cint}}, structure::String, view::Char; index::Char='O') where T <: BlasFloat + matrix = CudssMatrix(A, structure, view; index) + config = CudssConfig() + data = CudssData() + return new{T}(matrix, config, data) + end end """ @@ -222,6 +230,8 @@ end """ cudss(phase::String, solver::CudssSolver{T}, x::CuVector{T}, b::CuVector{T}) cudss(phase::String, solver::CudssSolver{T}, X::CuMatrix{T}, B::CuMatrix{T}) + cudss(phase::String, solver::CudssSolver{T}, x::Vector{CuVector{T}}, b::Vector{CuVector{T}}) + cudss(phase::String, solver::CudssSolver{T}, X::Vector{CuMatrix{T}}, B::Vector{CuMatrix{T}}) cudss(phase::String, solver::CudssSolver{T}, X::CudssMatrix{T}, B::CudssMatrix{T}) The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`. @@ -246,3 +256,15 @@ function cudss(phase::String, solver::CudssSolver{T}, X::CuMatrix{T}, B::CuMatri rhs = CudssMatrix(B) cudss(phase, solver, solution, rhs) end + +function cudss(phase::String, solver::CudssSolver{T}, x::Vector{CuVector{T}}, b::Vector{CuVector{T}}) where T <: BlasFloat + solution = CudssMatrix(x) + rhs = CudssMatrix(b) + cudss(phase, solver, solution, rhs) +end + +function cudss(phase::String, solver::CudssSolver{T}, X::Vector{CuMatrix{T}}, B::Vector{CuMatrix{T}}) where T <: BlasFloat + solution = CudssMatrix(X) + rhs = CudssMatrix(B) + cudss(phase, solver, solution, rhs) +end