Skip to content

Commit

Permalink
refactor!: move preconditioners inside linear solvers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 31, 2024
1 parent fcee7a1 commit 8762dcc
Show file tree
Hide file tree
Showing 17 changed files with 103 additions and 147 deletions.
4 changes: 0 additions & 4 deletions docs/src/native/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ documentation.
uses the LinearSolve.jl default algorithm choice. For more information on available
algorithm choices, see the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `linesearch`: the line search algorithm to use. Defaults to
[`NoLineSearch()`](@extref LineSearch.NoLineSearch), which means that no line search is
performed.
Expand Down
2 changes: 2 additions & 0 deletions docs/src/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Breaking Changes in `NonlinearSolve.jl` v4

- `ApproximateJacobianSolveAlgorithm` has been renamed to `QuasiNewtonAlgorithm`.
- Preconditioners for the linear solver needs to be specified with the linear solver
instead of `precs` keyword argument.
- See [common breaking changes](@ref common-breaking-changes-v4v2) below.

### Breaking Changes in `SimpleNonlinearSolve.jl` v2
Expand Down
67 changes: 29 additions & 38 deletions docs/src/tutorials/large_systems.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ end
u0 = init_brusselator_2d(xyd_brusselator)
prob_brusselator_2d = NonlinearProblem(
brusselator_2d_loop, u0, p; abstol = 1e-10, reltol = 1e-10)
brusselator_2d_loop, u0, p; abstol = 1e-10, reltol = 1e-10
)
```

## Choosing Jacobian Types
Expand Down Expand Up @@ -140,7 +141,8 @@ using SparseConnectivityTracer
prob_brusselator_2d_autosparse = NonlinearProblem(
NonlinearFunction(brusselator_2d_loop; sparsity = TracerSparsityDetector()),
u0, p; abstol = 1e-10, reltol = 1e-10)
u0, p; abstol = 1e-10, reltol = 1e-10
)
@btime solve(prob_brusselator_2d_autosparse,
NewtonRaphson(; autodiff = AutoForwardDiff(; chunksize = 12)));
Expand Down Expand Up @@ -235,7 +237,7 @@ choices, see the

Any [LinearSolve.jl-compatible preconditioner](https://docs.sciml.ai/LinearSolve/stable/basics/Preconditioners/)
can be used as a preconditioner in the linear solver interface. To define preconditioners,
one must define a `precs` function in compatible with nonlinear solvers which returns the
one must define a `precs` function in compatible with linear solvers which returns the
left and right preconditioners, matrices which approximate the inverse of `W = I - gamma*J`
used in the solution of the ODE. An example of this with using
[IncompleteLU.jl](https://github.com/haampie/IncompleteLU.jl) is as follows:
Expand All @@ -244,26 +246,18 @@ used in the solution of the ODE. An example of this with using
# FIXME: On 1.10+ this is broken. Skipping this for now.
using IncompleteLU

function incompletelu(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
if newW === nothing || newW
Pl = ilu(W, τ = 50.0)
else
Pl = Plprev
end
Pl, nothing
end
incompletelu(W, p = nothing) = ilu(W, τ = 50.0), LinearAlgebra.I

@btime solve(prob_brusselator_2d_sparse,
NewtonRaphson(linsolve = KrylovJL_GMRES(), precs = incompletelu, concrete_jac = true));
NewtonRaphson(linsolve = KrylovJL_GMRES(precs = incompletelu), concrete_jac = true)
);
nothing # hide
```

Notice a few things about this preconditioner. This preconditioner uses the sparse Jacobian,
and thus we set `concrete_jac = true` to tell the algorithm to generate the Jacobian
(otherwise, a Jacobian-free algorithm is used with GMRES by default). Then `newW = true`
whenever a new `W` matrix is computed, and `newW = nothing` during the startup phase of the
solver. Thus, we do a check `newW === nothing || newW` and when true, it's only at these
points when we update the preconditioner, otherwise we just pass on the previous version.
(otherwise, a Jacobian-free algorithm is used with GMRES by default).

We use `convert(AbstractMatrix,W)` to get the concrete `W` matrix (matching `jac_prototype`,
thus `SpraseMatrixCSC`) which we can use in the preconditioner's definition. Then we use
`IncompleteLU.ilu` on that sparse matrix to generate the preconditioner. We return
Expand All @@ -279,39 +273,36 @@ which is more automatic. The setup is very similar to before:
```@example ill_conditioned_nlprob
using AlgebraicMultigrid
function algebraicmultigrid(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
if newW === nothing || newW
Pl = aspreconditioner(ruge_stuben(convert(AbstractMatrix, W)))
else
Pl = Plprev
end
Pl, nothing
function algebraicmultigrid(W, p = nothing)
return aspreconditioner(ruge_stuben(convert(AbstractMatrix, W))), LinearAlgebra.I
end
@btime solve(prob_brusselator_2d_sparse,
NewtonRaphson(linsolve = KrylovJL_GMRES(), precs = algebraicmultigrid,
concrete_jac = true));
NewtonRaphson(
linsolve = KrylovJL_GMRES(; precs = algebraicmultigrid), concrete_jac = true
)
);
nothing # hide
```

or with a Jacobi smoother:

```@example ill_conditioned_nlprob
function algebraicmultigrid2(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
if newW === nothing || newW
A = convert(AbstractMatrix, W)
Pl = AlgebraicMultigrid.aspreconditioner(AlgebraicMultigrid.ruge_stuben(
A, presmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1))),
postsmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1)))))
else
Pl = Plprev
end
Pl, nothing
function algebraicmultigrid2(W, p = nothing)
A = convert(AbstractMatrix, W)
Pl = AlgebraicMultigrid.aspreconditioner(AlgebraicMultigrid.ruge_stuben(
A, presmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1))),
postsmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1)))
))
return Pl, LinearAlgebra.I
end
@btime solve(prob_brusselator_2d_sparse,
NewtonRaphson(linsolve = KrylovJL_GMRES(), precs = algebraicmultigrid2,
concrete_jac = true));
@btime solve(
prob_brusselator_2d_sparse,
NewtonRaphson(
linsolve = KrylovJL_GMRES(precs = algebraicmultigrid2), concrete_jac = true
)
);
nothing # hide
```

Expand Down
25 changes: 4 additions & 21 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,15 @@ using LinearAlgebra: ColumnNorm
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils

function (cache::LinearSolveJLCache)(;
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...
A = nothing, b = nothing, linu = nothing,
reuse_A_if_factorization = false, verbose = true, kwargs...
)
cache.stats.nsolve += 1

update_A!(cache, A, reuse_A_if_factorization)
b !== nothing && setproperty!(cache.lincache, :b, b)
linu !== nothing && NonlinearSolveBase.set_lincache_u!(cache, linu)

Plprev = cache.lincache.Pl
Prprev = cache.lincache.Pr

if cache.precs === nothing
Pl, Pr = nothing, nothing
else
Pl, Pr = cache.precs(
cache.lincache.A, du, linu, p, nothing,
A !== nothing, Plprev, Prprev, cachedata
)
end

if Pl !== nothing || Pr !== nothing
Pl, Pr = NonlinearSolveBase.wrap_preconditioners(Pl, Pr, linu)
cache.lincache.Pl = Pl
cache.lincache.Pr = Pr
end

linres = solve!(cache.lincache)
cache.lincache = linres.cache
# Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling
Expand All @@ -58,7 +40,8 @@ function (cache::LinearSolveJLCache)(;
linprob = LinearProblem(A, b; u0 = linres.u)
cache.additional_lincache = init(
linprob, QRFactorization(ColumnNorm()); alias_u0 = false,
alias_A = false, alias_b = false, cache.lincache.Pl, cache.lincache.Pr)
alias_A = false, alias_b = false
)
else
cache.additional_lincache.A = A
cache.additional_lincache.b = b
Expand Down
5 changes: 1 addition & 4 deletions lib/NonlinearSolveBase/src/descent/damped_newton.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""
DampedNewtonDescent(;
linsolve = nothing, precs = nothing, initial_damping, damping_fn
)
DampedNewtonDescent(; linsolve = nothing, initial_damping, damping_fn)
A Newton descent algorithm with damping. The damping factor is computed using the
`damping_fn` function. The descent direction is computed as ``(JᵀJ + λDᵀD) δu = -fu``. For
Expand All @@ -20,7 +18,6 @@ The damping factor returned must be a non-negative number.
"""
@kwdef @concrete struct DampedNewtonDescent <: AbstractDescentDirection
linsolve = nothing
precs = nothing
initial_damping
damping_fn <: AbstractDampingFunction
end
Expand Down
10 changes: 5 additions & 5 deletions lib/NonlinearSolveBase/src/descent/dogleg.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Dogleg(; linsolve = nothing, precs = nothing)
Dogleg(; linsolve = nothing)
Switch between Newton's method and the steepest descent method depending on the size of the
trust region. The trust region is specified via keyword argument `trust_region` to
Expand All @@ -15,18 +15,18 @@ end
supports_trust_region(::Dogleg) = true
get_linear_solver(alg::Dogleg) = get_linear_solver(alg.newton_descent)

function Dogleg(; linsolve = nothing, precs = nothing, damping = Val(false),
function Dogleg(; linsolve = nothing, damping = Val(false),
damping_fn = missing, initial_damping = missing, kwargs...)
if !Utils.unwrap_val(damping)
return Dogleg(NewtonDescent(; linsolve, precs), SteepestDescent(; linsolve, precs))
return Dogleg(NewtonDescent(; linsolve), SteepestDescent(; linsolve))
end
if damping_fn === missing || initial_damping === missing
throw(ArgumentError("`damping_fn` and `initial_damping` must be supplied if \
`damping = Val(true)`."))
end
return Dogleg(
DampedNewtonDescent(; linsolve, precs, damping_fn, initial_damping),
SteepestDescent(; linsolve, precs)
DampedNewtonDescent(; linsolve, damping_fn, initial_damping),
SteepestDescent(; linsolve)
)
end

Expand Down
7 changes: 3 additions & 4 deletions lib/NonlinearSolveBase/src/descent/newton.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
NewtonDescent(; linsolve = nothing, precs = nothing)
NewtonDescent(; linsolve = nothing)
Compute the descent direction as ``J δu = -fu``. For non-square Jacobian problems, this is
commonly referred to as the Gauss-Newton Descent.
Expand All @@ -8,7 +8,6 @@ See also [`Dogleg`](@ref), [`SteepestDescent`](@ref), [`DampedNewtonDescent`](@r
"""
@kwdef @concrete struct NewtonDescent <: AbstractDescentDirection
linsolve = nothing
precs = nothing
end

supports_line_search(::NewtonDescent) = true
Expand Down Expand Up @@ -103,15 +102,15 @@ function InternalAPI.solve!(
@static_timeit cache.timer "linear solve" begin
linres = cache.lincache(;
A = Utils.maybe_symmetric(cache.JᵀJ_cache), b = cache.Jᵀfu_cache,
kwargs..., linu = Utils.safe_vec(δu), du = Utils.safe_vec(δu),
kwargs..., linu = Utils.safe_vec(δu),
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1))
)
end
else
@static_timeit cache.timer "linear solve" begin
linres = cache.lincache(;
A = J, b = Utils.safe_vec(fu),
kwargs..., linu = Utils.safe_vec(δu), du = Utils.safe_vec(δu),
kwargs..., linu = Utils.safe_vec(δu),
reuse_A_if_factorization = !new_jacobian || idx !== Val(1)
)
end
Expand Down
4 changes: 1 addition & 3 deletions lib/NonlinearSolveBase/src/descent/steepest.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
SteepestDescent(; linsolve = nothing, precs = nothing)
SteepestDescent(; linsolve = nothing)
Compute the descent direction as ``δu = -Jᵀfu``. The linear solver and preconditioner are
only used if `J` is provided in the inverted form.
Expand All @@ -8,7 +8,6 @@ See also [`Dogleg`](@ref), [`NewtonDescent`](@ref), [`DampedNewtonDescent`](@ref
"""
@kwdef @concrete struct SteepestDescent <: AbstractDescentDirection
linsolve = nothing
precs = nothing
end

supports_line_search(::SteepestDescent) = true
Expand Down Expand Up @@ -57,7 +56,6 @@ function InternalAPI.solve!(
A = J === nothing ? nothing : transpose(J)
linres = cache.lincache(;
A, b = Utils.safe_vec(fu), kwargs..., linu = Utils.safe_vec(δu),
du = Utils.safe_vec(δu),
reuse_A_if_factorization = !new_jacobian || idx !== Val(1)
)
δu = Utils.restructure(SciMLBase.get_du(cache, idx), linres.u)
Expand Down
23 changes: 5 additions & 18 deletions lib/NonlinearSolveBase/src/linear_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ end
lincache
linsolve
additional_lincache::Any
precs
stats::NLStats
end

Expand All @@ -34,8 +33,8 @@ handled:
```julia
(cache::LinearSolverCache)(;
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
weight = nothing, cachedata = nothing, reuse_A_if_factorization = false, kwargs...)
A = nothing, b = nothing, linu = nothing, reuse_A_if_factorization = false, kwargs...
)
```
Returns the solution of the system `u` and stores the updated cache in `cache.lincache`.
Expand All @@ -60,15 +59,11 @@ aliasing arguments even after cache construction, i.e., if we passed in an `A` t
not mutated, we do this by copying over `A` to a preconstructed cache.
"""
function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...)
no_preconditioner = !hasfield(typeof(alg), :precs) || alg.precs === nothing

if (A isa Number && b isa Number) || (A isa Diagonal)
return NativeJLLinearSolveCache(A, b, stats)
elseif linsolve isa typeof(\)
!no_preconditioner &&
error("Default Julia Backsolve Operator `\\` doesn't support Preconditioners")
return NativeJLLinearSolveCache(A, b, stats)
elseif no_preconditioner && linsolve === nothing
elseif linsolve === nothing
if (A isa SMatrix || A isa WrappedArray{<:Any, <:SMatrix})
return NativeJLLinearSolveCache(A, b, stats)
end
Expand All @@ -78,17 +73,9 @@ function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...)
@bb u_cache = copy(u_fixed)
linprob = LinearProblem(A, b; u0 = u_cache, kwargs...)

if no_preconditioner
precs, Pl, Pr = nothing, nothing, nothing
else
precs = alg.precs
Pl, Pr = precs(A, nothing, u, ntuple(Returns(nothing), 6)...)
end
Pl, Pr = wrap_preconditioners(Pl, Pr, u)

# unlias here, we will later use these as caches
lincache = init(linprob, linsolve; alias_A = false, alias_b = false, Pl, Pr)
return LinearSolveJLCache(lincache, linsolve, nothing, precs, stats)
lincache = init(linprob, linsolve; alias_A = false, alias_b = false)
return LinearSolveJLCache(lincache, linsolve, nothing, stats)
end

function (cache::NativeJLLinearSolveCache)(;
Expand Down
6 changes: 3 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/gauss_newton.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
GaussNewton(;
concrete_jac = nothing, linsolve = nothing, linesearch = missing, precs = nothing,
concrete_jac = nothing, linsolve = nothing, linesearch = missing,
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
)
Expand All @@ -9,12 +9,12 @@ matrices via colored automatic differentiation and preconditioned linear solvers
for large-scale and numerically-difficult nonlinear systems.
"""
function GaussNewton(;
concrete_jac = nothing, linsolve = nothing, linesearch = missing, precs = nothing,
concrete_jac = nothing, linsolve = nothing, linesearch = missing,
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
)
return GeneralizedFirstOrderAlgorithm(;
linesearch,
descent = NewtonDescent(; linsolve, precs),
descent = NewtonDescent(; linsolve),
autodiff, vjp_autodiff, jvp_autodiff,
concrete_jac,
name = :GaussNewton
Expand Down
5 changes: 2 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/levenberg_marquardt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
LevenbergMarquardt(;
linsolve = nothing, precs = nothing,
linsolve = nothing,
damping_initial::Real = 1.0, α_geodesic::Real = 0.75, disable_geodesic = Val(false),
damping_increase_factor::Real = 2.0, damping_decrease_factor::Real = 3.0,
finite_diff_step_geodesic = 0.1, b_uphill::Real = 1.0, min_damping_D::Real = 1e-8,
Expand Down Expand Up @@ -33,15 +33,14 @@ For the remaining arguments, see [`GeodesicAcceleration`](@ref) and
[`NonlinearSolveFirstOrder.LevenbergMarquardtTrustRegion`](@ref) documentations.
"""
function LevenbergMarquardt(;
linsolve = nothing, precs = nothing,
linsolve = nothing,
damping_initial::Real = 1.0, α_geodesic::Real = 0.75, disable_geodesic = Val(false),
damping_increase_factor::Real = 2.0, damping_decrease_factor::Real = 3.0,
finite_diff_step_geodesic = 0.1, b_uphill::Real = 1.0, min_damping_D::Real = 1e-8,
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
)
descent = DampedNewtonDescent(;
linsolve,
precs,
initial_damping = damping_initial,
damping_fn = LevenbergMarquardtDampingFunction(
damping_increase_factor, damping_decrease_factor, min_damping_D
Expand Down
Loading

0 comments on commit 8762dcc

Please sign in to comment.