Skip to content

Commit

Permalink
feat: add PETScSNES (#482)
Browse files Browse the repository at this point in the history
* feat: add `PETScSNES`

* feat: support automatic sparsity detection for PETSc

* test: add PETScSNES to the wrapper tests

* docs: add PETSc example

* test: skip PETSc tests on windows

* docs: print the benchmark results
  • Loading branch information
avik-pal authored Oct 27, 2024
1 parent 21b02bd commit 28c0189
Show file tree
Hide file tree
Showing 12 changed files with 384 additions and 40 deletions.
8 changes: 7 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ FixedPointAcceleration = "817d07cb-a79a-5c30-9a31-890123675176"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
MINPACK = "4854310b-de5a-5eb6-a2a5-c1dee2bd17f9"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
NLSolvers = "337daf1e-9722-11e9-073e-8b9effe078ba"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
SIAMFANLEquations = "084e46ad-d928-497d-ad5e-07fa361a48c4"
SpeedMapping = "f1835b91-879b-4a3f-a438-e4baacf14412"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Expand All @@ -55,6 +57,7 @@ NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveMINPACKExt = "MINPACK"
NonlinearSolveNLSolversExt = "NLSolvers"
NonlinearSolveNLsolveExt = ["NLsolve", "LineSearches"]
NonlinearSolvePETScExt = ["PETSc", "MPI"]
NonlinearSolveSIAMFANLEquationsExt = "SIAMFANLEquations"
NonlinearSolveSpeedMappingExt = "SpeedMapping"
NonlinearSolveSundialsExt = "Sundials"
Expand Down Expand Up @@ -86,13 +89,15 @@ LineSearches = "7.3"
LinearAlgebra = "1.10"
LinearSolve = "2.35"
MINPACK = "1.2"
MPI = "0.20.22"
MaybeInplace = "0.1.4"
NLSolvers = "0.5"
NLsolve = "4.5"
NaNMath = "1"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "1"
OrdinaryDiffEqTsit5 = "1.1.0"
PETSc = "0.2"
Pkg = "1.10"
PrecompileTools = "1.2"
Preferences = "1.4"
Expand Down Expand Up @@ -139,6 +144,7 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
NonlinearProblemLibrary = "b7050fa9-e91f-4b37-bcee-a89a063da141"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Expand All @@ -152,4 +158,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
test = ["Aqua", "BandedMatrices", "BenchmarkTools", "CUDA", "Enzyme", "ExplicitImports", "FastLevenbergMarquardt", "FixedPointAcceleration", "Hwloc", "InteractiveUtils", "LeastSquaresOptim", "LineSearches", "MINPACK", "NLSolvers", "NLsolve", "NaNMath", "NonlinearProblemLibrary", "OrdinaryDiffEqTsit5", "PETSc", "Pkg", "Random", "ReTestItems", "SIAMFANLEquations", "SparseConnectivityTracer", "SpeedMapping", "StableRNGs", "StaticArrays", "Sundials", "Test", "Zygote"]
3 changes: 3 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
PETSc = "ace2c81b-2b5f-4b1e-a30d-d662738edfe0"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -31,6 +32,7 @@ AlgebraicMultigrid = "0.5, 0.6"
ArrayInterface = "6, 7"
BenchmarkTools = "1"
BracketingNonlinearSolve = "1"
DiffEqBase = "6.158"
DifferentiationInterface = "0.6.16"
Documenter = "1"
DocumenterCitations = "1"
Expand All @@ -41,6 +43,7 @@ LinearSolve = "2"
NonlinearSolve = "4"
NonlinearSolveBase = "1"
OrdinaryDiffEqTsit5 = "1.1.0"
PETSc = "0.2"
Plots = "1"
Random = "1.10"
SciMLBase = "2.4"
Expand Down
4 changes: 3 additions & 1 deletion docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@ pages = [
"tutorials/modelingtoolkit.md",
"tutorials/small_compile.md",
"tutorials/iterator_interface.md",
"tutorials/optimizing_parameterized_ode.md"
"tutorials/optimizing_parameterized_ode.md",
"tutorials/snes_ex2.md"
],
"Basics" => Any[
"basics/nonlinear_problem.md",
Expand Down Expand Up @@ -45,6 +46,7 @@ pages = [
"api/minpack.md",
"api/nlsolve.md",
"api/nlsolvers.md",
"api/petsc.md",
"api/siamfanlequations.md",
"api/speedmapping.md",
"api/sundials.md"
Expand Down
17 changes: 17 additions & 0 deletions docs/src/api/petsc.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# PETSc.jl

This is a extension for importing solvers from PETSc.jl SNES into the SciML interface. Note
that these solvers do not come by default, and thus one needs to install the package before
using these solvers:

```julia
using Pkg
Pkg.add("PETSc")
using PETSc, NonlinearSolve
```

## Solver API

```@docs
PETScSNES
```
9 changes: 9 additions & 0 deletions docs/src/solvers/nonlinear_system_solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,12 @@ This is a wrapper package for importing solvers from NLSolvers.jl into the SciML
[NLSolvers.jl](https://github.com/JuliaNLSolvers/NLSolvers.jl)

For a list of possible solvers see the [NLSolvers.jl documentation](https://julianlsolvers.github.io/NLSolvers.jl/)

### PETSc.jl

This is a wrapper package for importing solvers from PETSc.jl into the SciML interface.

- [`PETScSNES()`](@ref): A wrapper for
[PETSc.jl](https://github.com/JuliaParallel/PETSc.jl)

For a list of possible solvers see the [PETSc.jl documentation](https://petsc.org/release/manual/snes/)
81 changes: 81 additions & 0 deletions docs/src/tutorials/snes_ex2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# [PETSc SNES Example 2](@id snes_ex2)

This implements `src/snes/examples/tutorials/ex2.c` from PETSc and `examples/SNES_ex2.jl`
from PETSc.jl using automatic sparsity detection and automatic differentiation using
`NonlinearSolve.jl`.

This solves the equations sequentially. Newton method to solve
`u'' + u^{2} = f`, sequentially.

```@example snes_ex2
using NonlinearSolve, PETSc, LinearAlgebra, SparseConnectivityTracer, BenchmarkTools
u0 = fill(0.5, 128)
function form_residual!(resid, x, _)
n = length(x)
xp = LinRange(0.0, 1.0, n)
F = 6xp .+ (xp .+ 1e-12) .^ 6
dx = 1 / (n - 1)
resid[1] = x[1]
for i in 2:(n - 1)
resid[i] = (x[i - 1] - 2x[i] + x[i + 1]) / dx^2 + x[i] * x[i] - F[i]
end
resid[n] = x[n] - 1
return
end
```

To use automatic sparsity detection, we need to specify `sparsity` keyword argument to
`NonlinearFunction`. See [Automatic Sparsity Detection](@ref sparsity-detection) for more
details.

```@example snes_ex2
nlfunc_dense = NonlinearFunction(form_residual!)
nlfunc_sparse = NonlinearFunction(form_residual!; sparsity = TracerSparsityDetector())
nlprob_dense = NonlinearProblem(nlfunc_dense, u0)
nlprob_sparse = NonlinearProblem(nlfunc_sparse, u0)
```

Now we can solve the problem using `PETScSNES` or with one of the native `NonlinearSolve.jl`
solvers.

```@example snes_ex2
sol_dense_nr = solve(nlprob_dense, NewtonRaphson(); abstol = 1e-8)
sol_dense_snes = solve(nlprob_dense, PETScSNES(); abstol = 1e-8)
sol_dense_nr .- sol_dense_snes
```

```@example snes_ex2
sol_sparse_nr = solve(nlprob_sparse, NewtonRaphson(); abstol = 1e-8)
sol_sparse_snes = solve(nlprob_sparse, PETScSNES(); abstol = 1e-8)
sol_sparse_nr .- sol_sparse_snes
```

As expected the solutions are the same (upto floating point error). Now let's compare the
runtimes.

## Runtimes

### Dense Jacobian

```@example snes_ex2
@benchmark solve($(nlprob_dense), $(NewtonRaphson()); abstol = 1e-8)
```

```@example snes_ex2
@benchmark solve($(nlprob_dense), $(PETScSNES()); abstol = 1e-8)
```

### Sparse Jacobian

```@example snes_ex2
@benchmark solve($(nlprob_sparse), $(NewtonRaphson()); abstol = 1e-8)
```

```@example snes_ex2
@benchmark solve($(nlprob_sparse), $(PETScSNES()); abstol = 1e-8)
```
120 changes: 120 additions & 0 deletions ext/NonlinearSolvePETScExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
module NonlinearSolvePETScExt

using FastClosures: @closure
using MPI: MPI
using NonlinearSolveBase: NonlinearSolveBase, get_tolerance
using NonlinearSolve: NonlinearSolve, PETScSNES
using PETSc: PETSc
using SciMLBase: SciMLBase, NonlinearProblem, ReturnCode
using SparseArrays: AbstractSparseMatrix

function SciMLBase.__solve(
prob::NonlinearProblem, alg::PETScSNES, args...; abstol = nothing, reltol = nothing,
maxiters = 1000, alias_u0::Bool = false, termination_condition = nothing,
show_trace::Val{ShT} = Val(false), kwargs...) where {ShT}
# XXX: https://petsc.org/release/manualpages/SNES/SNESSetConvergenceTest/
termination_condition === nothing ||
error("`PETScSNES` does not support termination conditions!")

_f!, u0, resid = NonlinearSolve.__construct_extension_f(prob; alias_u0)
T = eltype(prob.u0)
@assert T PETSc.scalar_types

if alg.petsclib === missing
petsclibidx = findfirst(PETSc.petsclibs) do petsclib
petsclib isa PETSc.PetscLibType{T}
end

if petsclibidx === nothing
error("No compatible PETSc library found for element type $(T). Pass in a \
custom `petsclib` via `PETScSNES(; petsclib = <petsclib>, ....)`.")
end
petsclib = PETSc.petsclibs[petsclibidx]
else
petsclib = alg.petsclib
end
PETSc.initialized(petsclib) || PETSc.initialize(petsclib)

abstol = get_tolerance(abstol, T)
reltol = get_tolerance(reltol, T)

nf = Ref{Int}(0)

f! = @closure (cfx, cx, user_ctx) -> begin
nf[] += 1
fx = cfx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cfx; read = false) : cfx
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
_f!(fx, x)
Base.finalize(fx)
Base.finalize(x)
return
end

snes = PETSc.SNES{T}(petsclib,
alg.mpi_comm === missing ? MPI.COMM_SELF : alg.mpi_comm;
alg.snes_options..., snes_monitor = ShT, snes_rtol = reltol,
snes_atol = abstol, snes_max_it = maxiters)

PETSc.setfunction!(snes, f!, PETSc.VecSeq(zero(u0)))

if alg.autodiff === missing && prob.f.jac === nothing
_jac! = nothing
njac = Ref{Int}(-1)
else
autodiff = alg.autodiff === missing ? nothing : alg.autodiff
if prob.u0 isa Number
_jac! = NonlinearSolve.__construct_extension_jac(
prob, alg, prob.u0, prob.u0; autodiff)
J_init = zeros(T, 1, 1)
else
_jac!, J_init = NonlinearSolve.__construct_extension_jac(
prob, alg, u0, resid; autodiff, initial_jacobian = Val(true))
end

njac = Ref{Int}(0)

if J_init isa AbstractSparseMatrix
PJ = PETSc.MatSeqAIJ(J_init)
jac! = @closure (cx, J, _, user_ctx) -> begin
njac[] += 1
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
if J isa PETSc.AbstractMat
_jac!(user_ctx.jacobian, x)
copyto!(J, user_ctx.jacobian)
PETSc.assemble(J)
else
_jac!(J, x)
end
Base.finalize(x)
return
end
PETSc.setjacobian!(snes, jac!, PJ, PJ)
snes.user_ctx = (; jacobian = J_init)
else
PJ = PETSc.MatSeqDense(J_init)
jac! = @closure (cx, J, _, user_ctx) -> begin
njac[] += 1
x = cx isa Ptr{Nothing} ? PETSc.unsafe_localarray(T, cx; write = false) : cx
_jac!(J, x)
Base.finalize(x)
J isa PETSc.AbstractMat && PETSc.assemble(J)
return
end
PETSc.setjacobian!(snes, jac!, PJ, PJ)
end
end

res = PETSc.solve!(u0, snes)

_f!(resid, res)
u_ = prob.u0 isa Number ? res[1] : res
resid_ = prob.u0 isa Number ? resid[1] : resid

objective = maximum(abs, resid)
# XXX: Return Code from PETSc
retcode = ifelse(objective abstol, ReturnCode.Success, ReturnCode.Failure)
return SciMLBase.build_solution(prob, alg, u_, resid_; retcode, original = snes,
stats = SciMLBase.NLStats(nf[], njac[], -1, -1, -1))
end

end
12 changes: 11 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,15 @@ include("algorithms/extension_algs.jl")
include("utils.jl")
include("default.jl")

const ALL_SOLVER_TYPES = [
Nothing, AbstractNonlinearSolveAlgorithm, GeneralizedDFSane,
GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm,
LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
CMINPACK, PETScSNES,
NonlinearSolvePolyAlgorithm{:NLLS, <:Any}, NonlinearSolvePolyAlgorithm{:NLS, <:Any}
]

include("internal/forward_diff.jl") # we need to define after the algorithms

@setup_workload begin
Expand Down Expand Up @@ -171,8 +180,9 @@ export NonlinearSolvePolyAlgorithm, RobustMultiNewton, FastShortcutNonlinearPoly
FastShortcutNLLSPolyalg

# Extension Algorithms
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, CMINPACK, NLsolveJL, NLSolversJL,
export LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
FixedPointAccelerationJL, SpeedMappingJL, SIAMFANLEquationsJL
export PETScSNES, CMINPACK

# Advanced Algorithms -- Without Bells and Whistles
export GeneralizedFirstOrderAlgorithm, ApproximateJacobianSolveAlgorithm, GeneralizedDFSane
Expand Down
Loading

0 comments on commit 28c0189

Please sign in to comment.