Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: move preconditioners inside linear solvers #485

Merged
merged 3 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/src/native/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ documentation.
uses the LinearSolve.jl default algorithm choice. For more information on available
algorithm choices, see the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `precs`: the choice of preconditioners for the linear solver. Defaults to using no
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `linesearch`: the line search algorithm to use. Defaults to
[`NoLineSearch()`](@extref LineSearch.NoLineSearch), which means that no line search is
performed.
Expand Down
2 changes: 2 additions & 0 deletions docs/src/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
### Breaking Changes in `NonlinearSolve.jl` v4

- `ApproximateJacobianSolveAlgorithm` has been renamed to `QuasiNewtonAlgorithm`.
- Preconditioners for the linear solver needs to be specified with the linear solver
instead of `precs` keyword argument.
- See [common breaking changes](@ref common-breaking-changes-v4v2) below.

### Breaking Changes in `SimpleNonlinearSolve.jl` v2
Expand Down
67 changes: 29 additions & 38 deletions docs/src/tutorials/large_systems.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ end

u0 = init_brusselator_2d(xyd_brusselator)
prob_brusselator_2d = NonlinearProblem(
brusselator_2d_loop, u0, p; abstol = 1e-10, reltol = 1e-10)
brusselator_2d_loop, u0, p; abstol = 1e-10, reltol = 1e-10
)
```

## Choosing Jacobian Types
Expand Down Expand Up @@ -140,7 +141,8 @@ using SparseConnectivityTracer

prob_brusselator_2d_autosparse = NonlinearProblem(
NonlinearFunction(brusselator_2d_loop; sparsity = TracerSparsityDetector()),
u0, p; abstol = 1e-10, reltol = 1e-10)
u0, p; abstol = 1e-10, reltol = 1e-10
)

@btime solve(prob_brusselator_2d_autosparse,
NewtonRaphson(; autodiff = AutoForwardDiff(; chunksize = 12)));
Expand Down Expand Up @@ -235,7 +237,7 @@ choices, see the

Any [LinearSolve.jl-compatible preconditioner](https://docs.sciml.ai/LinearSolve/stable/basics/Preconditioners/)
can be used as a preconditioner in the linear solver interface. To define preconditioners,
one must define a `precs` function in compatible with nonlinear solvers which returns the
one must define a `precs` function in compatible with linear solvers which returns the
left and right preconditioners, matrices which approximate the inverse of `W = I - gamma*J`
used in the solution of the ODE. An example of this with using
[IncompleteLU.jl](https://github.com/haampie/IncompleteLU.jl) is as follows:
Expand All @@ -244,26 +246,18 @@ used in the solution of the ODE. An example of this with using
# FIXME: On 1.10+ this is broken. Skipping this for now.
using IncompleteLU

function incompletelu(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
if newW === nothing || newW
Pl = ilu(W, τ = 50.0)
else
Pl = Plprev
end
Pl, nothing
end
incompletelu(W, p = nothing) = ilu(W, τ = 50.0), LinearAlgebra.I

@btime solve(prob_brusselator_2d_sparse,
NewtonRaphson(linsolve = KrylovJL_GMRES(), precs = incompletelu, concrete_jac = true));
NewtonRaphson(linsolve = KrylovJL_GMRES(precs = incompletelu), concrete_jac = true)
);
nothing # hide
```

Notice a few things about this preconditioner. This preconditioner uses the sparse Jacobian,
and thus we set `concrete_jac = true` to tell the algorithm to generate the Jacobian
(otherwise, a Jacobian-free algorithm is used with GMRES by default). Then `newW = true`
whenever a new `W` matrix is computed, and `newW = nothing` during the startup phase of the
solver. Thus, we do a check `newW === nothing || newW` and when true, it's only at these
points when we update the preconditioner, otherwise we just pass on the previous version.
(otherwise, a Jacobian-free algorithm is used with GMRES by default).

We use `convert(AbstractMatrix,W)` to get the concrete `W` matrix (matching `jac_prototype`,
thus `SpraseMatrixCSC`) which we can use in the preconditioner's definition. Then we use
`IncompleteLU.ilu` on that sparse matrix to generate the preconditioner. We return
Expand All @@ -279,39 +273,36 @@ which is more automatic. The setup is very similar to before:
```@example ill_conditioned_nlprob
using AlgebraicMultigrid

function algebraicmultigrid(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
if newW === nothing || newW
Pl = aspreconditioner(ruge_stuben(convert(AbstractMatrix, W)))
else
Pl = Plprev
end
Pl, nothing
function algebraicmultigrid(W, p = nothing)
return aspreconditioner(ruge_stuben(convert(AbstractMatrix, W))), LinearAlgebra.I
end

@btime solve(prob_brusselator_2d_sparse,
NewtonRaphson(linsolve = KrylovJL_GMRES(), precs = algebraicmultigrid,
concrete_jac = true));
NewtonRaphson(
linsolve = KrylovJL_GMRES(; precs = algebraicmultigrid), concrete_jac = true
)
);
nothing # hide
```

or with a Jacobi smoother:

```@example ill_conditioned_nlprob
function algebraicmultigrid2(W, du, u, p, t, newW, Plprev, Prprev, solverdata)
if newW === nothing || newW
A = convert(AbstractMatrix, W)
Pl = AlgebraicMultigrid.aspreconditioner(AlgebraicMultigrid.ruge_stuben(
A, presmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1))),
postsmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1)))))
else
Pl = Plprev
end
Pl, nothing
function algebraicmultigrid2(W, p = nothing)
A = convert(AbstractMatrix, W)
Pl = AlgebraicMultigrid.aspreconditioner(AlgebraicMultigrid.ruge_stuben(
A, presmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1))),
postsmoother = AlgebraicMultigrid.Jacobi(rand(size(A, 1)))
))
return Pl, LinearAlgebra.I
end

@btime solve(prob_brusselator_2d_sparse,
NewtonRaphson(linsolve = KrylovJL_GMRES(), precs = algebraicmultigrid2,
concrete_jac = true));
@btime solve(
prob_brusselator_2d_sparse,
NewtonRaphson(
linsolve = KrylovJL_GMRES(precs = algebraicmultigrid2), concrete_jac = true
)
);
nothing # hide
```

Expand Down
25 changes: 4 additions & 21 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,15 @@ using LinearAlgebra: ColumnNorm
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils

function (cache::LinearSolveJLCache)(;
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...
A = nothing, b = nothing, linu = nothing,
reuse_A_if_factorization = false, verbose = true, kwargs...
)
cache.stats.nsolve += 1

update_A!(cache, A, reuse_A_if_factorization)
b !== nothing && setproperty!(cache.lincache, :b, b)
linu !== nothing && NonlinearSolveBase.set_lincache_u!(cache, linu)

Plprev = cache.lincache.Pl
Prprev = cache.lincache.Pr

if cache.precs === nothing
Pl, Pr = nothing, nothing
else
Pl, Pr = cache.precs(
cache.lincache.A, du, linu, p, nothing,
A !== nothing, Plprev, Prprev, cachedata
)
end

if Pl !== nothing || Pr !== nothing
Pl, Pr = NonlinearSolveBase.wrap_preconditioners(Pl, Pr, linu)
cache.lincache.Pl = Pl
cache.lincache.Pr = Pr
end

linres = solve!(cache.lincache)
cache.lincache = linres.cache
# Unfortunately LinearSolve.jl doesn't have the most uniform ReturnCode handling
Expand All @@ -58,7 +40,8 @@ function (cache::LinearSolveJLCache)(;
linprob = LinearProblem(A, b; u0 = linres.u)
cache.additional_lincache = init(
linprob, QRFactorization(ColumnNorm()); alias_u0 = false,
alias_A = false, alias_b = false, cache.lincache.Pl, cache.lincache.Pr)
alias_A = false, alias_b = false
)
else
cache.additional_lincache.A = A
cache.additional_lincache.b = b
Expand Down
5 changes: 1 addition & 4 deletions lib/NonlinearSolveBase/src/descent/damped_newton.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""
DampedNewtonDescent(;
linsolve = nothing, precs = nothing, initial_damping, damping_fn
)
DampedNewtonDescent(; linsolve = nothing, initial_damping, damping_fn)

A Newton descent algorithm with damping. The damping factor is computed using the
`damping_fn` function. The descent direction is computed as ``(JᵀJ + λDᵀD) δu = -fu``. For
Expand All @@ -20,7 +18,6 @@ The damping factor returned must be a non-negative number.
"""
@kwdef @concrete struct DampedNewtonDescent <: AbstractDescentDirection
linsolve = nothing
precs = nothing
initial_damping
damping_fn <: AbstractDampingFunction
end
Expand Down
10 changes: 5 additions & 5 deletions lib/NonlinearSolveBase/src/descent/dogleg.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Dogleg(; linsolve = nothing, precs = nothing)
Dogleg(; linsolve = nothing)

Switch between Newton's method and the steepest descent method depending on the size of the
trust region. The trust region is specified via keyword argument `trust_region` to
Expand All @@ -15,18 +15,18 @@ end
supports_trust_region(::Dogleg) = true
get_linear_solver(alg::Dogleg) = get_linear_solver(alg.newton_descent)

function Dogleg(; linsolve = nothing, precs = nothing, damping = Val(false),
function Dogleg(; linsolve = nothing, damping = Val(false),
damping_fn = missing, initial_damping = missing, kwargs...)
if !Utils.unwrap_val(damping)
return Dogleg(NewtonDescent(; linsolve, precs), SteepestDescent(; linsolve, precs))
return Dogleg(NewtonDescent(; linsolve), SteepestDescent(; linsolve))
end
if damping_fn === missing || initial_damping === missing
throw(ArgumentError("`damping_fn` and `initial_damping` must be supplied if \
`damping = Val(true)`."))
end
return Dogleg(
DampedNewtonDescent(; linsolve, precs, damping_fn, initial_damping),
SteepestDescent(; linsolve, precs)
DampedNewtonDescent(; linsolve, damping_fn, initial_damping),
SteepestDescent(; linsolve)
)
end

Expand Down
7 changes: 3 additions & 4 deletions lib/NonlinearSolveBase/src/descent/newton.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
NewtonDescent(; linsolve = nothing, precs = nothing)
NewtonDescent(; linsolve = nothing)

Compute the descent direction as ``J δu = -fu``. For non-square Jacobian problems, this is
commonly referred to as the Gauss-Newton Descent.
Expand All @@ -8,7 +8,6 @@ See also [`Dogleg`](@ref), [`SteepestDescent`](@ref), [`DampedNewtonDescent`](@r
"""
@kwdef @concrete struct NewtonDescent <: AbstractDescentDirection
linsolve = nothing
precs = nothing
end

supports_line_search(::NewtonDescent) = true
Expand Down Expand Up @@ -103,15 +102,15 @@ function InternalAPI.solve!(
@static_timeit cache.timer "linear solve" begin
linres = cache.lincache(;
A = Utils.maybe_symmetric(cache.JᵀJ_cache), b = cache.Jᵀfu_cache,
kwargs..., linu = Utils.safe_vec(δu), du = Utils.safe_vec(δu),
kwargs..., linu = Utils.safe_vec(δu),
reuse_A_if_factorization = !new_jacobian || (idx !== Val(1))
)
end
else
@static_timeit cache.timer "linear solve" begin
linres = cache.lincache(;
A = J, b = Utils.safe_vec(fu),
kwargs..., linu = Utils.safe_vec(δu), du = Utils.safe_vec(δu),
kwargs..., linu = Utils.safe_vec(δu),
reuse_A_if_factorization = !new_jacobian || idx !== Val(1)
)
end
Expand Down
4 changes: 1 addition & 3 deletions lib/NonlinearSolveBase/src/descent/steepest.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
SteepestDescent(; linsolve = nothing, precs = nothing)
SteepestDescent(; linsolve = nothing)

Compute the descent direction as ``δu = -Jᵀfu``. The linear solver and preconditioner are
only used if `J` is provided in the inverted form.
Expand All @@ -8,7 +8,6 @@ See also [`Dogleg`](@ref), [`NewtonDescent`](@ref), [`DampedNewtonDescent`](@ref
"""
@kwdef @concrete struct SteepestDescent <: AbstractDescentDirection
linsolve = nothing
precs = nothing
end

supports_line_search(::SteepestDescent) = true
Expand Down Expand Up @@ -57,7 +56,6 @@ function InternalAPI.solve!(
A = J === nothing ? nothing : transpose(J)
linres = cache.lincache(;
A, b = Utils.safe_vec(fu), kwargs..., linu = Utils.safe_vec(δu),
du = Utils.safe_vec(δu),
reuse_A_if_factorization = !new_jacobian || idx !== Val(1)
)
δu = Utils.restructure(SciMLBase.get_du(cache, idx), linres.u)
Expand Down
3 changes: 1 addition & 2 deletions lib/NonlinearSolveBase/src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ Construct a cache for the Jacobian of `f` w.r.t. `u`.
- `jvp_autodiff`: Automatic Differentiation or Finite Differencing backend for computing
the Jacobian-vector product.
- `linsolve`: Linear Solver Algorithm used to determine if we need a concrete jacobian
or if possible we can just use a [`SciMLJacobianOperators.JacobianOperator`](@ref)
instead.
or if possible we can just use a `JacobianOperator` instead.
"""
function construct_jacobian_cache(
prob, alg, f::NonlinearFunction, fu, u = prob.u0, p = prob.p; stats,
Expand Down
23 changes: 5 additions & 18 deletions lib/NonlinearSolveBase/src/linear_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ end
lincache
linsolve
additional_lincache::Any
precs
stats::NLStats
end

Expand All @@ -34,8 +33,8 @@ handled:

```julia
(cache::LinearSolverCache)(;
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
weight = nothing, cachedata = nothing, reuse_A_if_factorization = false, kwargs...)
A = nothing, b = nothing, linu = nothing, reuse_A_if_factorization = false, kwargs...
)
```

Returns the solution of the system `u` and stores the updated cache in `cache.lincache`.
Expand All @@ -60,15 +59,11 @@ aliasing arguments even after cache construction, i.e., if we passed in an `A` t
not mutated, we do this by copying over `A` to a preconstructed cache.
"""
function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...)
no_preconditioner = !hasfield(typeof(alg), :precs) || alg.precs === nothing

if (A isa Number && b isa Number) || (A isa Diagonal)
return NativeJLLinearSolveCache(A, b, stats)
elseif linsolve isa typeof(\)
!no_preconditioner &&
error("Default Julia Backsolve Operator `\\` doesn't support Preconditioners")
return NativeJLLinearSolveCache(A, b, stats)
elseif no_preconditioner && linsolve === nothing
elseif linsolve === nothing
if (A isa SMatrix || A isa WrappedArray{<:Any, <:SMatrix})
return NativeJLLinearSolveCache(A, b, stats)
end
Expand All @@ -78,17 +73,9 @@ function construct_linear_solver(alg, linsolve, A, b, u; stats, kwargs...)
@bb u_cache = copy(u_fixed)
linprob = LinearProblem(A, b; u0 = u_cache, kwargs...)

if no_preconditioner
precs, Pl, Pr = nothing, nothing, nothing
else
precs = alg.precs
Pl, Pr = precs(A, nothing, u, ntuple(Returns(nothing), 6)...)
end
Pl, Pr = wrap_preconditioners(Pl, Pr, u)

# unlias here, we will later use these as caches
lincache = init(linprob, linsolve; alias_A = false, alias_b = false, Pl, Pr)
return LinearSolveJLCache(lincache, linsolve, nothing, precs, stats)
lincache = init(linprob, linsolve; alias_A = false, alias_b = false)
return LinearSolveJLCache(lincache, linsolve, nothing, stats)
end

function (cache::NativeJLLinearSolveCache)(;
Expand Down
6 changes: 3 additions & 3 deletions lib/NonlinearSolveFirstOrder/src/gauss_newton.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
GaussNewton(;
concrete_jac = nothing, linsolve = nothing, linesearch = missing, precs = nothing,
concrete_jac = nothing, linsolve = nothing, linesearch = missing,
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
)

Expand All @@ -9,12 +9,12 @@ matrices via colored automatic differentiation and preconditioned linear solvers
for large-scale and numerically-difficult nonlinear systems.
"""
function GaussNewton(;
concrete_jac = nothing, linsolve = nothing, linesearch = missing, precs = nothing,
concrete_jac = nothing, linsolve = nothing, linesearch = missing,
autodiff = nothing, vjp_autodiff = nothing, jvp_autodiff = nothing
)
return GeneralizedFirstOrderAlgorithm(;
linesearch,
descent = NewtonDescent(; linsolve, precs),
descent = NewtonDescent(; linsolve),
autodiff, vjp_autodiff, jvp_autodiff,
concrete_jac,
name = :GaussNewton
Expand Down
Loading
Loading