Skip to content

Commit

Permalink
refactor: centralize autodiff selection
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 23, 2024
1 parent a0e22f4 commit 6dbc95e
Show file tree
Hide file tree
Showing 21 changed files with 247 additions and 301 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "4.0.0"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Expand Down Expand Up @@ -65,6 +66,7 @@ ArrayInterface = "7.16"
BandedMatrices = "1.5"
BenchmarkTools = "1.4"
CUDA = "5.5"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.158.3"
DifferentiationInterface = "0.6.1"
Expand Down
11 changes: 6 additions & 5 deletions docs/src/native/solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ documentation.
preconditioners. For more information on specifying preconditioners for LinearSolve
algorithms, consult the
[LinearSolve.jl documentation](https://docs.sciml.ai/LinearSolve/stable/).
- `linesearch`: the line search algorithm to use. Defaults to [`NoLineSearch()`](@extref LineSearch.NoLineSearch),
which means that no line search is performed.
- `autodiff`/`jacobian_ad`: etermines the backend used for the Jacobian. Note that this
- `linesearch`: the line search algorithm to use. Defaults to
[`NoLineSearch()`](@extref LineSearch.NoLineSearch), which means that no line search is
performed.
- `autodiff`: etermines the backend used for the Jacobian. Note that this
argument is ignored if an analytical Jacobian is passed, as that will be used instead.
Defaults to `nothing` which means that a default is selected according to the problem
specification! Valid choices are types from ADTypes.jl.
- `forward_ad`/`vjp_autodiff`: similar to `autodiff`, but is used to compute Jacobian
- `vjp_autodiff`: similar to `autodiff`, but is used to compute Jacobian
Vector Products. Ignored if the NonlinearFunction contains the `jvp` function.
- `reverse_ad`/`vjp_autodiff`: similar to `autodiff`, but is used to compute Vector
- `vjp_autodiff`: similar to `autodiff`, but is used to compute Vector
Jacobian Products. Ignored if the NonlinearFunction contains the `vjp` function.
- `concrete_jac`: whether to build a concrete Jacobian. If a Krylov-subspace method is
used, then the Jacobian will not be constructed and instead direct Jacobian-Vector
Expand Down
6 changes: 3 additions & 3 deletions docs/src/release_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
- Use of termination conditions from `DiffEqBase` has been removed. Use the termination
conditions from `NonlinearSolveBase` instead.
- If no autodiff is provided, we now choose from a list of autodiffs based on the packages
loaded. For example, if `Enzyme` is loaded, we will default to that. In general, we
don't guarantee the exact autodiff selected if `autodiff` is not provided (i.e.
`nothing`).
loaded. For example, if `Enzyme` is loaded, we will default to that (for reverse mode).
In general, we don't guarantee the exact autodiff selected if `autodiff` is not provided
(i.e. `nothing`).

## Dec '23

Expand Down
32 changes: 21 additions & 11 deletions lib/NonlinearSolveBase/src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,34 @@

# Ordering is important here. We want to select the first one that is compatible with the
# problem.
const ReverseADs = (
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
ADTypes.AutoZygote(),
ADTypes.AutoTracker(),
ADTypes.AutoReverseDiff(; compile = true),
ADTypes.AutoReverseDiff(),
ADTypes.AutoFiniteDiff()
)
# XXX: Remove this once Enzyme is properly supported on Julia 1.11+
@static if VERSION v"1.11-"
const ReverseADs = (
ADTypes.AutoZygote(),
ADTypes.AutoTracker(),
ADTypes.AutoReverseDiff(; compile = true),
ADTypes.AutoReverseDiff(),
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
ADTypes.AutoFiniteDiff()
)
else
const ReverseADs = (
ADTypes.AutoEnzyme(; mode = EnzymeCore.Reverse),
ADTypes.AutoZygote(),
ADTypes.AutoTracker(),
ADTypes.AutoReverseDiff(; compile = true),
ADTypes.AutoReverseDiff(),
ADTypes.AutoFiniteDiff()
)
end

const ForwardADs = (
ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward),
ADTypes.AutoPolyesterForwardDiff(),
ADTypes.AutoForwardDiff(),
ADTypes.AutoEnzyme(; mode = EnzymeCore.Forward),
ADTypes.AutoFiniteDiff()
)

# TODO: Handle Sparsity

function select_forward_mode_autodiff(
prob::AbstractNonlinearProblem, ad::AbstractADType; warn_check_mode::Bool = true)
if warn_check_mode && !(ADTypes.mode(ad) isa ADTypes.ForwardMode)
Expand Down
1 change: 0 additions & 1 deletion lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ function prepare_jacobian(prob, autodiff, _, x::Number)
if SciMLBase.has_jac(prob.f) || SciMLBase.has_vjp(prob.f) || SciMLBase.has_jvp(prob.f)
return AnalyticJacobian()
end
# return DI.prepare_derivative(prob.f, autodiff, x, Constant(prob.p))
return DINoPreparation()
end
function prepare_jacobian(prob, autodiff, fx, x)
Expand Down
130 changes: 68 additions & 62 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using PrecompileTools: @compile_workload, @setup_workload

using ArrayInterface: ArrayInterface, can_setindex, restructure, fast_scalar_indexing,
ismutable
using CommonSolve: solve, init, solve!
using ConcreteStructs: @concrete
using DiffEqBase: DiffEqBase # Needed for `init` / `solve` dispatches
using FastClosures: @closure
Expand All @@ -21,13 +22,18 @@ using NonlinearSolveBase: NonlinearSolveBase, nonlinearsolve_forwarddiff_solve,
nonlinearsolve_dual_solution, nonlinearsolve_∂f_∂p,
nonlinearsolve_∂f_∂u, L2_NORM, AbsNormTerminationMode,
AbstractNonlinearTerminationMode,
AbstractSafeBestNonlinearTerminationMode
AbstractSafeBestNonlinearTerminationMode,
select_forward_mode_autodiff, select_reverse_mode_autodiff,
select_jacobian_autodiff
using Printf: @printf
using Preferences: Preferences, @load_preference, @set_preferences!
using RecursiveArrayTools: recursivecopy!
using SciMLBase: AbstractNonlinearAlgorithm, AbstractNonlinearProblem, _unwrap_val,
isinplace, NLStats
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, AbstractNonlinearProblem,
_unwrap_val, isinplace, NLStats, NonlinearFunction,
NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, get_du, step!,
set_u!, LinearProblem, IdentityOperator
using SciMLOperators: AbstractSciMLOperator
using SimpleNonlinearSolve: SimpleNonlinearSolve
using StaticArraysCore: StaticArray, SVector, SArray, MArray, Size, SMatrix
using SymbolicIndexingInterface: SymbolicIndexingInterface, ParameterIndexingProxy,
symbolic_container, parameter_values, state_values, getu,
Expand Down Expand Up @@ -95,65 +101,65 @@ include("internal/forward_diff.jl") # we need to define after the algorithms
include("utils.jl")
include("default.jl")

@setup_workload begin
nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
(NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
probs_nls = NonlinearProblem[]
for (fn, u0) in nlfuncs
push!(probs_nls, NonlinearProblem(fn, u0, 2.0))
end

nls_algs = (
NewtonRaphson(),
TrustRegion(),
LevenbergMarquardt(),
Broyden(),
Klement(),
nothing
)

probs_nlls = NonlinearLeastSquaresProblem[]
nlfuncs = (
(NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
(NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
(
NonlinearFunction{true}(
(du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)),
[0.1, 0.0]),
(
NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
resid_prototype = zeros(4)),
[0.1, 0.1]
)
)
for (fn, u0) in nlfuncs
push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
end

nlls_algs = (
LevenbergMarquardt(),
GaussNewton(),
TrustRegion(),
nothing
)

@compile_workload begin
@sync begin
for T in (Float32, Float64), (fn, u0) in nlfuncs
Threads.@spawn NonlinearProblem(fn, T.(u0), T(2))
end
for (fn, u0) in nlfuncs
Threads.@spawn NonlinearLeastSquaresProblem(fn, u0, 2.0)
end
for prob in probs_nls, alg in nls_algs
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
end
for prob in probs_nlls, alg in nlls_algs
Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
end
end
end
end
# @setup_workload begin
# nlfuncs = ((NonlinearFunction{false}((u, p) -> u .* u .- p), 0.1),
# (NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p), [0.1]))
# probs_nls = NonlinearProblem[]
# for (fn, u0) in nlfuncs
# push!(probs_nls, NonlinearProblem(fn, u0, 2.0))
# end

# nls_algs = (
# NewtonRaphson(),
# TrustRegion(),
# LevenbergMarquardt(),
# Broyden(),
# Klement(),
# nothing
# )

# probs_nlls = NonlinearLeastSquaresProblem[]
# nlfuncs = (
# (NonlinearFunction{false}((u, p) -> (u .^ 2 .- p)[1:1]), [0.1, 0.0]),
# (NonlinearFunction{false}((u, p) -> vcat(u .* u .- p, u .* u .- p)), [0.1, 0.1]),
# (
# NonlinearFunction{true}(
# (du, u, p) -> du[1] = u[1] * u[1] - p, resid_prototype = zeros(1)),
# [0.1, 0.0]),
# (
# NonlinearFunction{true}((du, u, p) -> du .= vcat(u .* u .- p, u .* u .- p),
# resid_prototype = zeros(4)),
# [0.1, 0.1]
# )
# )
# for (fn, u0) in nlfuncs
# push!(probs_nlls, NonlinearLeastSquaresProblem(fn, u0, 2.0))
# end

# nlls_algs = (
# LevenbergMarquardt(),
# GaussNewton(),
# TrustRegion(),
# nothing
# )

# @compile_workload begin
# @sync begin
# for T in (Float32, Float64), (fn, u0) in nlfuncs
# Threads.@spawn NonlinearProblem(fn, T.(u0), T(2))
# end
# for (fn, u0) in nlfuncs
# Threads.@spawn NonlinearLeastSquaresProblem(fn, u0, 2.0)
# end
# for prob in probs_nls, alg in nls_algs
# Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
# end
# for prob in probs_nlls, alg in nlls_algs
# Threads.@spawn solve(prob, alg; abstol = 1e-2, verbose = false)
# end
# end
# end
# end

# Rexexports
@reexport using SciMLBase, SimpleNonlinearSolve, NonlinearSolveBase
Expand Down
59 changes: 30 additions & 29 deletions src/algorithms/broyden.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Broyden(; max_resets::Int = 100, linesearch = NoLineSearch(), reset_tolerance = nothing,
Broyden(; max_resets::Int = 100, linesearch = nothing, reset_tolerance = nothing,
init_jacobian::Val = Val(:identity), autodiff = nothing, alpha = nothing)
An implementation of `Broyden`'s Method [broyden1965class](@cite) with resetting and line
Expand Down Expand Up @@ -29,36 +29,37 @@ search.
problem
"""
function Broyden(;
max_resets = 100, linesearch = NoLineSearch(), reset_tolerance = nothing,
init_jacobian::Val{IJ} = Val(:identity), autodiff = nothing,
alpha = nothing, update_rule::Val{UR} = Val(:good_broyden)) where {IJ, UR}
if IJ === :identity
if UR === :diagonal
initialization = IdentityInitialization(alpha, DiagonalStructure())
else
initialization = IdentityInitialization(alpha, FullStructure())
end
elseif IJ === :true_jacobian
initialization = TrueJacobianInitialization(FullStructure(), autodiff)
else
throw(ArgumentError("`init_jacobian` must be one of `:identity` or \
`:true_jacobian`"))
end
max_resets = 100, linesearch = nothing, reset_tolerance = nothing,
init_jacobian = Val(:identity), autodiff = nothing, alpha = nothing,
update_rule = Val(:good_broyden))
initialization = broyden_init(init_jacobian, update_rule, autodiff, alpha)
update_rule = broyden_update_rule(update_rule)
return ApproximateJacobianSolveAlgorithm{
init_jacobian isa Val{:true_jacobian}, :Broyden}(;
linesearch, descent = NewtonDescent(), update_rule, max_resets, initialization,
reinit_rule = NoChangeInStateReset(; reset_tolerance))
end

update_rule = if UR === :good_broyden
GoodBroydenUpdateRule()
elseif UR === :bad_broyden
BadBroydenUpdateRule()
elseif UR === :diagonal
GoodBroydenUpdateRule()
else
throw(ArgumentError("`update_rule` must be one of `:good_broyden`, `:bad_broyden`, \
or `:diagonal`"))
end
function broyden_init(::Val{:identity}, ::Val{:diagonal}, autodiff, alpha)
return IdentityInitialization(alpha, DiagonalStructure())
end
function broyden_init(::Val{:identity}, ::Val, autodiff, alpha)
IdentityInitialization(alpha, FullStructure())
end
function broyden_init(::Val{:true_jacobian}, ::Val, autodiff, alpha)
return TrueJacobianInitialization(FullStructure(), autodiff)
end
function broyden_init(::Val{IJ}, ::Val{UR}, autodiff, alpha) where {IJ, UR}
error("Unknown combination of `init_jacobian = Val($(Meta.quot(IJ)))` and \
`update_rule = Val($(Meta.quot(UR)))`. Please choose a valid combination.")
end

return ApproximateJacobianSolveAlgorithm{IJ === :true_jacobian, :Broyden}(;
linesearch, descent = NewtonDescent(), update_rule, max_resets,
initialization, reinit_rule = NoChangeInStateReset(; reset_tolerance))
broyden_update_rule(::Val{:good_broyden}) = GoodBroydenUpdateRule()
broyden_update_rule(::Val{:bad_broyden}) = BadBroydenUpdateRule()
broyden_update_rule(::Val{:diagonal}) = GoodBroydenUpdateRule()
function broyden_update_rule(::Val{UR}) where {UR}
error("Unknown update rule `update_rule = Val($(Meta.quot(UR)))`. Please choose a \
valid update rule.")
end

# Checks for no significant change for `nsteps`
Expand Down
1 change: 1 addition & 0 deletions src/algorithms/dfsane.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# XXX: remove kwargs with unicode
"""
DFSane(; σ_min = 1 // 10^10, σ_max = 1e10, σ_1 = 1, M::Int = 10, γ = 1 // 10^4,
τ_min = 1 // 10, τ_max = 1 // 2, n_exp::Int = 2, max_inner_iterations::Int = 100,
Expand Down
13 changes: 7 additions & 6 deletions src/algorithms/gauss_newton.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""
GaussNewton(; concrete_jac = nothing, linsolve = nothing, linesearch = NoLineSearch(),
precs = DEFAULT_PRECS, adkwargs...)
GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
linesearch = nothing, vjp_autodiff = nothing, autodiff = nothing,
jvp_autodiff = nothing)
An advanced GaussNewton implementation with support for efficient handling of sparse
matrices via colored automatic differentiation and preconditioned linear solvers. Designed
for large-scale and numerically-difficult nonlinear least squares problems.
"""
function GaussNewton(; concrete_jac = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
linesearch = NoLineSearch(), vjp_autodiff = nothing, autodiff = nothing)
descent = NewtonDescent(; linsolve, precs)
return GeneralizedFirstOrderAlgorithm(; concrete_jac, name = :GaussNewton, descent,
jacobian_ad = autodiff, reverse_ad = vjp_autodiff, linesearch)
linesearch = nothing, vjp_autodiff = nothing, autodiff = nothing,
jvp_autodiff = nothing)
return GeneralizedFirstOrderAlgorithm{concrete_jac, :GaussNewton}(; linesearch,
descent = NewtonDescent(; linsolve, precs), autodiff, vjp_autodiff, jvp_autodiff)
end
Loading

0 comments on commit 6dbc95e

Please sign in to comment.