From ac822eeff91dddf5354ca77321292fd45bc96b88 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Thu, 12 Dec 2024 14:22:49 -0600 Subject: [PATCH] Handle pointer of pointers --- src/CUDSS.jl | 3 ++- src/helpers.jl | 10 +++++++--- src/interfaces.jl | 10 +++++++--- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/CUDSS.jl b/src/CUDSS.jl index 05f0e43..59b5e18 100644 --- a/src/CUDSS.jl +++ b/src/CUDSS.jl @@ -1,6 +1,6 @@ module CUDSS -using CUDA, CUDA.APIUtils, CUDA.CUSPARSE +using CUDA, CUDA.APIUtils, CUDA.CUSPARSE, CUDA.CUBLAS using CUDSS_jll using LinearAlgebra using SparseArrays @@ -14,6 +14,7 @@ else end import CUDA: @checked, libraryPropertyType, cudaDataType, initialize_context, retry_reclaim, CUstream, @gcsafe_ccall +import CUDA.CUBLAS: unsafe_batch import LinearAlgebra: lu, lu!, ldlt, ldlt!, cholesky, cholesky!, ldiv!, BlasFloat, BlasReal, checksquare import Base: \ diff --git a/src/helpers.jl b/src/helpers.jl index e3f0a5d..8ef8ef1 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -95,7 +95,9 @@ mutable struct CudssMatrix{T} nrows = [length(vᵢ) for vᵢ in v] ncols = [1 for i = 1:nbatch] ld = nrows - cudssMatrixCreateBatchDn(matrix_ref, nbatch, nrows, ncols, ld, v, T, 'C') + vptrs = unsafe_batch(v) + cudssMatrixCreateBatchDn(matrix_ref, nbatch, nrows, ncols, ld, vptrs, T, 'C') + # unsafe_free!(vptrs) obj = new{T}(T, matrix_ref[]) finalizer(cudssMatrixDestroy, obj) obj @@ -107,11 +109,13 @@ mutable struct CudssMatrix{T} nrows = [size(Aᵢ,1) for Aᵢ in A] ncols = [size(Aᵢ,2) for Aᵢ in A] ld = nrows + Aptrs = unsafe_batch(A) if transposed - cudssMatrixCreateBatchDn(matrix_ref, nbatch, ncols, nrows, ld, A, T, 'R') + cudssMatrixCreateBatchDn(matrix_ref, nbatch, ncols, nrows, ld, Aptrs, T, 'R') else - cudssMatrixCreateBatchDn(matrix_ref, nbatch, nrows, ncols, ld, A, T, 'C') + cudssMatrixCreateBatchDn(matrix_ref, nbatch, nrows, ncols, ld, Aptrs, T, 'C') end + # unsafe_free!(Aptrs) obj = new{T}(T, matrix_ref[]) finalizer(cudssMatrixDestroy, obj) obj diff --git a/src/interfaces.jl b/src/interfaces.jl index 1d071f5..12de25b 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -106,11 +106,15 @@ function cudss_set(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint}) where T end function cudss_set(matrix::CudssMatrix{T}, v::Vector{CuVector{T}}) where T <: BlasFloat - cudssMatrixSetBatchValues(matrix, v) + vptrs = unsafe_batch(v) + cudssMatrixSetBatchValues(matrix, vptrs) + # unsafe_free!(vptrs) end -function cudss_set(matrix::CudssMatrix{T}, v::Vector{CuMatrix{T}}) where T <: BlasFloat - cudssMatrixSetBatchValues(matrix, v) +function cudss_set(matrix::CudssMatrix{T}, A::Vector{CuMatrix{T}}) where T <: BlasFloat + Aptrs = unsafe_batch(A) + cudssMatrixSetBatchValues(matrix, Aptrs) + # unsafe_free!(Aptrs) end function cudss_set(matrix::CudssMatrix{T}, A::Vector{CuSparseMatrixCSR{T,Cint}}) where T <: BlasFloat