Skip to content

Commit

Permalink
Handle pointer of pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Dec 12, 2024
1 parent c539c4d commit ac822ee
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
3 changes: 2 additions & 1 deletion src/CUDSS.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module CUDSS

using CUDA, CUDA.APIUtils, CUDA.CUSPARSE
using CUDA, CUDA.APIUtils, CUDA.CUSPARSE, CUDA.CUBLAS
using CUDSS_jll
using LinearAlgebra
using SparseArrays
Expand All @@ -14,6 +14,7 @@ else
end

import CUDA: @checked, libraryPropertyType, cudaDataType, initialize_context, retry_reclaim, CUstream, @gcsafe_ccall
import CUDA.CUBLAS: unsafe_batch
import LinearAlgebra: lu, lu!, ldlt, ldlt!, cholesky, cholesky!, ldiv!, BlasFloat, BlasReal, checksquare
import Base: \

Expand Down
10 changes: 7 additions & 3 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ mutable struct CudssMatrix{T}
nrows = [length(vᵢ) for vᵢ in v]
ncols = [1 for i = 1:nbatch]
ld = nrows
cudssMatrixCreateBatchDn(matrix_ref, nbatch, nrows, ncols, ld, v, T, 'C')
vptrs = unsafe_batch(v)
cudssMatrixCreateBatchDn(matrix_ref, nbatch, nrows, ncols, ld, vptrs, T, 'C')
# unsafe_free!(vptrs)
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
Expand All @@ -107,11 +109,13 @@ mutable struct CudssMatrix{T}
nrows = [size(Aᵢ,1) for Aᵢ in A]
ncols = [size(Aᵢ,2) for Aᵢ in A]
ld = nrows
Aptrs = unsafe_batch(A)
if transposed
cudssMatrixCreateBatchDn(matrix_ref, nbatch, ncols, nrows, ld, A, T, 'R')
cudssMatrixCreateBatchDn(matrix_ref, nbatch, ncols, nrows, ld, Aptrs, T, 'R')
else
cudssMatrixCreateBatchDn(matrix_ref, nbatch, nrows, ncols, ld, A, T, 'C')
cudssMatrixCreateBatchDn(matrix_ref, nbatch, nrows, ncols, ld, Aptrs, T, 'C')
end
# unsafe_free!(Aptrs)
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
Expand Down
10 changes: 7 additions & 3 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,15 @@ function cudss_set(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint}) where T
end

function cudss_set(matrix::CudssMatrix{T}, v::Vector{CuVector{T}}) where T <: BlasFloat
cudssMatrixSetBatchValues(matrix, v)
vptrs = unsafe_batch(v)
cudssMatrixSetBatchValues(matrix, vptrs)
# unsafe_free!(vptrs)
end

function cudss_set(matrix::CudssMatrix{T}, v::Vector{CuMatrix{T}}) where T <: BlasFloat
cudssMatrixSetBatchValues(matrix, v)
function cudss_set(matrix::CudssMatrix{T}, A::Vector{CuMatrix{T}}) where T <: BlasFloat
Aptrs = unsafe_batch(A)
cudssMatrixSetBatchValues(matrix, Aptrs)
# unsafe_free!(Aptrs)
end

function cudss_set(matrix::CudssMatrix{T}, A::Vector{CuSparseMatrixCSR{T,Cint}}) where T <: BlasFloat
Expand Down

0 comments on commit ac822ee

Please sign in to comment.