Skip to content

Commit

Permalink
Dont dispatch on init and solve!
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 16, 2024
1 parent ca15839 commit 87750c4
Show file tree
Hide file tree
Showing 17 changed files with 175 additions and 111 deletions.
57 changes: 30 additions & 27 deletions src/abstract_types.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
function __internal_init end
function __internal_solve! end

"""
AbstractDescentAlgorithm
Expand All @@ -10,15 +13,15 @@ in which case we use the normal form equations ``JᵀJ δu = Jᵀ fu``. Note tha
factorization is often the faster choice, but it is not as numerically stable as the least
squares solver.
### `SciMLBase.init` specification
### `__internal_init` specification
```julia
SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractDescentAlgorithm, J, fu, u;
__internal_init(prob::NonlinearProblem{uType, iip}, alg::AbstractDescentAlgorithm, J, fu, u;
pre_inverted::Val{INV} = Val(false), linsolve_kwargs = (;), abstol = nothing,
reltol = nothing, alias_J::Bool = true, shared::Val{N} = Val(1),
kwargs...) where {INV, N, uType, iip} --> AbstractDescentCache
SciMLBase.init(prob::NonlinearLeastSquaresProblem{uType, iip},
__internal_init(prob::NonlinearLeastSquaresProblem{uType, iip},
alg::AbstractDescentAlgorithm, J, fu, u; pre_inverted::Val{INV} = Val(false),
linsolve_kwargs = (;), abstol = nothing, reltol = nothing, alias_J::Bool = true,
shared::Val{N} = Val(1), kwargs...) where {INV, N, uType, iip} --> AbstractDescentCache
Expand Down Expand Up @@ -59,10 +62,10 @@ get_linear_solver(alg::AbstractDescentAlgorithm) = __getproperty(alg, Val(:linso
Abstract Type for all Descent Caches.
### `SciMLBase.solve!` specification
### `__internal_solve!` specification
```julia
δu, success, intermediates = SciMLBase.solve!(cache::AbstractDescentCache, J, fu, u,
δu, success, intermediates = __internal_solve!(cache::AbstractDescentCache, J, fu, u,
idx::Val; skip_solve::Bool = false, kwargs...)
```
Expand Down Expand Up @@ -112,10 +115,10 @@ end
Abstract Type for all Line Search Algorithms used in NonlinearSolve.jl.
### `SciMLBase.init` specification
### `__internal_init` specification
```julia
SciMLBase.init(prob::AbstractNonlinearProblem,
__internal_init(prob::AbstractNonlinearProblem,
alg::AbstractNonlinearSolveLineSearchAlgorithm, f::F, fu, u, p, args...;
internalnorm::IN = DEFAULT_NORM,
kwargs...) where {F, IN} --> AbstractNonlinearSolveLineSearchCache
Expand All @@ -128,10 +131,10 @@ abstract type AbstractNonlinearSolveLineSearchAlgorithm end
Abstract Type for all Line Search Caches used in NonlinearSolve.jl.
### `SciMLBase.solve!` specification
### `__internal_solve!` specification
```julia
SciMLBase.solve!(cache::AbstractNonlinearSolveLineSearchCache, u, du; kwargs...)
__internal_solve!(cache::AbstractNonlinearSolveLineSearchCache, u, du; kwargs...)
```
Returns 2 values:
Expand Down Expand Up @@ -226,10 +229,10 @@ abstract type AbstractLinearSolverCache <: Function end
Abstract Type for Damping Functions in DampedNewton.
### `SciMLBase.init` specification
### `__internal_init` specification
```julia
SciMLBase.init(prob::AbstractNonlinearProblem, f::AbstractDampingFunction, initial_damping,
__internal_init(prob::AbstractNonlinearProblem, f::AbstractDampingFunction, initial_damping,
J, fu, u, args...; internal_norm = DEFAULT_NORM,
kwargs...) --> AbstractDampingFunctionCache
```
Expand All @@ -254,10 +257,10 @@ Abstract Type for the Caches created by AbstractDampingFunctions
- `(cache::AbstractDampingFunctionCache)(::Nothing)`: returns the damping factor. The type
of the damping factor returned from `solve!` is guaranteed to be the same as this.
### `SciMLBase.solve!` specification
### `__internal_solve!` specification
```julia
SciMLBase.solve!(cache::AbstractDampingFunctionCache, J, fu, args...; kwargs...)
__internal_solve!(cache::AbstractDampingFunctionCache, J, fu, args...; kwargs...)
```
Returns the damping factor.
Expand Down Expand Up @@ -310,10 +313,10 @@ Abstract Type for all Jacobian Initialization Algorithms used in NonlinearSolve.
- `jacobian_initialized_preinverted(alg)`: whether or not the Jacobian is initialized
preinverted. Defaults to `false`.
### `SciMLBase.init` specification
### `__internal_init` specification
```julia
SciMLBase.init(prob::AbstractNonlinearProblem, alg::AbstractJacobianInitialization,
__internal_init(prob::AbstractNonlinearProblem, alg::AbstractJacobianInitialization,
solver, f::F, fu, u, p; linsolve = missing, internalnorm::IN = DEFAULT_NORM,
kwargs...)
```
Expand Down Expand Up @@ -345,10 +348,10 @@ Abstract Type for all Approximate Jacobian Update Rules used in NonlinearSolve.j
- `store_inverse_jacobian(alg)`: Return `INV`
### `SciMLBase.init` specification
### `__internal_init` specification
```julia
SciMLBase.init(prob::AbstractNonlinearProblem,
__internal_init(prob::AbstractNonlinearProblem,
alg::AbstractApproximateJacobianUpdateRule, J, fu, u, du, args...;
internalnorm::F = DEFAULT_NORM,
kwargs...) where {F} --> AbstractApproximateJacobianUpdateRuleCache{INV}
Expand All @@ -367,10 +370,10 @@ Abstract Type for all Approximate Jacobian Update Rule Caches used in NonlinearS
- `store_inverse_jacobian(alg)`: Return `INV`
### `SciMLBase.solve!` specification
### `__internal_solve!` specification
```julia
SciMLBase.solve!(cache::AbstractApproximateJacobianUpdateRuleCache, J, fu, u, du;
__internal_solve!(cache::AbstractApproximateJacobianUpdateRuleCache, J, fu, u, du;
kwargs...) --> J / J⁻¹
```
"""
Expand All @@ -383,17 +386,17 @@ store_inverse_jacobian(::AbstractApproximateJacobianUpdateRuleCache{INV}) where
Condition for resetting the Jacobian in Quasi-Newton's methods.
### `SciMLBase.init` specification
### `__internal_init` specification
```julia
SciMLBase.init(alg::AbstractResetCondition, J, fu, u, du, args...;
__internal_init(alg::AbstractResetCondition, J, fu, u, du, args...;
kwargs...) --> ResetCache
```
### `SciMLBase.solve!` specification
### `__internal_solve!` specification
```julia
SciMLBase.solve!(cache::ResetCache, J, fu, u, du) --> Bool
__internal_solve!(cache::ResetCache, J, fu, u, du) --> Bool
```
"""
abstract type AbstractResetCondition end
Expand All @@ -403,10 +406,10 @@ abstract type AbstractResetCondition end
Abstract Type for all Trust Region Methods used in NonlinearSolve.jl.
### `SciMLBase.init` specification
### `__internal_init` specification
```julia
SciMLBase.init(prob::AbstractNonlinearProblem, alg::AbstractTrustRegionMethod,
__internal_init(prob::AbstractNonlinearProblem, alg::AbstractTrustRegionMethod,
f::F, fu, u, p, args...; internalnorm::IF = DEFAULT_NORM,
kwargs...) where {F, IF} --> AbstractTrustRegionMethodCache
```
Expand All @@ -423,10 +426,10 @@ Abstract Type for all Trust Region Method Caches used in NonlinearSolve.jl.
- `last_step_accepted(cache)`: whether or not the last step was accepted. Defaults to
`cache.last_step_accepted`. Should if overloaded if the field is not present.
### `SciMLBase.solve!` specification
### `__internal_solve!` specification
```julia
SciMLBase.solve!(cache::AbstractTrustRegionMethodCache, J, fu, u, δu, descent_stats)
__internal_solve!(cache::AbstractTrustRegionMethodCache, J, fu, u, δu, descent_stats)
```
Returns `last_step_accepted`, updated `u_cache` and `fu_cache`. If the last step was
Expand Down
10 changes: 5 additions & 5 deletions src/algorithms/broyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ function reinit_cache!(cache::NoChangeInStateResetCache, args...; kwargs...)
cache.steps_since_change_dfu = 0
end

function SciMLBase.init(alg::NoChangeInStateReset, J, fu, u, du, args...; kwargs...)
function __internal_init(alg::NoChangeInStateReset, J, fu, u, du, args...; kwargs...)
if alg.check_dfu
@bb dfu = copy(fu)
else
Expand All @@ -110,7 +110,7 @@ function SciMLBase.init(alg::NoChangeInStateReset, J, fu, u, du, args...; kwargs
0)
end

function SciMLBase.solve!(cache::NoChangeInStateResetCache, J, fu, u, du)
function __internal_solve!(cache::NoChangeInStateResetCache, J, fu, u, du)
reset_tolerance = cache.reset_tolerance
if cache.check_du
if any(@closure(x->abs(x) reset_tolerance), du)
Expand Down Expand Up @@ -168,7 +168,7 @@ Broyden Update Rule corresponding to "good broyden's method" [broyden1965class](
internalnorm
end

function SciMLBase.init(prob::AbstractNonlinearProblem,
function __internal_init(prob::AbstractNonlinearProblem,
alg::Union{GoodBroydenUpdateRule, BadBroydenUpdateRule}, J, fu, u, du, args...;
internalnorm::F = DEFAULT_NORM, kwargs...) where {F}
@bb J⁻¹dfu = similar(u)
Expand All @@ -187,7 +187,7 @@ function SciMLBase.init(prob::AbstractNonlinearProblem,
return BroydenUpdateRuleCache{mode}(J⁻¹dfu, dfu, u_cache, du_cache, internalnorm)
end

function SciMLBase.solve!(cache::BroydenUpdateRuleCache{mode}, J⁻¹, fu, u, du) where {mode}
function __internal_solve!(cache::BroydenUpdateRuleCache{mode}, J⁻¹, fu, u, du) where {mode}
T = eltype(u)
@bb @. cache.dfu = fu - cache.dfu
@bb cache.J⁻¹dfu = J⁻¹ × vec(cache.dfu)
Expand All @@ -205,7 +205,7 @@ function SciMLBase.solve!(cache::BroydenUpdateRuleCache{mode}, J⁻¹, fu, u, du
return J⁻¹
end

function SciMLBase.solve!(cache::BroydenUpdateRuleCache{mode}, J⁻¹::Diagonal, fu, u,
function __internal_solve!(cache::BroydenUpdateRuleCache{mode}, J⁻¹::Diagonal, fu, u,
du) where {mode}
T = eltype(u)
@bb @. cache.dfu = fu - cache.dfu
Expand Down
12 changes: 6 additions & 6 deletions src/algorithms/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ struct IllConditionedJacobianReset <: AbstractResetCondition end
condition_number_threshold
end

function SciMLBase.init(alg::IllConditionedJacobianReset, J, fu, u, du, args...; kwargs...)
function __internal_init(alg::IllConditionedJacobianReset, J, fu, u, du, args...; kwargs...)
condition_number_threshold = if J isa AbstractMatrix
inv(eps(real(eltype(J)))^(1 // 2))
else
Expand All @@ -73,7 +73,7 @@ function SciMLBase.init(alg::IllConditionedJacobianReset, J, fu, u, du, args...;
return IllConditionedJacobianResetCache(condition_number_threshold)
end

function SciMLBase.solve!(cache::IllConditionedJacobianResetCache, J, fu, u, du)
function __internal_solve!(cache::IllConditionedJacobianResetCache, J, fu, u, du)
J isa Number && return iszero(J)
J isa Diagonal && return any(iszero, diag(J))
J isa AbstractMatrix && return cond(J) cache.condition_number_threshold
Expand All @@ -98,7 +98,7 @@ Update rule for [`Klement`](@ref).
fu_cache
end

function SciMLBase.init(prob::AbstractNonlinearProblem, alg::KlementUpdateRule, J, fu, u,
function __internal_init(prob::AbstractNonlinearProblem, alg::KlementUpdateRule, J, fu, u,
du, args...; kwargs...)
@bb Jdu = similar(fu)
if J isa Diagonal || J isa Number
Expand All @@ -112,14 +112,14 @@ function SciMLBase.init(prob::AbstractNonlinearProblem, alg::KlementUpdateRule,
return KlementUpdateRuleCache(Jdu, J_cache, J_cache_2, Jdu_cache, fu_cache)
end

function SciMLBase.solve!(cache::KlementUpdateRuleCache, J::Number, fu, u, du)
function __internal_solve!(cache::KlementUpdateRuleCache, J::Number, fu, u, du)
Jdu = J^2 * du^2
J = J + ((fu - cache.fu_cache - J * du) / ifelse(iszero(Jdu), 1e-5, Jdu)) * du * J^2
cache.fu_cache = fu
return J
end

function SciMLBase.solve!(cache::KlementUpdateRuleCache, J_::Diagonal, fu, u, du)
function __internal_solve!(cache::KlementUpdateRuleCache, J_::Diagonal, fu, u, du)
T = eltype(u)
J = _restructure(u, diag(J_))
@bb @. cache.Jdu = (J^2) * (du^2)
Expand All @@ -129,7 +129,7 @@ function SciMLBase.solve!(cache::KlementUpdateRuleCache, J_::Diagonal, fu, u, du
return Diagonal(vec(J))
end

function SciMLBase.solve!(cache::KlementUpdateRuleCache, J::AbstractMatrix, fu, u, du)
function __internal_solve!(cache::KlementUpdateRuleCache, J::AbstractMatrix, fu, u, du)
T = eltype(u)
@bb @. cache.J_cache = J'^2
@bb @. cache.Jdu = du^2
Expand Down
5 changes: 3 additions & 2 deletions src/algorithms/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ end

jacobian_initialized_preinverted(::BroydenLowRankInitialization) = true

function SciMLBase.init(prob::AbstractNonlinearProblem,
function __internal_init(prob::AbstractNonlinearProblem,
alg::BroydenLowRankInitialization{T}, solver, f::F, fu, u, p; maxiters = 1000,
internalnorm::IN = DEFAULT_NORM, kwargs...) where {T, F, IN}
if u isa Number # Use the standard broyden
return init(prob, IdentityInitialization(true, FullStructure()), solver, f, fu, u,
return __internal_init(prob, IdentityInitialization(true, FullStructure()), solver,
f, fu, u,
p; maxiters, kwargs...)
end
# Pay to cost of slightly more allocations to prevent type-instability for StaticArrays
Expand Down
6 changes: 3 additions & 3 deletions src/algorithms/levenberg_marquardt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function returns_norm_form_damping(::Union{LevenbergMarquardtDampingFunction,
return true
end

function SciMLBase.init(prob::AbstractNonlinearProblem,
function __internal_init(prob::AbstractNonlinearProblem,
f::LevenbergMarquardtDampingFunction, initial_damping, J, fu, u, ::Val{NF};
internalnorm::F = DEFAULT_NORM, kwargs...) where {F, NF}
T = promote_type(eltype(u), eltype(fu))
Expand All @@ -115,7 +115,7 @@ end

(damping::LevenbergMarquardtDampingCache)(::Nothing) = damping.J_damped

function SciMLBase.solve!(damping::LevenbergMarquardtDampingCache, J, fu, ::Val{false};
function __internal_solve!(damping::LevenbergMarquardtDampingCache, J, fu, ::Val{false};
kwargs...)
if __can_setindex(damping.J_diag_cache)
sum!(abs2, _vec(damping.J_diag_cache), J')
Expand All @@ -129,7 +129,7 @@ function SciMLBase.solve!(damping::LevenbergMarquardtDampingCache, J, fu, ::Val{
return damping.J_damped
end

function SciMLBase.solve!(damping::LevenbergMarquardtDampingCache, JᵀJ, fu, ::Val{true};
function __internal_solve!(damping::LevenbergMarquardtDampingCache, JᵀJ, fu, ::Val{true};
kwargs...)
damping.DᵀD = __update_LM_diagonal!!(damping.DᵀD, JᵀJ)
@bb @. damping.J_damped = damping.λ * damping.DᵀD
Expand Down
4 changes: 2 additions & 2 deletions src/algorithms/pseudo_transient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ function requires_normal_form_rhs(cache::Union{SwitchedEvolutionRelaxation,
return false
end

function SciMLBase.init(prob::AbstractNonlinearProblem, f::SwitchedEvolutionRelaxation,
function __internal_init(prob::AbstractNonlinearProblem, f::SwitchedEvolutionRelaxation,
initial_damping, J, fu, u, args...; internalnorm::F = DEFAULT_NORM,
kwargs...) where {F}
T = promote_type(eltype(u), eltype(fu))
Expand All @@ -62,7 +62,7 @@ end

(damping::SwitchedEvolutionRelaxationCache)(::Nothing) = damping.α⁻¹

function SciMLBase.solve!(damping::SwitchedEvolutionRelaxationCache, J, fu, args...;
function __internal_solve!(damping::SwitchedEvolutionRelaxationCache, J, fu, args...;
kwargs...)
res_norm = damping.internalnorm(fu)
damping.α⁻¹ *= res_norm / damping.res_norm
Expand Down
Loading

0 comments on commit 87750c4

Please sign in to comment.