Skip to content

Commit

Permalink
Upgrade the Julia interface for cuDSS 0.2.1
Browse files Browse the repository at this point in the history
  • Loading branch information
Xx-Alexis-xX committed Apr 19, 2024
1 parent 35f75ed commit 9a454fd
Show file tree
Hide file tree
Showing 8 changed files with 62 additions and 32 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "CUDSS"
uuid = "45b445bb-4962-46a0-9369-b4df9d0f772e"
authors = ["Alexis Montoison <[email protected]>"]
version = "0.1.4"
version = "0.2.0"

[deps]
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Expand All @@ -13,7 +13,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
[compat]
CEnum = "0.4, 0.5"
CUDA = "5"
CUDSS_jll = "0.1.0"
CUDSS_jll = "0.2.1"
julia = "1.6"
LinearAlgebra = "1.6"
SparseArrays = "1.6"
Expand Down
1 change: 0 additions & 1 deletion src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ function LinearAlgebra.ldlt(A::CuSparseMatrixCSR{T,Cint}; view::Char='F', check
n = checksquare(A)
structure = T <: Real ? "S" : "H"
solver = CudssSolver(A, structure, view)
(T <: Complex) && cudss_set(solver, "pivot_type", 'N')
x = CudssMatrix(T, n)
b = CudssMatrix(T, n)
cudss("analysis", solver, x, b)
Expand Down
8 changes: 5 additions & 3 deletions src/helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,14 @@ Base.unsafe_convert(::Type{cudssMatrix_t}, matrix::CudssMatrix) = matrix.matrix
`CudssData` holds internal data (e.g., LU factors arrays).
"""
mutable struct CudssData
handle::cudssHandle_t
data::cudssData_t

function CudssData()
data_ref = Ref{cudssData_t}()
cudssDataCreate(handle(), data_ref)
obj = new(data_ref[])
cudss_handle = handle()
cudssDataCreate(cudss_handle, data_ref)
obj = new(cudss_handle, data_ref[])
finalizer(cudssDataDestroy, obj)
obj
end
Expand All @@ -112,7 +114,7 @@ end
Base.unsafe_convert(::Type{cudssData_t}, data::CudssData) = data.data

function cudssDataDestroy(data::CudssData)
cudssDataDestroy(handle(), data)
cudssDataDestroy(data.handle, data)
end

## Configuration
Expand Down
15 changes: 8 additions & 7 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function cudss_set(data::CudssData, parameter::String, value)
(parameter == "user_perm") || throw(ArgumentError("Only the data parameter \"user_perm\" can be set."))
(value isa Vector{Cint} || value isa CuVector{Cint}) || throw(ArgumentError("The permutation is neither a Vector{Cint} nor a CuVector{Cint}."))
nbytes = sizeof(value)
cudssDataSet(handle(), data, parameter, value, nbytes)
cudssDataSet(data.handle, data, parameter, value, nbytes)
end

function cudss_set(config::CudssConfig, parameter::String, value)
Expand Down Expand Up @@ -138,12 +138,13 @@ The available data parameters are:
- `"lu_nnz"`: Number of non-zero entries in LU factors;
- `"npivots"`: Number of pivots encountered during factorization;
- `"inertia"`: Tuple of positive and negative indices of inertia for symmetric and hermitian non positive-definite matrix types;
- `"perm_reorder"`: Reordering permutation;
- `"perm_reorder_row"`: Reordering permutation for the rows;
- `"perm_reorder_col"`: Reordering permutation for the columns;
- `"perm_row"`: Final row permutation (which includes effects of both reordering and pivoting);
- `"perm_col"`: Final column permutation (which includes effects of both reordering and pivoting);
- `"diag"`: Diagonal of the factorized matrix.
The data parameters `"info"`, `"lu_nnz"` and `"perm_reorder"` require the phase `"analyse"` performed by [`cudss`](@ref).
The data parameters `"info"`, `"lu_nnz"`, `"perm_reorder_row"` and `"perm_reorder_col"` require the phase `"analyse"` performed by [`cudss`](@ref).
The data parameters `"npivots"`, `"inertia"` and `"diag"` require the phases `"analyse"` and `"factorization"` performed by [`cudss`](@ref).
The data parameters `"perm_row"` and `"perm_col"` are available but not yet functional.
"""
Expand All @@ -162,13 +163,13 @@ end
function cudss_get(data::CudssData, parameter::String)
(parameter CUDSS_DATA_PARAMETERS) || throw(ArgumentError("Unknown data parameter $parameter."))
(parameter == "user_perm") && throw(ArgumentError("The data parameter \"user_perm\" cannot be retrieved."))
if (parameter == "perm_reorder") || (parameter == "perm_row") || (parameter == "perm_col") || (parameter == "diag")
if (parameter == "perm_reorder_row") || (parameter == "perm_reorder_col") || (parameter == "perm_row") || (parameter == "perm_col") || (parameter == "diag")
throw(ArgumentError("The data parameter \"$parameter\" is not supported by CUDSS.jl."))
end
type = CUDSS_TYPES[parameter]
val = Ref{type}()
nbytes = sizeof(val)
nbytes_written = Ref{Cint}()
nbytes_written = Ref{Csize_t}()
cudssDataGet(handle(), data, parameter, val, nbytes, nbytes_written)
return val[]
end
Expand All @@ -178,7 +179,7 @@ function cudss_get(config::CudssConfig, parameter::String)
type = CUDSS_TYPES[parameter]
val = Ref{type}()
nbytes = sizeof(val)
nbytes_written = Ref{Cint}()
nbytes_written = Ref{Csize_t}()
cudssConfigGet(config, parameter, val, nbytes, nbytes_written)
return val[]
end
Expand All @@ -196,7 +197,7 @@ The phases `"solve_fwd"`, `"solve_diag"` and `"solve_bwd"` are available but not
function cudss end

function cudss(phase::String, solver::CudssSolver{T}, X::CudssMatrix{T}, B::CudssMatrix{T}) where T <: BlasFloat
cudssExecute(handle(), phase, solver.config, solver.data, solver.matrix, X, B)
cudssExecute(solver.data.handle, phase, solver.config, solver.data, solver.matrix, X, B)
end

function cudss(phase::String, solver::CudssSolver{T}, x::CuVector{T}, b::CuVector{T}) where T <: BlasFloat
Expand Down
40 changes: 30 additions & 10 deletions src/libcudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,12 @@ end
CUDSS_DATA_LU_NNZ = 1
CUDSS_DATA_NPIVOTS = 2
CUDSS_DATA_INERTIA = 3
CUDSS_DATA_PERM_REORDER = 4
CUDSS_DATA_PERM_ROW = 5
CUDSS_DATA_PERM_COL = 6
CUDSS_DATA_DIAG = 7
CUDSS_DATA_USER_PERM = 8
CUDSS_DATA_PERM_REORDER_ROW = 4
CUDSS_DATA_PERM_REORDER_COL = 5
CUDSS_DATA_PERM_ROW = 6
CUDSS_DATA_PERM_COL = 7
CUDSS_DATA_DIAG = 8
CUDSS_DATA_USER_PERM = 9
end

@cenum cudssPhase_t::UInt32 begin
Expand Down Expand Up @@ -111,31 +112,38 @@ end
CUDSS_MFORMAT_CSR = 1
end

struct cudssDeviceMemHandler_t
ctx::Ptr{Cvoid}
device_alloc::Ptr{Cvoid}
device_free::Ptr{Cvoid}
name::NTuple{64,Cchar}
end

@checked function cudssConfigSet(config, param, value, sizeInBytes)
initialize_context()
@ccall libcudss.cudssConfigSet(config::cudssConfig_t, param::cudssConfigParam_t,
value::Ptr{Cvoid}, sizeInBytes::Cint)::cudssStatus_t
value::Ptr{Cvoid}, sizeInBytes::Csize_t)::cudssStatus_t
end

@checked function cudssConfigGet(config, param, value, sizeInBytes, sizeWritten)
initialize_context()
@ccall libcudss.cudssConfigGet(config::cudssConfig_t, param::cudssConfigParam_t,
value::Ptr{Cvoid}, sizeInBytes::Cint,
sizeWritten::Ptr{Cint})::cudssStatus_t
value::Ptr{Cvoid}, sizeInBytes::Csize_t,
sizeWritten::Ptr{Csize_t})::cudssStatus_t
end

@checked function cudssDataSet(handle, data, param, value, sizeInBytes)
initialize_context()
@ccall libcudss.cudssDataSet(handle::cudssHandle_t, data::cudssData_t,
param::cudssDataParam_t, value::PtrOrCuPtr{Cvoid},
sizeInBytes::Cint)::cudssStatus_t
sizeInBytes::Csize_t)::cudssStatus_t
end

@checked function cudssDataGet(handle, data, param, value, sizeInBytes, sizeWritten)
initialize_context()
@ccall libcudss.cudssDataGet(handle::cudssHandle_t, data::cudssData_t,
param::cudssDataParam_t, value::PtrOrCuPtr{Cvoid},
sizeInBytes::Cint, sizeWritten::Ptr{Cint})::cudssStatus_t
sizeInBytes::Csize_t, sizeWritten::Ptr{Csize_t})::cudssStatus_t
end

@checked function cudssExecute(handle, phase, solverConfig, solverData, inputMatrix,
Expand Down Expand Up @@ -257,3 +265,15 @@ end
@ccall libcudss.cudssMatrixGetFormat(matrix::cudssMatrix_t,
format::Ptr{cudssMatrixFormat_t})::cudssStatus_t
end

@checked function cudssGetDeviceMemHandler(handle, handler)
initialize_context()
@ccall libcudss.cudssGetDeviceMemHandler(handle::cudssHandle_t,
handler::Ptr{cudssDeviceMemHandler_t})::cudssStatus_t
end

@checked function cudssSetDeviceMemHandler(handle, handler)
initialize_context()
@ccall libcudss.cudssSetDeviceMemHandler(handle::cudssHandle_t,
handler::Ptr{cudssDeviceMemHandler_t})::cudssStatus_t
end
13 changes: 8 additions & 5 deletions src/types.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# cuDSS types

const CUDSS_DATA_PARAMETERS = ("info", "lu_nnz", "npivots", "inertia", "perm_reorder",
"perm_row", "perm_col", "diag", "user_perm")
const CUDSS_DATA_PARAMETERS = ("info", "lu_nnz", "npivots", "inertia", "perm_reorder_row",
"perm_reorder_col", "perm_row", "perm_col", "diag", "user_perm")

const CUDSS_CONFIG_PARAMETERS = ("reordering_alg", "factorization_alg", "solve_alg", "matching_type",
"solve_mode", "ir_n_steps", "ir_tol", "pivot_type", "pivot_threshold",
Expand All @@ -13,7 +13,8 @@ const CUDSS_TYPES = Dict{String, DataType}(
"lu_nnz" => Int64,
"npivots" => Cint,
"inertia" => Tuple{Cint, Cint},
"perm_reorder" => Vector{Cint},
"perm_reorder_row" => Vector{Cint},
"perm_reorder_col" => Vector{Cint},
"perm_row" => Vector{Cint},
"perm_col" => Vector{Cint},
"diag" => Vector{Float64},
Expand Down Expand Up @@ -73,8 +74,10 @@ function Base.convert(::Type{cudssDataParam_t}, data::String)
return CUDSS_DATA_NPIVOTS
elseif data == "inertia"
return CUDSS_DATA_INERTIA
elseif data == "perm_reorder"
return CUDSS_DATA_PERM_REORDER
elseif data == "perm_reorder_row"
return CUDSS_DATA_PERM_REORDER_ROW
elseif data == "perm_reorder_col"
return CUDSS_DATA_PERM_REORDER_COL
elseif data == "perm_row"
return CUDSS_DATA_PERM_ROW
elseif data == "perm_col"
Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ include("test_cudss.jl")
cudss_sparse()
end

@testset "CudssData" begin
# Issue #1
data = CudssData()
CUDSS.cudssDataDestroy(CUDSS.handle(), data)
end

@testset "CudssSolver" begin
cudss_solver()
end
Expand Down
7 changes: 3 additions & 4 deletions test/test_cudss.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function cudss_version()
@test CUDSS.version() == v"0.1.0"
@test CUDSS.version() == v"0.2.0"
end

function cudss_dense()
Expand Down Expand Up @@ -97,7 +97,7 @@ function cudss_solver()
end

@testset "data parameter = $parameter" for parameter in CUDSS_DATA_PARAMETERS
parameter ("perm_row", "perm_col", "perm_reorder", "diag") && continue
parameter ("perm_row", "perm_col", "perm_reorder_row", "perm_reorder_col", "diag") && continue
if parameter "user_perm"
(parameter == "inertia") && !(structure ("S", "H")) && continue
val = cudss_get(solver, parameter)
Expand Down Expand Up @@ -157,10 +157,9 @@ function cudss_execution()
end
end

symmetric_hermitian_pivots = T <: Real ? ('C', 'R', 'N') : ('N',)
@testset "Symmetric -- Hermitian" begin
@testset "view = $view" for view in ('F', 'L', 'U')
@testset "Pivoting = $pivot" for pivot in symmetric_hermitian_pivots
@testset "Pivoting = $pivot" for pivot in ('C', 'R', 'N')
A_cpu = sprand(T, n, n, 0.01) + I
A_cpu = A_cpu + A_cpu'
X_cpu = zeros(T, n, p)
Expand Down

0 comments on commit 9a454fd

Please sign in to comment.