From 7192527f8a53e8a2104c6e76299f8eb3522db931 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Fri, 8 Dec 2023 12:08:49 -0600 Subject: [PATCH] Use a String for the structure --- README.md | 4 ++-- src/helpers.jl | 10 +++++----- src/interfaces.jl | 10 +++++----- src/types.jl | 12 ------------ test/test_cudss.jl | 12 ++++++------ 5 files changed, 18 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index 2e8f575..071b19e 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ A_gpu = CuSparseMatrixCSR(A_cpu) x_gpu = CuVector(x_cpu) b_gpu = CuVector(b_cpu) -solver = CudssSolver(A_gpu, 'G', 'F') +solver = CudssSolver(A_gpu, "G", 'F') cudss("analysis", solver, x_gpu, b_gpu) cudss("factorization", solver, x_gpu, b_gpu) @@ -75,7 +75,7 @@ A_gpu = CuSparseMatrixCSR(A_cpu |> tril) X_gpu = CuMatrix(X_cpu) B_gpu = CuMatrix(B_cpu) -structure = T <: Real ? 'S' : 'H' +structure = T <: Real ? "S" : "H" solver = CudssSolver(A_gpu, structure, 'L') cudss("analysis", solver, X_gpu, B_gpu) diff --git a/src/helpers.jl b/src/helpers.jl index 788bfb2..72aaca0 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -6,15 +6,15 @@ export CudssMatrix, CudssData, CudssConfig """ matrix = CudssMatrix(v::CuVector) matrix = CudssMatrix(A::CuMatrix) - matrix = CudssMatrix(A::CuSparseMatrixCSR, struture::Union{Char, String}, view::Char; index::Char='O') + matrix = CudssMatrix(A::CuSparseMatrixCSR, struture::String, view::Char; index::Char='O') `CudssMatrix` is a wrapper for `CuVector`, `CuMatrix` and `CuSparseMatrixCSR`. `CudssMatrix` is used to pass matrix of the linear system, as well as solution and right-hand side. `structure` specifies the stucture for sparse matrices: -- `'G'` or `"G"`: General matrix -- LDU factorization; -- `'S'` or `"S"`: Real symmetric matrix -- LDLᵀ factorization; -- `'H'` or `"H"`: Complex Hermitian matrix -- LDLᴴ factorization; +- `"G"`: General matrix -- LDU factorization; +- `"S"`: Real symmetric matrix -- LDLᵀ factorization; +- `"H"`: Complex Hermitian matrix -- LDLᴴ factorization; - `"SPD"`: Symmetric positive-definite matrix -- LLᵀ factorization; - `"HPD"`: Hermitian positive-definite matrix -- LLᴴ factorization. @@ -52,7 +52,7 @@ mutable struct CudssMatrix obj end - function CudssMatrix(A::CuSparseMatrixCSR, structure::Union{Char, String}, view::Char; index::Char='O') + function CudssMatrix(A::CuSparseMatrixCSR, structure::String, view::Char; index::Char='O') m,n = size(A) matrix_ref = Ref{cudssMatrix_t}() cudssMatrixCreateCsr(matrix_ref, m, n, nnz(A), A.rowPtr, CU_NULL, diff --git a/src/interfaces.jl b/src/interfaces.jl index f1465a4..2671952 100644 --- a/src/interfaces.jl +++ b/src/interfaces.jl @@ -1,16 +1,16 @@ export CudssSolver, cudss, cudss_set, cudss_get """ - solver = CudssSolver(A::CuSparseMatrixCSR, structure::Union{Char, String}, view::Char; index::Char='O') + solver = CudssSolver(A::CuSparseMatrixCSR, structure::String, view::Char; index::Char='O') solver = CudssSolver(matrix::CudssMatrix, config::CudssConfig, data::CudssData) `CudssSolver` contains all structures required to solve linear systems with cuDSS. One constructor of `CudssSolver` takes as input the same parameters as [`CudssMatrix`](@ref). `structure` specifies the stucture for sparse matrices: -- `'G'` or `"G"`: General matrix -- LDU factorization; -- `'S'` or `"S"`: Real symmetric matrix -- LDLᵀ factorization; -- `'H'` or `"H"`: Complex Hermitian matrix -- LDLᴴ factorization; +- `"G"`: General matrix -- LDU factorization; +- `"S"`: Real symmetric matrix -- LDLᵀ factorization; +- `"H"`: Complex Hermitian matrix -- LDLᴴ factorization; - `"SPD"`: Symmetric positive-definite matrix -- LLᵀ factorization; - `"HPD"`: Hermitian positive-definite matrix -- LLᴴ factorization. @@ -34,7 +34,7 @@ mutable struct CudssSolver return new(matrix, config, data) end - function CudssSolver(A::CuSparseMatrixCSR, structure::Union{Char, String}, view::Char; index::Char='O') + function CudssSolver(A::CuSparseMatrixCSR, structure::String, view::Char; index::Char='O') matrix = CudssMatrix(A, structure, view; index) config = CudssConfig() data = CudssData() diff --git a/src/types.jl b/src/types.jl index 47db64f..f334c51 100644 --- a/src/types.jl +++ b/src/types.jl @@ -58,18 +58,6 @@ end ## matrix structure type -function Base.convert(::Type{cudssMatrixType_t}, structure::Char) - if structure == 'G' - return CUDSS_MTYPE_GENERAL - elseif structure == 'S' - return CUDSS_MTYPE_SYMMETRIC - elseif structure == 'H' - return CUDSS_MTYPE_HERMITIAN - else - throw(ArgumentError("Unknown structure $structure")) - end -end - function Base.convert(::Type{cudssMatrixType_t}, structure::String) if structure == "G" return CUDSS_MTYPE_GENERAL diff --git a/test/test_cudss.jl b/test/test_cudss.jl index 5721a93..6c22ef8 100644 --- a/test/test_cudss.jl +++ b/test/test_cudss.jl @@ -41,7 +41,7 @@ function cudss_sparse() A_cpu = A_cpu + A_cpu' A_gpu = CuSparseMatrixCSR(A_cpu) @testset "view = $view" for view in ('L', 'U', 'F') - @testset "structure = $structure" for structure in ('G', "G", 'S', "S", 'H', "H", "SPD", "HPD") + @testset "structure = $structure" for structure in ("G", "S", "H", "SPD", "HPD") matrix = CudssMatrix(A_gpu, structure, view) format = Ref{CUDSS.cudssMatrixFormat_t}() CUDSS.cudssMatrixGetFormat(matrix, format) @@ -62,7 +62,7 @@ function cudss_solver() A_cpu = sprand(T, n, n, 1.0) A_cpu = A_cpu + A_cpu' A_gpu = CuSparseMatrixCSR(A_cpu) - @testset "structure = $structure" for structure in ('G', "G", 'S', "S", 'H', "H", "SPD", "HPD") + @testset "structure = $structure" for structure in ("G", "S", "H", "SPD", "HPD") @testset "view = $view" for view in ('L', 'U', 'F') solver = CudssSolver(A_gpu, structure, view) @@ -95,7 +95,7 @@ function cudss_solver() @testset "data parameter = $parameter" for parameter in CUDSS_DATA_PARAMETERS parameter ∈ ("perm_row", "perm_col", "perm_reorder", "diag") && continue if parameter ≠ "user_perm" - (parameter == "inertia") && !(structure ∈ ('S', "S", 'H', "H")) && continue + (parameter == "inertia") && !(structure ∈ ("S", "H")) && continue val = cudss_get(solver, parameter) else perm = Cint[i for i=n:-1:1] @@ -121,7 +121,7 @@ function cudss_execution() x_gpu = CuVector(x_cpu) b_gpu = CuVector(b_cpu) - matrix = CudssMatrix(A_gpu, 'G', 'F') + matrix = CudssMatrix(A_gpu, "G", 'F') config = CudssConfig() data = CudssData() solver = CudssSolver(matrix, config, data) @@ -147,12 +147,12 @@ function cudss_execution() X_gpu = CuMatrix(X_cpu) B_gpu = CuMatrix(B_cpu) - structure = T <: Real ? 'S' : 'H' + structure = T <: Real ? "S" : "H" matrix = CudssMatrix(A_gpu, structure, view) config = CudssConfig() data = CudssData() solver = CudssSolver(matrix, config, data) - (structure == 'H') && cudss_set(solver, "pivot_type", 'N') + (structure == "H") && cudss_set(solver, "pivot_type", 'N') cudss("analysis", solver, X_gpu, B_gpu) cudss("factorization", solver, X_gpu, B_gpu)