Skip to content

Commit

Permalink
Add check keyword argument to generic calls
Browse files Browse the repository at this point in the history
`lu!(fact, A, check=false)` works for CPU but not for GPU because this keyword argument is not allowed. This makes it hard to use the generic interface in generic code because it has slightly different arguments. This adds the `check` keyword argument to match the original argument interface.
  • Loading branch information
ChrisRackauckas authored and amontoison committed Apr 19, 2024
1 parent d00710d commit 8fca66d
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
* `solver`: an opaque structure [`CudssSolver`](@ref) that stores the factors of the LU decomposition.
"""
function LinearAlgebra.lu(A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat
function LinearAlgebra.lu(A::CuSparseMatrixCSR{T,Cint}; check = false) where T <: BlasFloat
n = checksquare(A)
solver = CudssSolver(A, "G", 'F')
x = CudssMatrix(T, n)
Expand All @@ -28,7 +28,7 @@ end
Compute the LU factorization of a sparse matrix `A` on an NVIDIA GPU, reusing the symbolic factorization stored in `solver`.
The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
"""
function LinearAlgebra.lu!(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat
function LinearAlgebra.lu!(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint}; check = false) where T <: BlasFloat
n = checksquare(A)
cudss_set(solver, A)
x = CudssMatrix(T, n)
Expand All @@ -55,7 +55,7 @@ The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
* `solver`: Opaque structure [`CudssSolver`](@ref) that stores the factors of the LDLᴴ decomposition.
"""
function LinearAlgebra.ldlt(A::CuSparseMatrixCSR{T,Cint}; view::Char='F') where T <: BlasFloat
function LinearAlgebra.ldlt(A::CuSparseMatrixCSR{T,Cint}; view::Char='F', check = false) where T <: BlasFloat
n = checksquare(A)
structure = T <: Real ? "S" : "H"
solver = CudssSolver(A, structure, view)
Expand All @@ -76,7 +76,7 @@ LinearAlgebra.ldlt(A::Hermitian{T,<:CuSparseMatrixCSR{T,Cint}}) where T <: BlasF
Compute the LDLᴴ factorization of a sparse matrix `A` on an NVIDIA GPU, reusing the symbolic factorization stored in `solver`.
The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
"""
function LinearAlgebra.ldlt!(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat
function LinearAlgebra.ldlt!(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint}; check = false) where T <: BlasFloat
n = checksquare(A)
cudss_set(solver, A)
x = CudssMatrix(T, n)
Expand All @@ -103,7 +103,7 @@ The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
* `solver`: Opaque structure [`CudssSolver`](@ref) that stores the factors of the LLᴴ decomposition.
"""
function LinearAlgebra.cholesky(A::CuSparseMatrixCSR{T,Cint}; view::Char='F') where T <: BlasFloat
function LinearAlgebra.cholesky(A::CuSparseMatrixCSR{T,Cint}; view::Char='F', check = false) where T <: BlasFloat
n = checksquare(A)
structure = T <: Real ? "SPD" : "HPD"
solver = CudssSolver(A, structure, view)
Expand All @@ -114,16 +114,16 @@ function LinearAlgebra.cholesky(A::CuSparseMatrixCSR{T,Cint}; view::Char='F') wh
return solver
end

LinearAlgebra.cholesky(A::Symmetric{T,<:CuSparseMatrixCSR{T,Cint}}) where T <: BlasReal = LinearAlgebra.cholesky(A.data, view=A.uplo)
LinearAlgebra.cholesky(A::Hermitian{T,<:CuSparseMatrixCSR{T,Cint}}) where T <: BlasFloat = LinearAlgebra.cholesky(A.data, view=A.uplo)
LinearAlgebra.cholesky(A::Symmetric{T,<:CuSparseMatrixCSR{T,Cint}}; check = false) where T <: BlasReal = LinearAlgebra.cholesky(A.data, view=A.uplo)
LinearAlgebra.cholesky(A::Hermitian{T,<:CuSparseMatrixCSR{T,Cint}}; check = false) where T <: BlasFloat = LinearAlgebra.cholesky(A.data, view=A.uplo)

"""
solver = cholesky!(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint})
Compute the LLᴴ factorization of a sparse matrix `A` on an NVIDIA GPU, reusing the symbolic factorization stored in `solver`.
The type `T` can be `Float32`, `Float64`, `ComplexF32` or `ComplexF64`.
"""
function LinearAlgebra.cholesky!(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint}) where T <: BlasFloat
function LinearAlgebra.cholesky!(solver::CudssSolver{T}, A::CuSparseMatrixCSR{T,Cint}; check = false) where T <: BlasFloat
n = checksquare(A)
cudss_set(solver, A)
x = CudssMatrix(T, n)
Expand Down

0 comments on commit 8fca66d

Please sign in to comment.