From 9a454fd1c0dcf97b2acdce68b6a075d2236f44ca Mon Sep 17 00:00:00 2001 From: Xx-Alexis-xX Date: Fri, 19 Apr 2024 14:48:05 -0400 Subject: [PATCH] Upgrade the Julia interface for cuDSS 0.2.1 --- Project.toml | 4 ++-- src/generic.jl | 1 - src/helpers.jl | 8 +++++--- src/interfaces.jl | 15 ++++++++------- src/libcudss.jl | 40 ++++++++++++++++++++++++++++++---------- src/types.jl | 13 ++++++++----- test/runtests.jl | 6 ++++++ test/test_cudss.jl | 7 +++---- 8 files changed, 62 insertions(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index 801c476..c1f8972 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "CUDSS" uuid = "45b445bb-4962-46a0-9369-b4df9d0f772e" authors = ["Alexis Montoison "] -version = "0.1.4" +version = "0.2.0" [deps] CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82" @@ -13,7 +13,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] CEnum = "0.4, 0.5" CUDA = "5" -CUDSS_jll = "0.1.0" +CUDSS_jll = "0.2.1" julia = "1.6" LinearAlgebra = "1.6" SparseArrays = "1.6" diff --git a/src/generic.jl b/src/generic.jl index 5e7e9fe..a8a258a 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -59,7 +59,6 @@ function LinearAlgebra.ldlt(A::CuSparseMatrixCSR{T,Cint}; view::Char='F', check n = checksquare(A) structure = T <: Real ? "S" : "H" solver = CudssSolver(A, structure, view) - (T <: Complex) && cudss_set(solver, "pivot_type", 'N') x = CudssMatrix(T, n) b = CudssMatrix(T, n) cudss("analysis", solver, x, b) diff --git a/src/helpers.jl b/src/helpers.jl index 142c49c..f1e8c57 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -98,12 +98,14 @@ Base.unsafe_convert(::Type{cudssMatrix_t}, matrix::CudssMatrix) = matrix.matrix `CudssData` holds internal data (e.g., LU factors arrays). """ mutable struct CudssData + handle::cudssHandle_t data::cudssData_t function CudssData() data_ref = Ref{cudssData_t}() - cudssDataCreate(handle(), data_ref) - obj = new(data_ref[]) + cudss_handle = handle() + cudssDataCreate(cudss_handle, data_ref) + obj = new(cudss_handle, data_ref[]) finalizer(cudssDataDestroy, obj) obj end @@ -112,7 +114,7 @@ end Base.unsafe_convert(::Type{cudssData_t}, data::CudssData) = data.data function cudssDataDestroy(data::CudssData) - cudssDataDestroy(handle(), data) + cudssDataDestroy(data.handle, data) end ## Configuration diff --git a/src/interfaces.jl b/src/interfaces.jl index 36b4e6a..8406a12 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -104,7 +104,7 @@ function cudss_set(data::CudssData, parameter::String, value) (parameter == "user_perm") || throw(ArgumentError("Only the data parameter \"user_perm\" can be set.")) (value isa Vector{Cint} || value isa CuVector{Cint}) || throw(ArgumentError("The permutation is neither a Vector{Cint} nor a CuVector{Cint}.")) nbytes = sizeof(value) - cudssDataSet(handle(), data, parameter, value, nbytes) + cudssDataSet(data.handle, data, parameter, value, nbytes) end function cudss_set(config::CudssConfig, parameter::String, value) @@ -138,12 +138,13 @@ The available data parameters are: - `"lu_nnz"`: Number of non-zero entries in LU factors; - `"npivots"`: Number of pivots encountered during factorization; - `"inertia"`: Tuple of positive and negative indices of inertia for symmetric and hermitian non positive-definite matrix types; -- `"perm_reorder"`: Reordering permutation; +- `"perm_reorder_row"`: Reordering permutation for the rows; +- `"perm_reorder_col"`: Reordering permutation for the columns; - `"perm_row"`: Final row permutation (which includes effects of both reordering and pivoting); - `"perm_col"`: Final column permutation (which includes effects of both reordering and pivoting); - `"diag"`: Diagonal of the factorized matrix. -The data parameters `"info"`, `"lu_nnz"` and `"perm_reorder"` require the phase `"analyse"` performed by [`cudss`](@ref). +The data parameters `"info"`, `"lu_nnz"`, `"perm_reorder_row"` and `"perm_reorder_col"` require the phase `"analyse"` performed by [`cudss`](@ref). The data parameters `"npivots"`, `"inertia"` and `"diag"` require the phases `"analyse"` and `"factorization"` performed by [`cudss`](@ref). The data parameters `"perm_row"` and `"perm_col"` are available but not yet functional. """ @@ -162,13 +163,13 @@ end function cudss_get(data::CudssData, parameter::String) (parameter ∈ CUDSS_DATA_PARAMETERS) || throw(ArgumentError("Unknown data parameter $parameter.")) (parameter == "user_perm") && throw(ArgumentError("The data parameter \"user_perm\" cannot be retrieved.")) - if (parameter == "perm_reorder") || (parameter == "perm_row") || (parameter == "perm_col") || (parameter == "diag") + if (parameter == "perm_reorder_row") || (parameter == "perm_reorder_col") || (parameter == "perm_row") || (parameter == "perm_col") || (parameter == "diag") throw(ArgumentError("The data parameter \"$parameter\" is not supported by CUDSS.jl.")) end type = CUDSS_TYPES[parameter] val = Ref{type}() nbytes = sizeof(val) - nbytes_written = Ref{Cint}() + nbytes_written = Ref{Csize_t}() cudssDataGet(handle(), data, parameter, val, nbytes, nbytes_written) return val[] end @@ -178,7 +179,7 @@ function cudss_get(config::CudssConfig, parameter::String) type = CUDSS_TYPES[parameter] val = Ref{type}() nbytes = sizeof(val) - nbytes_written = Ref{Cint}() + nbytes_written = Ref{Csize_t}() cudssConfigGet(config, parameter, val, nbytes, nbytes_written) return val[] end @@ -196,7 +197,7 @@ The phases `"solve_fwd"`, `"solve_diag"` and `"solve_bwd"` are available but not function cudss end function cudss(phase::String, solver::CudssSolver{T}, X::CudssMatrix{T}, B::CudssMatrix{T}) where T <: BlasFloat - cudssExecute(handle(), phase, solver.config, solver.data, solver.matrix, X, B) + cudssExecute(solver.data.handle, phase, solver.config, solver.data, solver.matrix, X, B) end function cudss(phase::String, solver::CudssSolver{T}, x::CuVector{T}, b::CuVector{T}) where T <: BlasFloat diff --git a/src/libcudss.jl b/src/libcudss.jl index 194c992..b681510 100644 --- a/src/libcudss.jl +++ b/src/libcudss.jl @@ -40,11 +40,12 @@ end CUDSS_DATA_LU_NNZ = 1 CUDSS_DATA_NPIVOTS = 2 CUDSS_DATA_INERTIA = 3 - CUDSS_DATA_PERM_REORDER = 4 - CUDSS_DATA_PERM_ROW = 5 - CUDSS_DATA_PERM_COL = 6 - CUDSS_DATA_DIAG = 7 - CUDSS_DATA_USER_PERM = 8 + CUDSS_DATA_PERM_REORDER_ROW = 4 + CUDSS_DATA_PERM_REORDER_COL = 5 + CUDSS_DATA_PERM_ROW = 6 + CUDSS_DATA_PERM_COL = 7 + CUDSS_DATA_DIAG = 8 + CUDSS_DATA_USER_PERM = 9 end @cenum cudssPhase_t::UInt32 begin @@ -111,31 +112,38 @@ end CUDSS_MFORMAT_CSR = 1 end +struct cudssDeviceMemHandler_t + ctx::Ptr{Cvoid} + device_alloc::Ptr{Cvoid} + device_free::Ptr{Cvoid} + name::NTuple{64,Cchar} +end + @checked function cudssConfigSet(config, param, value, sizeInBytes) initialize_context() @ccall libcudss.cudssConfigSet(config::cudssConfig_t, param::cudssConfigParam_t, - value::Ptr{Cvoid}, sizeInBytes::Cint)::cudssStatus_t + value::Ptr{Cvoid}, sizeInBytes::Csize_t)::cudssStatus_t end @checked function cudssConfigGet(config, param, value, sizeInBytes, sizeWritten) initialize_context() @ccall libcudss.cudssConfigGet(config::cudssConfig_t, param::cudssConfigParam_t, - value::Ptr{Cvoid}, sizeInBytes::Cint, - sizeWritten::Ptr{Cint})::cudssStatus_t + value::Ptr{Cvoid}, sizeInBytes::Csize_t, + sizeWritten::Ptr{Csize_t})::cudssStatus_t end @checked function cudssDataSet(handle, data, param, value, sizeInBytes) initialize_context() @ccall libcudss.cudssDataSet(handle::cudssHandle_t, data::cudssData_t, param::cudssDataParam_t, value::PtrOrCuPtr{Cvoid}, - sizeInBytes::Cint)::cudssStatus_t + sizeInBytes::Csize_t)::cudssStatus_t end @checked function cudssDataGet(handle, data, param, value, sizeInBytes, sizeWritten) initialize_context() @ccall libcudss.cudssDataGet(handle::cudssHandle_t, data::cudssData_t, param::cudssDataParam_t, value::PtrOrCuPtr{Cvoid}, - sizeInBytes::Cint, sizeWritten::Ptr{Cint})::cudssStatus_t + sizeInBytes::Csize_t, sizeWritten::Ptr{Csize_t})::cudssStatus_t end @checked function cudssExecute(handle, phase, solverConfig, solverData, inputMatrix, @@ -257,3 +265,15 @@ end @ccall libcudss.cudssMatrixGetFormat(matrix::cudssMatrix_t, format::Ptr{cudssMatrixFormat_t})::cudssStatus_t end + +@checked function cudssGetDeviceMemHandler(handle, handler) + initialize_context() + @ccall libcudss.cudssGetDeviceMemHandler(handle::cudssHandle_t, + handler::Ptr{cudssDeviceMemHandler_t})::cudssStatus_t +end + +@checked function cudssSetDeviceMemHandler(handle, handler) + initialize_context() + @ccall libcudss.cudssSetDeviceMemHandler(handle::cudssHandle_t, + handler::Ptr{cudssDeviceMemHandler_t})::cudssStatus_t +end diff --git a/src/types.jl b/src/types.jl index b07c8b6..2a22b61 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,7 +1,7 @@ # cuDSS types -const CUDSS_DATA_PARAMETERS = ("info", "lu_nnz", "npivots", "inertia", "perm_reorder", - "perm_row", "perm_col", "diag", "user_perm") +const CUDSS_DATA_PARAMETERS = ("info", "lu_nnz", "npivots", "inertia", "perm_reorder_row", + "perm_reorder_col", "perm_row", "perm_col", "diag", "user_perm") const CUDSS_CONFIG_PARAMETERS = ("reordering_alg", "factorization_alg", "solve_alg", "matching_type", "solve_mode", "ir_n_steps", "ir_tol", "pivot_type", "pivot_threshold", @@ -13,7 +13,8 @@ const CUDSS_TYPES = Dict{String, DataType}( "lu_nnz" => Int64, "npivots" => Cint, "inertia" => Tuple{Cint, Cint}, - "perm_reorder" => Vector{Cint}, + "perm_reorder_row" => Vector{Cint}, + "perm_reorder_col" => Vector{Cint}, "perm_row" => Vector{Cint}, "perm_col" => Vector{Cint}, "diag" => Vector{Float64}, @@ -73,8 +74,10 @@ function Base.convert(::Type{cudssDataParam_t}, data::String) return CUDSS_DATA_NPIVOTS elseif data == "inertia" return CUDSS_DATA_INERTIA - elseif data == "perm_reorder" - return CUDSS_DATA_PERM_REORDER + elseif data == "perm_reorder_row" + return CUDSS_DATA_PERM_REORDER_ROW + elseif data == "perm_reorder_col" + return CUDSS_DATA_PERM_REORDER_COL elseif data == "perm_row" return CUDSS_DATA_PERM_ROW elseif data == "perm_col" diff --git a/test/runtests.jl b/test/runtests.jl index 2ecf936..98f7ebd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,6 +22,12 @@ include("test_cudss.jl") cudss_sparse() end + @testset "CudssData" begin + # Issue #1 + data = CudssData() + CUDSS.cudssDataDestroy(CUDSS.handle(), data) + end + @testset "CudssSolver" begin cudss_solver() end diff --git a/test/test_cudss.jl b/test/test_cudss.jl index d1a0fd4..44cf643 100644 --- a/test/test_cudss.jl +++ b/test/test_cudss.jl @@ -1,5 +1,5 @@ function cudss_version() - @test CUDSS.version() == v"0.1.0" + @test CUDSS.version() == v"0.2.0" end function cudss_dense() @@ -97,7 +97,7 @@ function cudss_solver() end @testset "data parameter = $parameter" for parameter in CUDSS_DATA_PARAMETERS - parameter ∈ ("perm_row", "perm_col", "perm_reorder", "diag") && continue + parameter ∈ ("perm_row", "perm_col", "perm_reorder_row", "perm_reorder_col", "diag") && continue if parameter ≠ "user_perm" (parameter == "inertia") && !(structure ∈ ("S", "H")) && continue val = cudss_get(solver, parameter) @@ -157,10 +157,9 @@ function cudss_execution() end end - symmetric_hermitian_pivots = T <: Real ? ('C', 'R', 'N') : ('N',) @testset "Symmetric -- Hermitian" begin @testset "view = $view" for view in ('F', 'L', 'U') - @testset "Pivoting = $pivot" for pivot in symmetric_hermitian_pivots + @testset "Pivoting = $pivot" for pivot in ('C', 'R', 'N') A_cpu = sprand(T, n, n, 0.01) + I A_cpu = A_cpu + A_cpu' X_cpu = zeros(T, n, p)