Skip to content

Commit

Permalink
Fix more issues...
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Dec 12, 2024
1 parent a7f0ed6 commit 9e55231
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 27 deletions.
30 changes: 15 additions & 15 deletions gen/cudss.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ needs_context = false
5 = "CuPtr{Cvoid}"

[api.cudssMatrixCreateBatchDn.argtypes]
6 = "Ptr{CuPtr{Cvoid}}"
6 = "CuPtr{Ptr{Cvoid}}"

[api.cudssMatrixCreateCsr.argtypes]
5 = "CuPtr{Cvoid}"
Expand All @@ -43,16 +43,16 @@ needs_context = false
8 = "CuPtr{Cvoid}"

[api.cudssMatrixCreateBatchCsr.argtypes]
6 = "Ptr{CuPtr{Cvoid}}"
7 = "Ptr{CuPtr{Cvoid}}"
8 = "Ptr{CuPtr{Cvoid}}"
9 = "Ptr{CuPtr{Cvoid}}"
6 = "CuPtr{Ptr{Cvoid}}"
7 = "CuPtr{Ptr{Cvoid}}"
8 = "CuPtr{Ptr{Cvoid}}"
9 = "CuPtr{Ptr{Cvoid}}"

[api.cudssMatrixGetDn.argtypes]
5 = "Ptr{CuPtr{Cvoid}}"

[api.cudssMatrixGetBatchDn.argtypes]
6 = "Ptr{Ptr{CuPtr{Cvoid}}}"
6 = "Ptr{CuPtr{Ptr{Cvoid}}}"

[api.cudssMatrixGetCsr.argtypes]
5 = "Ptr{CuPtr{Cvoid}}"
Expand All @@ -61,16 +61,16 @@ needs_context = false
8 = "Ptr{CuPtr{Cvoid}}"

[api.cudssMatrixGetBatchCsr.argtypes]
6 = "Ptr{Ptr{CuPtr{Cvoid}}}"
7 = "Ptr{Ptr{CuPtr{Cvoid}}}"
8 = "Ptr{Ptr{CuPtr{Cvoid}}}"
9 = "Ptr{Ptr{CuPtr{Cvoid}}}"
6 = "Ptr{CuPtr{Ptr{Cvoid}}}"
7 = "Ptr{CuPtr{Ptr{Cvoid}}}"
8 = "Ptr{CuPtr{Ptr{Cvoid}}}"
9 = "Ptr{CuPtr{Ptr{Cvoid}}}"

[api.cudssMatrixSetValues.argtypes]
2 = "CuPtr{Cvoid}"

[api.cudssMatrixSetBatchValues.argtypes]
2 = "Ptr{CuPtr{Cvoid}}"
2 = "CuPtr{Ptr{Cvoid}}"

[api.cudssMatrixSetCsrPointers.argtypes]
2 = "CuPtr{Cvoid}"
Expand All @@ -79,7 +79,7 @@ needs_context = false
5 = "CuPtr{Cvoid}"

[api.cudssMatrixSetBatchCsrPointers.argtypes]
2 = "Ptr{CuPtr{Cvoid}}"
3 = "Ptr{CuPtr{Cvoid}}"
4 = "Ptr{CuPtr{Cvoid}}"
5 = "Ptr{CuPtr{Cvoid}}"
2 = "CuPtr{Ptr{Cvoid}}"
3 = "CuPtr{Ptr{Cvoid}}"
4 = "CuPtr{Ptr{Cvoid}}"
5 = "CuPtr{Ptr{Cvoid}}"
2 changes: 1 addition & 1 deletion gen/libcudss_prologue.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# CUDSS uses CUDA runtime objects, which are compatible with our driver usage
const cudaStream_t = CUstream

const cudaDataType_t = cudaDataType
const CUPTR_C_NULL = CuPtr{Ptr{Cvoid}}(0)
12 changes: 7 additions & 5 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,15 @@ mutable struct CudssMatrix{T}
nrows = [size(Aᵢ,1) for Aᵢ in A]
ncols = [size(Aᵢ,2) for Aᵢ in A]
nnzA = [nnz(Aᵢ) for Aᵢ in A]
rowPtrs = [Aᵢ.rowPtr for Aᵢ in A]
colVals = [Aᵢ.colVal for Aᵢ in A]
nzVals = [Aᵢ.nzVal for Aᵢ in A]
PTR_CU_NULL = Ptr{CuPtr{Cvoid}}()
rowsPtrs = [pointer(Aᵢ.rowPtr) for Aᵢ in A] |> CuVector
colVals = [pointer(Aᵢ.colVal) for Aᵢ in A] |> CuVector
nzVals = [pointer(Aᵢ.nzVal) for Aᵢ in A] |> CuVector
cudssMatrixCreateBatchCsr(matrix_ref, nbatch, nrows, ncols, nnzA, rowPtrs,
PTR_CU_NULL, colVals, nzVals, Cint, T, structure,
CUPTR_C_NULL, colVals, nzVals, Cint, T, structure,
view, index)
# unsafe_free!(rowsPtrs)
# unsafe_free!(colVals)
# unsafe_free!(nzVals)
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
Expand Down
12 changes: 7 additions & 5 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,13 @@ function cudss_set(matrix::CudssMatrix{T}, v::Vector{CuMatrix{T}}) where T <: Bl
end

function cudss_set(matrix::CudssMatrix{T}, A::Vector{CuSparseMatrixCSR{T,Cint}}) where T <: BlasFloat
rowsPtrs = [A.rowPtr for Aᵢ in A]
colVals = [A.colVal for Aᵢ in A]
nzVals = [A.nzVal for Aᵢ in A]
PTR_CU_NULL = Ptr{CuPtr{Cvoid}}()
cudssMatrixSetBatchCsrPointers(matrix, rowsPtrs, PTR_CU_NULL, colVals, nzVals)
rowsPtrs = [pointer(Aᵢ.rowPtr) for Aᵢ in A] |> CuVector
colVals = [pointer(Aᵢ.colVal) for Aᵢ in A] |> CuVector
nzVals = [pointer(Aᵢ.nzVal) for Aᵢ in A] |> CuVector
cudssMatrixSetBatchCsrPointers(matrix, rowsPtrs, CUPTR_C_NULL, colVals, nzVals)
# unsafe_free!(rowsPtrs)
# unsafe_free!(colVals)
# unsafe_free!(nzVals)
end

function cudss_set(solver::CudssSolver{T}, A::Vector{CuSparseMatrixCSR{T,Cint}}) where T <: BlasFloat
Expand Down
2 changes: 1 addition & 1 deletion src/libcudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ using CEnum

# CUDSS uses CUDA runtime objects, which are compatible with our driver usage
const cudaStream_t = CUstream

const cudaDataType_t = cudaDataType
const CUPTR_C_NULL = CuPtr{Ptr{Cvoid}}(0)

@cenum cudssOpType_t::UInt32 begin
CUDSS_SUM = 0
Expand Down

0 comments on commit 9e55231

Please sign in to comment.