Skip to content

Commit

Permalink
Put DualAbstractNonlinearProblem solving in subpackages
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY committed Dec 3, 2024
1 parent ce721be commit de6eb96
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 24 deletions.
19 changes: 5 additions & 14 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
AbstractNonlinearSolveAlgorithm, Utils, InternalAPI,
AbstractNonlinearSolveCache
AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm

const DI = DifferentiationInterface

const ALL_SOLVER_TYPES = [
Nothing, AbstractNonlinearSolveAlgorithm
const GENERAL_SOLVER_TYPES = [
Nothing, AbstractNonlinearSolveAlgorithm, NonlinearSolvePolyAlgorithm
]

const DualNonlinearProblem = NonlinearProblem{
Expand Down Expand Up @@ -121,7 +121,7 @@ function NonlinearSolveBase.nonlinearsolve_dual_solution(
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials)))
end

for algType in ALL_SOLVER_TYPES
for algType in GENERAL_SOLVER_TYPES
@eval function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
Expand Down Expand Up @@ -157,7 +157,7 @@ function InternalAPI.reinit!(
return cache
end

for algType in ALL_SOLVER_TYPES
for algType in GENERAL_SOLVER_TYPES
@eval function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
Expand Down Expand Up @@ -200,13 +200,4 @@ nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

"""
pickchunksize(x) = pickchunksize(length(x))
pickchunksize(x::Int)
Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
"""
@inline pickchunksize(x) = pickchunksize(length(x))
@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)

end
2 changes: 2 additions & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,6 @@ export DescentResult, SteepestDescent, NewtonDescent, DampedNewtonDescent, Dogle

export NonlinearSolvePolyAlgorithm

export pickchunksize

end
9 changes: 9 additions & 0 deletions lib/NonlinearSolveBase/src/common_defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,12 @@ function get_tolerance(::Union{StaticArray, Number}, ::Nothing, ::Type{T}) where
# Rational numbers can throw an error if used inside GPU Kernels
return T(real(oneunit(T)) * (eps(real(one(T)))^(real(T)(0.8))))
end

"""
pickchunksize(x) = pickchunksize(length(x))
pickchunksize(x::Int)
Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
"""
@inline pickchunksize(x) = pickchunksize(length(x))
@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)
4 changes: 3 additions & 1 deletion lib/NonlinearSolveFirstOrder/src/NonlinearSolveFirstOrder.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode,
using SciMLJacobianOperators: VecJacOperator, JacVecOperator, StatefulJacobianOperator

using FiniteDiff: FiniteDiff # Default Finite Difference Method
using ForwardDiff: ForwardDiff # Default Forward Mode AD
using ForwardDiff: ForwardDiff, Dual # Default Forward Mode AD

include("raphson.jl")
include("gauss_newton.jl")
Expand All @@ -41,6 +41,8 @@ include("poly_algs.jl")

include("solve.jl")

include("forward_diff.jl")

@setup_workload begin
nonlinear_functions = (
(NonlinearFunction{false, NoSpecialize}((u, p) -> u .* u .- p), 0.1),
Expand Down
34 changes: 34 additions & 0 deletions lib/NonlinearSolveFirstOrder/src/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs...
)
p = NonlinearSolveBase.nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end

function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::GeneralizedFirstOrderAlgorithm, args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end
6 changes: 6 additions & 0 deletions lib/NonlinearSolveQuasiNewton/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
NonlinearSolveQuasiNewtonForwardDiffExt = "ForwardDiff"

[compat]
ADTypes = "1.9.0"
Aqua = "0.8"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
module NonlinearSolveQuasiNewtonForwardDiffExt

using CommonSolve: CommonSolve, solve
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake

using NonlinearSolveBase: NonlinearSolveBase

using NonlinearSolveQuasiNewton: QuasiNewtonAlgorithm

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::QuasiNewtonAlgorithm, args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end

end
6 changes: 6 additions & 0 deletions lib/NonlinearSolveSpectralMethods/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"

[extensions]
NonlinearSolveSpectralMethodsForwardDiffExt = "ForwardDiff"

[compat]
Aqua = "0.8"
BenchmarkTools = "1.5.0"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
module NonlinearSolveSpectralMethodsForwardDiffExt

using CommonSolve: CommonSolve, solve
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake

using NonlinearSolveBase: NonlinearSolveBase

using NonlinearSolveSpectralMethods: GeneralizedDFSane

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::GeneralizedDFSane, args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end

end
11 changes: 2 additions & 9 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using LineSearch: BackTracking
using NonlinearSolveBase: NonlinearSolveBase, InternalAPI, AbstractNonlinearSolveAlgorithm,
AbstractNonlinearSolveCache, Utils, L2_NORM,
enable_timer_outputs, disable_timer_outputs,
NonlinearSolvePolyAlgorithm
NonlinearSolvePolyAlgorithm, pickchunksize

using Preferences: set_preferences!
using SciMLBase: SciMLBase, NLStats, ReturnCode, AbstractNonlinearProblem,
Expand Down Expand Up @@ -53,14 +53,7 @@ include("extension_algs.jl")

include("default.jl")

const ALL_SOLVER_TYPES = [
Nothing, AbstractNonlinearSolveAlgorithm,
GeneralizedDFSane, GeneralizedFirstOrderAlgorithm, QuasiNewtonAlgorithm,
LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
CMINPACK, PETScSNES,
NonlinearSolvePolyAlgorithm
]
include("forward_diff.jl")

@setup_workload begin
nonlinear_functions = (
Expand Down
44 changes: 44 additions & 0 deletions src/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
const EXTENSION_SOLVER_TYPES = [
LeastSquaresOptimJL, FastLevenbergMarquardtJL, NLsolveJL, NLSolversJL,
SpeedMappingJL, FixedPointAccelerationJL, SIAMFANLEquationsJL,
CMINPACK, PETScSNES
]

const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}
const DualAbstractNonlinearProblem = Union{
DualNonlinearProblem, DualNonlinearLeastSquaresProblem
}

for algType in EXTENSION_SOLVER_TYPES
@eval function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
p = NonlinearSolveBase.nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
)
end
end

for algType in EXTENSION_SOLVER_TYPES
@eval function SciMLBase.__solve(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
sol, partials = NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob, alg, args...; kwargs...
)
dual_soln = NonlinearSolveBase.nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end
end

0 comments on commit de6eb96

Please sign in to comment.