Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a convert function for cudssAlgType_t #28

Merged
merged 1 commit into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ end
The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.

The available configuration parameters are:
- `"reordering_alg"`: Algorithm for the reordering phase;
- `"factorization_alg"`: Algorithm for the factorization phase;
- `"solve_alg"`: Algorithm for the solving phase;
- `"reordering_alg"`: Algorithm for the reordering phase (`"default"`, `"algo1"`, `"algo2"` or `"algo3"`);
- `"factorization_alg"`: Algorithm for the factorization phase (`"default"`, `"algo1"`, `"algo2"` or `"algo3"`);
- `"solve_alg"`: Algorithm for the solving phase (`"default"`, `"algo1"`, `"algo2"` or `"algo3"`);
- `"matching_type"`: Type of matching;
- `"solve_mode"`: Potential modificator on the system matrix (transpose or adjoint);
- `"ir_n_steps"`: Number of steps during the iterative refinement;
Expand Down Expand Up @@ -123,7 +123,7 @@ The available configuration parameters are:
- `"solve_mode"`: Potential modificator on the system matrix (transpose or adjoint);
- `"ir_n_steps"`: Number of steps during the iterative refinement;
- `"ir_tol"`: Iterative refinement tolerance;
- `"pivot_type"`: Type of pivoting (`'C'`, `'R'` or `'N'`);
- `"pivot_type"`: Type of pivoting;
- `"pivot_threshold"`: Pivoting threshold which is used to determine if digonal element is subject to pivoting;
- `"pivot_epsilon"`: Pivoting epsilon, absolute value to replace singular diagonal elements;
- `"max_lu_nnz"`: Upper limit on the number of nonzero entries in LU factors for non-symmetric matrices.
Expand Down
186 changes: 101 additions & 85 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,89 +32,7 @@ const CUDSS_TYPES = Dict{String, DataType}(
"max_lu_nnz" => Int64
)

## layout type

function Base.convert(::Type{cudssLayout_t}, layout::Char)
if layout == 'R'
CUDSS_LAYOUT_ROW_MAJOR
elseif layout == 'C'
CUDSS_LAYOUT_COL_MAJOR
else
throw(ArgumentError("Unknown layout $layout"))
end
end

## index base

function Base.convert(::Type{cudssIndexBase_t}, index::Char)
if index == 'Z'
return CUDSS_BASE_ZERO
elseif index == 'O'
return CUDSS_BASE_ONE
else
throw(ArgumentError("Unknown index $index"))
end
end

## matrix structure type

function Base.convert(::Type{cudssMatrixType_t}, structure::String)
if structure == "G"
return CUDSS_MTYPE_GENERAL
elseif structure == "S"
return CUDSS_MTYPE_SYMMETRIC
elseif structure == "H"
return CUDSS_MTYPE_HERMITIAN
elseif structure == "SPD"
return CUDSS_MTYPE_SPD
elseif structure == "HPD"
return CUDSS_MTYPE_HPD
else
throw(ArgumentError("Unknown structure $structure"))
end
end

## view type

function Base.convert(::Type{cudssMatrixViewType_t}, view::Char)
if view == 'F'
return CUDSS_MVIEW_FULL
elseif view == 'L'
return CUDSS_MVIEW_LOWER
elseif view == 'U'
return CUDSS_MVIEW_UPPER
else
throw(ArgumentError("Unknown view $view"))
end
end

## pivot type

function Base.convert(::Type{cudssPivotType_t}, pivoting::Char)
if pivoting == 'C'
return CUDSS_PIVOT_COL
elseif pivoting == 'R'
return CUDSS_PIVOT_ROW
elseif pivoting == 'N'
return CUDSS_PIVOT_NONE
else
throw(ArgumentError("Unknown pivoting $pivoting"))
end
end

# matrix format type

function Base.convert(::Type{cudssMatrixFormat_t}, format::Char)
if format == 'D'
return CUDSS_MFORMAT_DENSE
elseif format == 'S'
return CUDSS_MFORMAT_CSR
else
throw(ArgumentError("Unknown format $format"))
end
end

# config type
## config type

function Base.convert(::Type{cudssConfigParam_t}, config::String)
if config == "reordering_alg"
Expand Down Expand Up @@ -144,7 +62,7 @@ function Base.convert(::Type{cudssConfigParam_t}, config::String)
end
end

# data type
## data type

function Base.convert(::Type{cudssDataParam_t}, data::String)
if data == "info"
Expand All @@ -170,7 +88,7 @@ function Base.convert(::Type{cudssDataParam_t}, data::String)
end
end

# phase type
## phase type

function Base.convert(::Type{cudssPhase_t}, phase::String)
if phase == "analysis"
Expand All @@ -191,3 +109,101 @@ function Base.convert(::Type{cudssPhase_t}, phase::String)
throw(ArgumentError("Unknown phase $phase"))
end
end

## matrix structure type

function Base.convert(::Type{cudssMatrixType_t}, structure::String)
if structure == "G"
return CUDSS_MTYPE_GENERAL
elseif structure == "S"
return CUDSS_MTYPE_SYMMETRIC
elseif structure == "H"
return CUDSS_MTYPE_HERMITIAN
elseif structure == "SPD"
return CUDSS_MTYPE_SPD
elseif structure == "HPD"
return CUDSS_MTYPE_HPD
else
throw(ArgumentError("Unknown structure $structure"))
end
end

## view type

function Base.convert(::Type{cudssMatrixViewType_t}, view::Char)
if view == 'F'
return CUDSS_MVIEW_FULL
elseif view == 'L'
return CUDSS_MVIEW_LOWER
elseif view == 'U'
return CUDSS_MVIEW_UPPER
else
throw(ArgumentError("Unknown view $view"))
end
end

## index base

function Base.convert(::Type{cudssIndexBase_t}, index::Char)
if index == 'Z'
return CUDSS_BASE_ZERO
elseif index == 'O'
return CUDSS_BASE_ONE
else
throw(ArgumentError("Unknown index $index"))
end
end

## layout type

function Base.convert(::Type{cudssLayout_t}, layout::Char)
if layout == 'R'
CUDSS_LAYOUT_ROW_MAJOR
elseif layout == 'C'
CUDSS_LAYOUT_COL_MAJOR
else
throw(ArgumentError("Unknown layout $layout"))
end
end

## algorithm type

function Base.convert(::Type{cudssAlgType_t}, algorithm::String)
if algorithm == "default"
CUDSS_ALG_DEFAULT
elseif algorithm == "algo1"
CUDSS_ALG_1
elseif algorithm == "algo2"
CUDSS_ALG_2
elseif algorithm == "algo3"
CUDSS_ALG_3
else
throw(ArgumentError("Unknown algorithm $algorithm"))
end
end

## pivot type

function Base.convert(::Type{cudssPivotType_t}, pivoting::Char)
if pivoting == 'C'
return CUDSS_PIVOT_COL
elseif pivoting == 'R'
return CUDSS_PIVOT_ROW
elseif pivoting == 'N'
return CUDSS_PIVOT_NONE
else
throw(ArgumentError("Unknown pivoting $pivoting"))
end
end

# matrix format type

function Base.convert(::Type{cudssMatrixFormat_t}, format::Char)
if format == 'D'
return CUDSS_MFORMAT_DENSE
elseif format == 'S'
return CUDSS_MFORMAT_CSR
else
throw(ArgumentError("Unknown format $format"))
end
end
1 change: 0 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ using SparseArrays
using LinearAlgebra

import CUDSS: CUDSS_DATA_PARAMETERS, CUDSS_CONFIG_PARAMETERS
import CUDSS: CUDSS_ALG_DEFAULT, CUDSS_ALG_1, CUDSS_ALG_2, CUDSS_ALG_3

@info("CUDSS_INSTALLATION : $(CUDSS.CUDSS_INSTALLATION)")

Expand Down
38 changes: 22 additions & 16 deletions test/test_cudss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,26 @@ function cudss_solver()
cudss("factorization", solver, x_gpu, b_gpu)

@testset "config parameter = $parameter" for parameter in CUDSS_CONFIG_PARAMETERS
val = cudss_get(solver, parameter)
for val in (CUDSS_ALG_DEFAULT, CUDSS_ALG_1, CUDSS_ALG_2, CUDSS_ALG_3)
(parameter == "reordering_alg") && cudss_set(solver, parameter, val)
(parameter == "factorization_alg") && cudss_set(solver, parameter, val)
(parameter == "solve_alg") && cudss_set(solver, parameter, val)
@testset "cudss_get" begin
val = cudss_get(solver, parameter)
end
(parameter == "matching_type") && cudss_set(solver, parameter, 0)
(parameter == "solve_mode") && cudss_set(solver, parameter, 0)
(parameter == "ir_n_steps") && cudss_set(solver, parameter, 1)
(parameter == "ir_tol") && cudss_set(solver, parameter, 1e-8)
for val in ('C', 'R', 'N')
(parameter == "pivot_type") && cudss_set(solver, parameter, val)
@testset "cudss_set" begin
(parameter == "matching_type") && cudss_set(solver, parameter, 0)
(parameter == "solve_mode") && cudss_set(solver, parameter, 0)
(parameter == "ir_n_steps") && cudss_set(solver, parameter, 1)
(parameter == "ir_tol") && cudss_set(solver, parameter, 1e-8)
(parameter == "pivot_threshold") && cudss_set(solver, parameter, 2.0)
(parameter == "pivot_epsilon") && cudss_set(solver, parameter, 1e-12)
(parameter == "max_lu_nnz") && cudss_set(solver, parameter, 10)
for algo in ("default", "algo1", "algo2", "algo3")
(parameter == "reordering_alg") && cudss_set(solver, parameter, algo)
(parameter == "factorization_alg") && cudss_set(solver, parameter, algo)
(parameter == "solve_alg") && cudss_set(solver, parameter, algo)
end
for pivoting in ('C', 'R', 'N')
(parameter == "pivot_type") && cudss_set(solver, parameter, pivoting)
end
end
(parameter == "pivot_threshold") && cudss_set(solver, parameter, 2.0)
(parameter == "pivot_epsilon") && cudss_set(solver, parameter, 1e-12)
(parameter == "max_lu_nnz") && cudss_set(solver, parameter, 10)
end

@testset "data parameter = $parameter" for parameter in CUDSS_DATA_PARAMETERS
Expand All @@ -98,8 +102,10 @@ function cudss_solver()
(parameter == "inertia") && !(structure ∈ ("S", "H")) && continue
val = cudss_get(solver, parameter)
else
perm = Cint[i for i=n:-1:1]
cudss_set(solver, parameter, perm)
perm_cpu = Cint[i for i=n:-1:1]
cudss_set(solver, parameter, perm_cpu)
perm_gpu = CuVector{Cint}(perm_cpu)
cudss_set(solver, parameter, perm_gpu)
end
end
end
Expand Down
Loading