diff --git a/src/interfaces.jl b/src/interfaces.jl index be3b863..0fb6925 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -97,10 +97,9 @@ end function cudss_set(data::CudssData, param::String, value) (param ∈ CUDSS_DATA_PARAMETERS) || throw(ArgumentError("Unknown data parameter $param.")) (param == "user_perm") || throw(ArgumentError("Only the data parameter \"user_perm\" can be set.")) - type = CUDSS_TYPES[param] - val = Ref{type}(value) - nbytes = sizeof(val) - cudssDataSet(handle(), data, param, val, nbytes) + (value isa Vector{Cint} || value isa CuVector{Cint}) || throw(ArgumentError("The permutation is neither a Vector{Cint} nor a CuVector{Cint}.")) + nbytes = sizeof(value) + cudssDataSet(handle(), data, param, value, nbytes) end function cudss_set(config::CudssConfig, param::String, value) diff --git a/src/libcudss.jl b/src/libcudss.jl index 639b53d..194c992 100644 --- a/src/libcudss.jl +++ b/src/libcudss.jl @@ -127,14 +127,14 @@ end @checked function cudssDataSet(handle, data, param, value, sizeInBytes) initialize_context() @ccall libcudss.cudssDataSet(handle::cudssHandle_t, data::cudssData_t, - param::cudssDataParam_t, value::Ptr{Cvoid}, + param::cudssDataParam_t, value::PtrOrCuPtr{Cvoid}, sizeInBytes::Cint)::cudssStatus_t end @checked function cudssDataGet(handle, data, param, value, sizeInBytes, sizeWritten) initialize_context() @ccall libcudss.cudssDataGet(handle::cudssHandle_t, data::cudssData_t, - param::cudssDataParam_t, value::Ptr{Cvoid}, + param::cudssDataParam_t, value::PtrOrCuPtr{Cvoid}, sizeInBytes::Cint, sizeWritten::Ptr{Cint})::cudssStatus_t end diff --git a/test/runtests.jl b/test/runtests.jl index 2a3a653..72d5b89 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -34,4 +34,8 @@ include("test_cudss.jl") @testset "Generic API" begin cudss_generic() end + + @testset "User permutation" begin + user_permutation() + end end diff --git a/test/test_cudss.jl b/test/test_cudss.jl index 50ea65c..56a7f11 100644 --- a/test/test_cudss.jl +++ b/test/test_cudss.jl @@ -344,3 +344,106 @@ function cudss_generic() end end end + +function user_permutation() + function permutation_lu(T, A_cpu, x_cpu, b_cpu, permutation) + A_gpu = CuSparseMatrixCSR(A_cpu) + x_gpu = CuVector(x_cpu) + b_gpu = CuVector(b_cpu) + + solver = CudssSolver(A_gpu, "G", 'F') + + cudss_set(solver, "user_perm", permutation) + + cudss("analysis", solver, x_gpu, b_gpu) + cudss("factorization", solver, x_gpu, b_gpu) + cudss("solve", solver, x_gpu, b_gpu) + + nz = cudss_get(solver, "lu_nnz") + return nz + end + + function permutation_ldlt(T, A_cpu, x_cpu, b_cpu, permutation) + A_gpu = CuSparseMatrixCSR(A_cpu |> tril) + x_gpu = CuVector(x_cpu) + b_gpu = CuVector(b_cpu) + + structure = T <: Real ? "S" : "H" + solver = CudssSolver(A_gpu, structure, 'L') + cudss_set(solver, "user_perm", permutation) + + cudss("analysis", solver, x_gpu, b_gpu) + cudss("factorization", solver, x_gpu, b_gpu) + cudss("solve", solver, x_gpu, b_gpu) + + nz = cudss_get(solver, "lu_nnz") + return nz + end + + function permutation_llt(T, A_cpu, x_cpu, b_cpu, permutation) + A_gpu = CuSparseMatrixCSR(A_cpu |> triu) + x_gpu = CuVector(x_cpu) + b_gpu = CuVector(b_cpu) + + structure = T <: Real ? "SPD" : "HPD" + solver = CudssSolver(A_gpu, structure, 'U') + cudss_set(solver, "user_perm", permutation) + + cudss("analysis", solver, x_gpu, b_gpu) + cudss("factorization", solver, x_gpu, b_gpu) + cudss("solve", solver, x_gpu, b_gpu) + + nz = cudss_get(solver, "lu_nnz") + return nz + end + + n = 1000 + perm1_cpu = Vector{Cint}(undef, n) + perm2_cpu = Vector{Cint}(undef, n) + for i = 1:n + perm1_cpu[i] = i + perm2_cpu[i] = n-i+1 + end + perm1_gpu = CuVector{Cint}(perm1_cpu) + perm2_gpu = CuVector{Cint}(perm2_cpu) + @testset "precision = $T" for T in (Float32, Float64, ComplexF32, ComplexF64) + @testset "LU" begin + A_cpu = sprand(T, n, n, 0.05) + I + x_cpu = zeros(T, n) + b_cpu = rand(T, n) + nz1_cpu = permutation_lu(T, A_cpu, x_cpu, b_cpu, perm1_cpu) + nz2_cpu = permutation_lu(T, A_cpu, x_cpu, b_cpu, perm2_cpu) + nz1_gpu = permutation_lu(T, A_cpu, x_cpu, b_cpu, perm1_gpu) + nz2_gpu = permutation_lu(T, A_cpu, x_cpu, b_cpu, perm2_gpu) + @test nz1_cpu == nz1_gpu + @test nz2_cpu == nz2_gpu + @test nz1_cpu != nz2_cpu + end + @testset "LDLᵀ / LDLᴴ" begin + A_cpu = sprand(T, n, n, 0.05) + I + A_cpu = A_cpu + A_cpu' + x_cpu = zeros(T, n) + b_cpu = rand(T, n) + nz1_cpu = permutation_ldlt(T, A_cpu, x_cpu, b_cpu, perm1_cpu) + nz2_cpu = permutation_ldlt(T, A_cpu, x_cpu, b_cpu, perm2_cpu) + nz1_gpu = permutation_ldlt(T, A_cpu, x_cpu, b_cpu, perm1_gpu) + nz2_gpu = permutation_ldlt(T, A_cpu, x_cpu, b_cpu, perm2_gpu) + @test nz1_cpu == nz1_gpu + @test nz2_cpu == nz2_gpu + @test nz1_cpu != nz2_cpu + end + @testset "LLᵀ / LLᴴ" begin + A_cpu = sprand(T, n, n, 0.01) + A_cpu = A_cpu * A_cpu' + I + x_cpu = zeros(T, n) + b_cpu = rand(T, n) + nz1_cpu = permutation_llt(T, A_cpu, x_cpu, b_cpu, perm1_cpu) + nz2_cpu = permutation_llt(T, A_cpu, x_cpu, b_cpu, perm2_cpu) + nz1_gpu = permutation_llt(T, A_cpu, x_cpu, b_cpu, perm1_gpu) + nz2_gpu = permutation_llt(T, A_cpu, x_cpu, b_cpu, perm2_gpu) + @test nz1_cpu == nz1_gpu + @test nz2_cpu == nz2_gpu + @test nz1_cpu != nz2_cpu + end + end +end