Skip to content

Commit

Permalink
Support cuDSS v0.4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Dec 11, 2024
1 parent 19defac commit 352ab10
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 8 deletions.
4 changes: 2 additions & 2 deletions gen/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ Clang = "40e3b903-d033-50b4-a0cc-940c62c95e31"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"

[compat]
CUDA_SDK_jll = "12.5.1"
CUDSS_jll = "0.3.0"
CUDA_SDK_jll = "12.6.3"
CUDSS_jll = "0.4.0"
julia = "1.6"
7 changes: 4 additions & 3 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,10 @@ The available data parameters are:
- `"perm_row"`: Final row permutation (which includes effects of both reordering and pivoting);
- `"perm_col"`: Final column permutation (which includes effects of both reordering and pivoting);
- `"diag"`: Diagonal of the factorized matrix;
- `"hybrid_device_memory_min"`: Minimal amount of device memory (number of bytes) required in the hybrid memory mode.
- `"hybrid_device_memory_min"`: Minimal amount of device memory (number of bytes) required in the hybrid memory mode;
- `"memory_estimates"`: Memory estimates (in bytes) for host and device memory required for the chosen memory mode.
The data parameters `"info"`, `"lu_nnz"`, `"perm_reorder_row"`, `"perm_reorder_col"` and `"hybrid_device_memory_min"` require the phase `"analyse"` performed by [`cudss`](@ref).
The data parameters `"info"`, `"lu_nnz"`, `"perm_reorder_row"`, `"perm_reorder_col"`, `"hybrid_device_memory_min"` and `"memory_estimates"` require the phase `"analyse"` performed by [`cudss`](@ref).
The data parameters `"npivots"`, `"inertia"` and `"diag"` require the phases `"analyse"` and `"factorization"` performed by [`cudss`](@ref).
The data parameters `"perm_row"` and `"perm_col"` are available but not yet functional.
"""
Expand All @@ -173,7 +174,7 @@ function cudss_get(data::CudssData, parameter::String)
if (parameter == "user_perm") || (parameter == "comm")
throw(ArgumentError("The data parameter \"$parameter\" cannot be retrieved."))
end
if (parameter == "perm_reorder_row") || (parameter == "perm_reorder_col") || (parameter == "perm_row") || (parameter == "perm_col") || (parameter == "diag")
if (parameter == "perm_reorder_row") || (parameter == "perm_reorder_col") || (parameter == "perm_row") || (parameter == "perm_col") || (parameter == "diag") || (parameter == "memory_estimates")
throw(ArgumentError("The data parameter \"$parameter\" is not supported by CUDSS.jl."))
end
type = CUDSS_TYPES[parameter]
Expand Down
79 changes: 79 additions & 0 deletions src/libcudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ end
CUDSS_DATA_USER_PERM = 9
CUDSS_DATA_HYBRID_DEVICE_MEMORY_MIN = 10
CUDSS_DATA_COMM = 11
CUDSS_DATA_MEMORY_ESTIMATES = 12
end

@cenum cudssPhase_t::UInt32 begin
Expand Down Expand Up @@ -255,6 +256,35 @@ end
indexBase::cudssIndexBase_t)::cudssStatus_t
end

@checked function cudssMatrixCreateBatchDn(matrix, batchCount, nrows, ncols, ld, values,
valueType, layout)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixCreateBatchDn(matrix::Ptr{cudssMatrix_t},
batchCount::Int64, nrows::Ptr{Cvoid},
ncols::Ptr{Cvoid}, ld::Ptr{Cvoid},
values::Ptr{Ptr{Cvoid}},
valueType::cudaDataType_t,
layout::cudssLayout_t)::cudssStatus_t
end

@checked function cudssMatrixCreateBatchCsr(matrix, batchCount, nrows, ncols, nnz, rowStart,
rowEnd, colIndices, values, indexType,
valueType, mtype, mview, indexBase)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixCreateBatchCsr(matrix::Ptr{cudssMatrix_t},
batchCount::Int64, nrows::Ptr{Cvoid},
ncols::Ptr{Cvoid}, nnz::Ptr{Cvoid},
rowStart::Ptr{Ptr{Cvoid}},
rowEnd::Ptr{Ptr{Cvoid}},
colIndices::Ptr{Ptr{Cvoid}},
values::Ptr{Ptr{Cvoid}},
indexType::cudaDataType_t,
valueType::cudaDataType_t,
mtype::cudssMatrixType_t,
mview::cudssMatrixViewType_t,
indexBase::cudssIndexBase_t)::cudssStatus_t
end

@checked function cudssMatrixDestroy(matrix)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixDestroy(matrix::cudssMatrix_t)::cudssStatus_t
Expand Down Expand Up @@ -300,6 +330,55 @@ end
values::CuPtr{Cvoid})::cudssStatus_t
end

@checked function cudssMatrixGetBatchDn(matrix, batchCount, nrows, ncols, ld, values, type,
layout)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixGetBatchDn(matrix::cudssMatrix_t,
batchCount::Ptr{Int64},
nrows::Ptr{Ptr{Cvoid}},
ncols::Ptr{Ptr{Cvoid}},
ld::Ptr{Ptr{Cvoid}},
values::Ptr{Ptr{Ptr{Cvoid}}},
type::Ptr{cudaDataType_t},
layout::Ptr{cudssLayout_t})::cudssStatus_t
end

@checked function cudssMatrixGetBatchCsr(matrix, batchCount, nrows, ncols, nnz, rowStart,
rowEnd, colIndices, values, indexType, valueType,
mtype, mview, indexBase)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixGetBatchCsr(matrix::cudssMatrix_t,
batchCount::Ptr{Int64},
nrows::Ptr{Ptr{Cvoid}},
ncols::Ptr{Ptr{Cvoid}},
nnz::Ptr{Ptr{Cvoid}},
rowStart::Ptr{Ptr{Ptr{Cvoid}}},
rowEnd::Ptr{Ptr{Ptr{Cvoid}}},
colIndices::Ptr{Ptr{Ptr{Cvoid}}},
values::Ptr{Ptr{Ptr{Cvoid}}},
indexType::Ptr{cudaDataType_t},
valueType::Ptr{cudaDataType_t},
mtype::Ptr{cudssMatrixType_t},
mview::Ptr{cudssMatrixViewType_t},
indexBase::Ptr{cudssIndexBase_t})::cudssStatus_t
end

@checked function cudssMatrixSetBatchValues(matrix, values)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixSetBatchValues(matrix::cudssMatrix_t,
values::Ptr{Ptr{Cvoid}})::cudssStatus_t
end

@checked function cudssMatrixSetBatchCsrPointers(matrix, rowOffsets, rowEnd, colIndices,
values)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixSetBatchCsrPointers(matrix::cudssMatrix_t,
rowOffsets::Ptr{Ptr{Cvoid}},
rowEnd::Ptr{Ptr{Cvoid}},
colIndices::Ptr{Ptr{Cvoid}},
values::Ptr{Ptr{Cvoid}})::cudssStatus_t
end

@checked function cudssMatrixGetFormat(matrix, format)
initialize_context()
@gcsafe_ccall libcudss.cudssMatrixGetFormat(matrix::cudssMatrix_t,
Expand Down
5 changes: 4 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

const CUDSS_DATA_PARAMETERS = ("info", "lu_nnz", "npivots", "inertia", "perm_reorder_row",
"perm_reorder_col", "perm_row", "perm_col", "diag", "user_perm",
"hybrid_device_memory_min", "comm")
"hybrid_device_memory_min", "comm", "memory_estimates")

const CUDSS_CONFIG_PARAMETERS = ("reordering_alg", "factorization_alg", "solve_alg", "matching_type",
"solve_mode", "ir_n_steps", "ir_tol", "pivot_type", "pivot_threshold",
Expand All @@ -23,6 +23,7 @@ const CUDSS_TYPES = Dict{String, DataType}(
"user_perm" => Vector{Cint},
"hybrid_device_memory_min" => Int64,
"comm" => Ptr{Cvoid},
"memory_estimates" => Vector{Int64},
# config type
"reordering_alg" => cudssAlgType_t,
"factorization_alg" => cudssAlgType_t,
Expand Down Expand Up @@ -103,6 +104,8 @@ function Base.convert(::Type{cudssDataParam_t}, data::String)
return CUDSS_DATA_HYBRID_DEVICE_MEMORY_MIN
elseif data == "comm"
return CUDSS_DATA_COMM
elseif data == "memory_estimates"
return CUDSS_DATA_MEMORY_ESTIMATES
else
throw(ArgumentError("Unknown data parameter $data"))
end
Expand Down
4 changes: 2 additions & 2 deletions test/test_cudss.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function cudss_version()
@test CUDSS.version() == v"0.3.0"
@test CUDSS.version() == v"0.4.0"
end

function cudss_dense()
Expand Down Expand Up @@ -102,7 +102,7 @@ function cudss_solver()
end

@testset "data parameter = $parameter" for parameter in CUDSS_DATA_PARAMETERS
parameter ("perm_row", "perm_col", "perm_reorder_row", "perm_reorder_col", "diag") && continue
parameter ("perm_row", "perm_col", "perm_reorder_row", "perm_reorder_col", "diag", "memory_estimates") && continue
if (parameter != "user_perm") && (parameter != "comm")
(parameter == "inertia") && !(structure ("S", "H")) && continue
val = cudss_get(solver, parameter)
Expand Down

0 comments on commit 352ab10

Please sign in to comment.