Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support cuDSS v0.4.0 #63

Merged
merged 15 commits into from
Dec 12, 2024
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[compat]
CEnum = "0.4, 0.5"
CEnum = "0.5"
CUDA = "5.4.0"
CUDSS_jll = "0.3.0"
CUDSS_jll = "0.4.0"
julia = "1.6"
LinearAlgebra = "1.6"
SparseArrays = "1.6"
Expand Down
4 changes: 2 additions & 2 deletions gen/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"

[compat]
CUDA_SDK_jll = "12.5.1"
CUDSS_jll = "0.3.0"
CUDA_SDK_jll = "12.6.3"
CUDSS_jll = "0.4.0"
julia = "1.6"
27 changes: 27 additions & 0 deletions gen/cudss.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,53 @@ needs_context = false
[api.cudssMatrixCreateDn.argtypes]
5 = "CuPtr{Cvoid}"

[api.cudssMatrixCreateBatchDn.argtypes]
6 = "CuPtr{Ptr{Cvoid}}"

[api.cudssMatrixCreateCsr.argtypes]
5 = "CuPtr{Cvoid}"
6 = "CuPtr{Cvoid}"
7 = "CuPtr{Cvoid}"
8 = "CuPtr{Cvoid}"

[api.cudssMatrixCreateBatchCsr.argtypes]
6 = "CuPtr{Ptr{Cvoid}}"
7 = "CuPtr{Ptr{Cvoid}}"
8 = "CuPtr{Ptr{Cvoid}}"
9 = "CuPtr{Ptr{Cvoid}}"

[api.cudssMatrixGetDn.argtypes]
5 = "Ptr{CuPtr{Cvoid}}"

[api.cudssMatrixGetBatchDn.argtypes]
6 = "Ptr{CuPtr{Ptr{Cvoid}}}"

[api.cudssMatrixGetCsr.argtypes]
5 = "Ptr{CuPtr{Cvoid}}"
6 = "Ptr{CuPtr{Cvoid}}"
7 = "Ptr{CuPtr{Cvoid}}"
8 = "Ptr{CuPtr{Cvoid}}"

[api.cudssMatrixGetBatchCsr.argtypes]
6 = "Ptr{CuPtr{Ptr{Cvoid}}}"
7 = "Ptr{CuPtr{Ptr{Cvoid}}}"
8 = "Ptr{CuPtr{Ptr{Cvoid}}}"
9 = "Ptr{CuPtr{Ptr{Cvoid}}}"

[api.cudssMatrixSetValues.argtypes]
2 = "CuPtr{Cvoid}"

[api.cudssMatrixSetBatchValues.argtypes]
2 = "CuPtr{Ptr{Cvoid}}"

[api.cudssMatrixSetCsrPointers.argtypes]
2 = "CuPtr{Cvoid}"
3 = "CuPtr{Cvoid}"
4 = "CuPtr{Cvoid}"
5 = "CuPtr{Cvoid}"

[api.cudssMatrixSetBatchCsrPointers.argtypes]
2 = "CuPtr{Ptr{Cvoid}}"
3 = "CuPtr{Ptr{Cvoid}}"
4 = "CuPtr{Ptr{Cvoid}}"
5 = "CuPtr{Ptr{Cvoid}}"
3 changes: 2 additions & 1 deletion gen/libcudss_prologue.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# CUDSS uses CUDA runtime objects, which are compatible with our driver usage
const cudaStream_t = CUstream

const cudaDataType_t = cudaDataType
const CUPTR_C_NULL = CuPtr{Ptr{Cvoid}}(0)

3 changes: 2 additions & 1 deletion src/CUDSS.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module CUDSS

using CUDA, CUDA.APIUtils, CUDA.CUSPARSE
using CUDA, CUDA.APIUtils, CUDA.CUSPARSE, CUDA.CUBLAS
using CUDSS_jll
using LinearAlgebra
using SparseArrays
Expand All @@ -14,6 +14,7 @@ else
end

import CUDA: @checked, libraryPropertyType, cudaDataType, initialize_context, retry_reclaim, CUstream, @gcsafe_ccall
import CUDA.CUBLAS: unsafe_batch
import LinearAlgebra: lu, lu!, ldlt, ldlt!, cholesky, cholesky!, ldiv!, BlasFloat, BlasReal, checksquare
import Base: \

Expand Down
57 changes: 56 additions & 1 deletion src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ export CudssMatrix, CudssData, CudssConfig
matrix = CudssMatrix(v::CuVector{T})
matrix = CudssMatrix(A::CuMatrix{T})
matrix = CudssMatrix(A::CuSparseMatrixCSR{T,Cint}, struture::String, view::Char; index::Char='O')
matrix = CudssMatrix(v::Vector{CuVector{T}})
matrix = CudssMatrix(A::Vector{CuMatrix{T}})
matrix = CudssMatrix(A::Vector{CuSparseMatrixCSR{T,Cint}}, struture::String, view::Char; index::Char='O')

The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.

Expand Down Expand Up @@ -79,12 +82,64 @@ mutable struct CudssMatrix{T}
m,n = size(A)
matrix_ref = Ref{cudssMatrix_t}()
cudssMatrixCreateCsr(matrix_ref, m, n, nnz(A), A.rowPtr, CU_NULL,
A.colVal, A.nzVal, eltype(A.rowPtr), T, structure,
A.colVal, A.nzVal, Cint, T, structure,
view, index)
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end

function CudssMatrix(v::Vector{CuVector{T}}) where T <: BlasFloat
matrix_ref = Ref{cudssMatrix_t}()
nbatch = length(v)
nrows = [length(vᵢ) for vᵢ in v]
ncols = [1 for i = 1:nbatch]
ld = nrows
vptrs = unsafe_batch(v)
cudssMatrixCreateBatchDn(matrix_ref, nbatch, nrows, ncols, ld, vptrs, T, 'C')
# unsafe_free!(vptrs)
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end

function CudssMatrix(A::Vector{CuMatrix{T}}; transposed::Bool=false) where T <: BlasFloat
matrix_ref = Ref{cudssMatrix_t}()
nbatch = length(A)
nrows = [size(Aᵢ,1) for Aᵢ in A]
ncols = [size(Aᵢ,2) for Aᵢ in A]
ld = nrows
Aptrs = unsafe_batch(A)
if transposed
cudssMatrixCreateBatchDn(matrix_ref, nbatch, ncols, nrows, ld, Aptrs, T, 'R')
else
cudssMatrixCreateBatchDn(matrix_ref, nbatch, nrows, ncols, ld, Aptrs, T, 'C')
end
# unsafe_free!(Aptrs)
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end

function CudssMatrix(A::Vector{CuSparseMatrixCSR{T,Cint}}, structure::String, view::Char; index::Char='O') where T <: BlasFloat
matrix_ref = Ref{cudssMatrix_t}()
nbatch = length(A)
nrows = [size(Aᵢ,1) for Aᵢ in A]
ncols = [size(Aᵢ,2) for Aᵢ in A]
nnzA = [nnz(Aᵢ) for Aᵢ in A]
rowPtrs = [pointer(Aᵢ.rowPtr) for Aᵢ in A] |> CuVector
colVals = [pointer(Aᵢ.colVal) for Aᵢ in A] |> CuVector
nzVals = [pointer(Aᵢ.nzVal) for Aᵢ in A] |> CuVector
cudssMatrixCreateBatchCsr(matrix_ref, nbatch, nrows, ncols, nnzA, rowPtrs,
CUPTR_C_NULL, colVals, nzVals, Cint, T, structure,
view, index)
# unsafe_free!(rowsPtrs)
# unsafe_free!(colVals)
# unsafe_free!(nzVals)
obj = new{T}(T, matrix_ref[])
finalizer(cudssMatrixDestroy, obj)
obj
end
end

Base.unsafe_convert(::Type{cudssMatrix_t}, matrix::CudssMatrix) = matrix.matrix
Expand Down
59 changes: 56 additions & 3 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ export CudssSolver, cudss, cudss_set, cudss_get

"""
solver = CudssSolver(A::CuSparseMatrixCSR{T,Cint}, structure::String, view::Char; index::Char='O')
solver = CudssSolver(A::Vector{CuSparseMatrixCSR{T,Cint}}, structure::String, view::Char; index::Char='O')
solver = CudssSolver(matrix::CudssMatrix{T}, config::CudssConfig, data::CudssData)

The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
Expand Down Expand Up @@ -42,13 +43,24 @@ mutable struct CudssSolver{T}
data = CudssData()
return new{T}(matrix, config, data)
end

function CudssSolver(A::Vector{CuSparseMatrixCSR{T,Cint}}, structure::String, view::Char; index::Char='O') where T <: BlasFloat
matrix = CudssMatrix(A, structure, view; index)
config = CudssConfig()
data = CudssData()
return new{T}(matrix, config, data)
end
end

"""
cudss_set(matrix::CudssMatrix{T}, v::CuVector{T})
cudss_set(matrix::CudssMatrix{T}, A::CuMatrix{T})
cudss_set(matrix::CudssMatrix{T}, A::CuSparseMatrixCSR{T,Cint})
cudss_set(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint})
cudss_set(matrix::CudssMatrix{T}, v::Vector{CuVector{T}})
cudss_set(matrix::CudssMatrix{T}, A::Vector{CuMatrix{T}})
cudss_set(matrix::CudssMatrix{T}, A::Vector{CuSparseMatrixCSR{T,Cint}})
cudss_set(solver::CudssSolver{T}, A::Vector{CuSparseMatrixCSR{T,Cint}})
cudss_set(solver::CudssSolver, parameter::String, value)
cudss_set(config::CudssConfig, parameter::String, value)
cudss_set(data::CudssData, parameter::String, value)
Expand Down Expand Up @@ -93,6 +105,32 @@ function cudss_set(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint}) where T
cudss_set(solver.matrix, A)
end

function cudss_set(matrix::CudssMatrix{T}, v::Vector{CuVector{T}}) where T <: BlasFloat
vptrs = unsafe_batch(v)
cudssMatrixSetBatchValues(matrix, vptrs)
# unsafe_free!(vptrs)
end

function cudss_set(matrix::CudssMatrix{T}, A::Vector{CuMatrix{T}}) where T <: BlasFloat
Aptrs = unsafe_batch(A)
cudssMatrixSetBatchValues(matrix, Aptrs)
# unsafe_free!(Aptrs)
end

function cudss_set(matrix::CudssMatrix{T}, A::Vector{CuSparseMatrixCSR{T,Cint}}) where T <: BlasFloat
rowsPtrs = [pointer(Aᵢ.rowPtr) for Aᵢ in A] |> CuVector
colVals = [pointer(Aᵢ.colVal) for Aᵢ in A] |> CuVector
nzVals = [pointer(Aᵢ.nzVal) for Aᵢ in A] |> CuVector
cudssMatrixSetBatchCsrPointers(matrix, rowsPtrs, CUPTR_C_NULL, colVals, nzVals)
# unsafe_free!(rowsPtrs)
# unsafe_free!(colVals)
# unsafe_free!(nzVals)
end

function cudss_set(solver::CudssSolver{T}, A::Vector{CuSparseMatrixCSR{T,Cint}}) where T <: BlasFloat
cudss_set(solver.matrix, A)
end

function cudss_set(solver::CudssSolver, parameter::String, value)
if parameter ∈ CUDSS_CONFIG_PARAMETERS
cudss_set(solver.config, parameter, value)
Expand Down Expand Up @@ -150,9 +188,10 @@ The available data parameters are:
- `"perm_row"`: Final row permutation (which includes effects of both reordering and pivoting);
- `"perm_col"`: Final column permutation (which includes effects of both reordering and pivoting);
- `"diag"`: Diagonal of the factorized matrix;
- `"hybrid_device_memory_min"`: Minimal amount of device memory (number of bytes) required in the hybrid memory mode.
- `"hybrid_device_memory_min"`: Minimal amount of device memory (number of bytes) required in the hybrid memory mode;
- `"memory_estimates"`: Memory estimates (in bytes) for host and device memory required for the chosen memory mode.

The data parameters `"info"`, `"lu_nnz"`, `"perm_reorder_row"`, `"perm_reorder_col"` and `"hybrid_device_memory_min"` require the phase `"analyse"` performed by [`cudss`](@ref).
The data parameters `"info"`, `"lu_nnz"`, `"perm_reorder_row"`, `"perm_reorder_col"`, `"hybrid_device_memory_min"` and `"memory_estimates"` require the phase `"analyse"` performed by [`cudss`](@ref).
The data parameters `"npivots"`, `"inertia"` and `"diag"` require the phases `"analyse"` and `"factorization"` performed by [`cudss`](@ref).
The data parameters `"perm_row"` and `"perm_col"` are available but not yet functional.
"""
Expand All @@ -173,7 +212,7 @@ function cudss_get(data::CudssData, parameter::String)
if (parameter == "user_perm") || (parameter == "comm")
throw(ArgumentError("The data parameter \"$parameter\" cannot be retrieved."))
end
if (parameter == "perm_reorder_row") || (parameter == "perm_reorder_col") || (parameter == "perm_row") || (parameter == "perm_col") || (parameter == "diag")
if (parameter == "perm_reorder_row") || (parameter == "perm_reorder_col") || (parameter == "perm_row") || (parameter == "perm_col") || (parameter == "diag") || (parameter == "memory_estimates")
throw(ArgumentError("The data parameter \"$parameter\" is not supported by CUDSS.jl."))
end
type = CUDSS_TYPES[parameter]
Expand All @@ -197,6 +236,8 @@ end
"""
cudss(phase::String, solver::CudssSolver{T}, x::CuVector{T}, b::CuVector{T})
cudss(phase::String, solver::CudssSolver{T}, X::CuMatrix{T}, B::CuMatrix{T})
cudss(phase::String, solver::CudssSolver{T}, x::Vector{CuVector{T}}, b::Vector{CuVector{T}})
cudss(phase::String, solver::CudssSolver{T}, X::Vector{CuMatrix{T}}, B::Vector{CuMatrix{T}})
cudss(phase::String, solver::CudssSolver{T}, X::CudssMatrix{T}, B::CudssMatrix{T})

The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
Expand All @@ -221,3 +262,15 @@ function cudss(phase::String, solver::CudssSolver{T}, X::CuMatrix{T}, B::CuMatri
rhs = CudssMatrix(B)
cudss(phase, solver, solution, rhs)
end

function cudss(phase::String, solver::CudssSolver{T}, x::Vector{CuVector{T}}, b::Vector{CuVector{T}}) where T <: BlasFloat
solution = CudssMatrix(x)
rhs = CudssMatrix(b)
cudss(phase, solver, solution, rhs)
end

function cudss(phase::String, solver::CudssSolver{T}, X::Vector{CuMatrix{T}}, B::Vector{CuMatrix{T}}) where T <: BlasFloat
solution = CudssMatrix(X)
rhs = CudssMatrix(B)
cudss(phase, solver, solution, rhs)
end
81 changes: 80 additions & 1 deletion src/libcudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ using CEnum

# CUDSS uses CUDA runtime objects, which are compatible with our driver usage
const cudaStream_t = CUstream

const cudaDataType_t = cudaDataType
const CUPTR_C_NULL = CuPtr{Ptr{Cvoid}}(0)

@cenum cudssOpType_t::UInt32 begin
CUDSS_SUM = 0
Expand Down Expand Up @@ -70,6 +70,7 @@ end
CUDSS_DATA_USER_PERM = 9
CUDSS_DATA_HYBRID_DEVICE_MEMORY_MIN = 10
CUDSS_DATA_COMM = 11
CUDSS_DATA_MEMORY_ESTIMATES = 12
end

@cenum cudssPhase_t::UInt32 begin
Expand Down Expand Up @@ -255,6 +256,35 @@ end
indexBase::cudssIndexBase_t)::cudssStatus_t
end

@checked function cudssMatrixCreateBatchDn(matrix, batchCount, nrows, ncols, ld, values,
valueType, layout)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixCreateBatchDn(matrix::Ptr{cudssMatrix_t},
batchCount::Int64, nrows::Ptr{Cvoid},
ncols::Ptr{Cvoid}, ld::Ptr{Cvoid},
values::CuPtr{Ptr{Cvoid}},
valueType::cudaDataType_t,
layout::cudssLayout_t)::cudssStatus_t
end

@checked function cudssMatrixCreateBatchCsr(matrix, batchCount, nrows, ncols, nnz, rowStart,
rowEnd, colIndices, values, indexType,
valueType, mtype, mview, indexBase)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixCreateBatchCsr(matrix::Ptr{cudssMatrix_t},
batchCount::Int64, nrows::Ptr{Cvoid},
ncols::Ptr{Cvoid}, nnz::Ptr{Cvoid},
rowStart::CuPtr{Ptr{Cvoid}},
rowEnd::CuPtr{Ptr{Cvoid}},
colIndices::CuPtr{Ptr{Cvoid}},
values::CuPtr{Ptr{Cvoid}},
indexType::cudaDataType_t,
valueType::cudaDataType_t,
mtype::cudssMatrixType_t,
mview::cudssMatrixViewType_t,
indexBase::cudssIndexBase_t)::cudssStatus_t
end

@checked function cudssMatrixDestroy(matrix)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixDestroy(matrix::cudssMatrix_t)::cudssStatus_t
Expand Down Expand Up @@ -300,6 +330,55 @@ end
values::CuPtr{Cvoid})::cudssStatus_t
end

@checked function cudssMatrixGetBatchDn(matrix, batchCount, nrows, ncols, ld, values, type,
layout)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixGetBatchDn(matrix::cudssMatrix_t,
batchCount::Ptr{Int64},
nrows::Ptr{Ptr{Cvoid}},
ncols::Ptr{Ptr{Cvoid}},
ld::Ptr{Ptr{Cvoid}},
values::Ptr{CuPtr{Ptr{Cvoid}}},
type::Ptr{cudaDataType_t},
layout::Ptr{cudssLayout_t})::cudssStatus_t
end

@checked function cudssMatrixGetBatchCsr(matrix, batchCount, nrows, ncols, nnz, rowStart,
rowEnd, colIndices, values, indexType, valueType,
mtype, mview, indexBase)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixGetBatchCsr(matrix::cudssMatrix_t,
batchCount::Ptr{Int64},
nrows::Ptr{Ptr{Cvoid}},
ncols::Ptr{Ptr{Cvoid}},
nnz::Ptr{Ptr{Cvoid}},
rowStart::Ptr{CuPtr{Ptr{Cvoid}}},
rowEnd::Ptr{CuPtr{Ptr{Cvoid}}},
colIndices::Ptr{CuPtr{Ptr{Cvoid}}},
values::Ptr{CuPtr{Ptr{Cvoid}}},
indexType::Ptr{cudaDataType_t},
valueType::Ptr{cudaDataType_t},
mtype::Ptr{cudssMatrixType_t},
mview::Ptr{cudssMatrixViewType_t},
indexBase::Ptr{cudssIndexBase_t})::cudssStatus_t
end

@checked function cudssMatrixSetBatchValues(matrix, values)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixSetBatchValues(matrix::cudssMatrix_t,
values::CuPtr{Ptr{Cvoid}})::cudssStatus_t
end

@checked function cudssMatrixSetBatchCsrPointers(matrix, rowOffsets, rowEnd, colIndices,
values)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixSetBatchCsrPointers(matrix::cudssMatrix_t,
rowOffsets::CuPtr{Ptr{Cvoid}},
rowEnd::CuPtr{Ptr{Cvoid}},
colIndices::CuPtr{Ptr{Cvoid}},
values::CuPtr{Ptr{Cvoid}})::cudssStatus_t
end

@checked function cudssMatrixGetFormat(matrix, format)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixGetFormat(matrix::cudssMatrix_t,
Expand Down
Loading
Loading