Skip to content

Commit

Permalink
refactor(SimpleNonlinearSolve): reuse more code from NLB
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 29, 2024
1 parent 99d3216 commit 254c2fb
Show file tree
Hide file tree
Showing 21 changed files with 284 additions and 251 deletions.
2 changes: 1 addition & 1 deletion lib/BracketingNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ConcreteStructs = "0.2.3"
ExplicitImports = "1.10.1"
ForwardDiff = "0.10.36"
InteractiveUtils = "<0.0.1, 1"
NonlinearSolveBase = "1"
NonlinearSolveBase = "1.1"
PrecompileTools = "1.2"
Reexport = "1.2"
SciMLBase = "2.50"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module NonlinearSolveBaseBandedMatricesExt

using BandedMatrices: BandedMatrix
using LinearAlgebra: Diagonal

using NonlinearSolveBase: NonlinearSolveBase, Utils

# This is used if we vcat a Banded Jacobian with a Diagonal Matrix in Levenberg
Expand Down
18 changes: 12 additions & 6 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem},
alg, args...; kwargs...)
alg, args...; kwargs...
)
p = Utils.value(prob.p)
if prob isa IntervalNonlinearProblem
tspan = Utils.value.(prob.tspan)
Expand Down Expand Up @@ -55,7 +56,8 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
end

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...
)
p = Utils.value(prob.p)
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
sol = solve(newprob, alg, args...; kwargs...)
Expand Down Expand Up @@ -168,13 +170,17 @@ function NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, f::F, u, p) where {F}
return ForwardDiff.jacobian(Base.Fix2(f, p), u)
end

function NonlinearSolveBase.nonlinearsolve_dual_solution(u::Number, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
function NonlinearSolveBase.nonlinearsolve_dual_solution(
u::Number, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}
) where {T, V, P}
return Dual{T, V, P}(u, partials)
end

function NonlinearSolveBase.nonlinearsolve_dual_solution(u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}) where {T, V, P}
function NonlinearSolveBase.nonlinearsolve_dual_solution(
u::AbstractArray, partials,
::Union{<:AbstractArray{<:Dual{T, V, P}}, Dual{T, V, P}}
) where {T, V, P}
return map(((uᵢ, pᵢ),) -> Dual{T, V, P}(uᵢ, pᵢ), zip(u, Utils.restructure(u, partials)))
end

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
module NonlinearSolveBaseLineSearchExt

using LineSearch: LineSearch, AbstractLineSearchCache
using NonlinearSolveBase: NonlinearSolveBase, InternalAPI
using SciMLBase: SciMLBase

using NonlinearSolveBase: NonlinearSolveBase, InternalAPI

function NonlinearSolveBase.callback_into_cache!(
topcache, cache::AbstractLineSearchCache, args...
)
Expand Down
10 changes: 7 additions & 3 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseLinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
module NonlinearSolveBaseLinearSolveExt

using ArrayInterface: ArrayInterface

using CommonSolve: CommonSolve, init, solve!
using LinearAlgebra: ColumnNorm
using LinearSolve: LinearSolve, QRFactorization, SciMLLinearSolveAlgorithm
using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils
using SciMLBase: ReturnCode, LinearProblem

using LinearAlgebra: ColumnNorm

using NonlinearSolveBase: NonlinearSolveBase, LinearSolveJLCache, LinearSolveResult, Utils

function (cache::LinearSolveJLCache)(;
A = nothing, b = nothing, linu = nothing, du = nothing, p = nothing,
cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...)
cachedata = nothing, reuse_A_if_factorization = false, verbose = true, kwargs...
)
cache.stats.nsolve += 1

update_A!(cache, A, reuse_A_if_factorization)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module NonlinearSolveBaseSparseArraysExt

using NonlinearSolveBase: NonlinearSolveBase, Utils
using SparseArrays: AbstractSparseMatrix, AbstractSparseMatrixCSC, nonzeros, sparse

using NonlinearSolveBase: NonlinearSolveBase, Utils

function NonlinearSolveBase.NAN_CHECK(x::AbstractSparseMatrixCSC)
return any(NonlinearSolveBase.NAN_CHECK, nonzeros(x))
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
module NonlinearSolveBaseSparseMatrixColoringsExt

using ADTypes: ADTypes, AbstractADType
using NonlinearSolveBase: NonlinearSolveBase, Utils
using SciMLBase: SciMLBase, NonlinearFunction

using SparseMatrixColorings: ConstantColoringAlgorithm, GreedyColoringAlgorithm,
LargestFirst

using NonlinearSolveBase: NonlinearSolveBase, Utils

Utils.is_extension_loaded(::Val{:SparseMatrixColorings}) = true

function NonlinearSolveBase.select_fastest_coloring_algorithm(
Expand Down
4 changes: 3 additions & 1 deletion lib/NonlinearSolveBase/src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,9 @@ maybe_unaliased(x::AbstractSciMLOperator, ::Bool) = x
can_setindex(x) = ArrayInterface.can_setindex(x)
can_setindex(::Number) = false

evaluate_f!!(prob::AbstractNonlinearProblem, fu, u, p) = evaluate_f!!(prob.f, fu, u, p)
function evaluate_f!!(prob::AbstractNonlinearProblem, fu, u, p = prob.p)
return evaluate_f!!(prob.f, fu, u, p)
end
function evaluate_f!!(f::NonlinearFunction, fu, u, p)
if SciMLBase.isinplace(f)
f(fu, u, p)
Expand Down
8 changes: 4 additions & 4 deletions lib/SimpleNonlinearSolve/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "2.0.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BracketingNonlinearSolve = "70df07ce-3d50-431d-a3e7-ca6ddb60ac1e"
CommonSolve = "38540f10-b2f7-11e9-35d8-d573e4eb0ff2"
Expand All @@ -21,6 +20,7 @@ NonlinearSolveBase = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
Expand All @@ -37,10 +37,9 @@ SimpleNonlinearSolveTrackerExt = "Tracker"

[compat]
ADTypes = "1.2"
Accessors = "0.1"
Aqua = "0.8.7"
ArrayInterface = "7.16"
BracketingNonlinearSolve = "1"
BracketingNonlinearSolve = "1.1"
ChainRulesCore = "1.24"
CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
Expand All @@ -56,14 +55,15 @@ LineSearch = "0.1.3"
LinearAlgebra = "1.10"
MaybeInplace = "0.1.4"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "1"
NonlinearSolveBase = "1.1"
Pkg = "1.10"
PolyesterForwardDiff = "0.1"
PrecompileTools = "1.2"
Random = "1.10"
Reexport = "1.2"
ReverseDiff = "1.15"
SciMLBase = "2.50"
Setfield = "1.1.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.3"
Test = "1.10"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
module SimpleNonlinearSolveChainRulesCoreExt

using ChainRulesCore: ChainRulesCore, NoTangent

using NonlinearSolveBase: ImmutableNonlinearProblem
using SciMLBase: ChainRulesOriginator, NonlinearLeastSquaresProblem

using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up,
solve_adjoint

function ChainRulesCore.rrule(::typeof(simplenonlinearsolve_solve_up),
function ChainRulesCore.rrule(
::typeof(simplenonlinearsolve_solve_up),
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...
)
out, ∇internal = solve_adjoint(
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...
)
function ∇simplenonlinearsolve_solve_up(Δ)
∂f, ∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Δ)
return (
∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...)
∂f, ∂prob, ∂sensealg, ∂u0, NoTangent(), ∂p, NoTangent(), NoTangent(), ∂args...
)
end
return out, ∇simplenonlinearsolve_solve_up
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
module SimpleNonlinearSolveReverseDiffExt

using ArrayInterface: ArrayInterface
using NonlinearSolveBase: ImmutableNonlinearProblem
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake

using ArrayInterface: ArrayInterface
using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal

using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint
import SimpleNonlinearSolve: simplenonlinearsolve_solve_up

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
module SimpleNonlinearSolveTrackerExt

using ArrayInterface: ArrayInterface
using NonlinearSolveBase: ImmutableNonlinearProblem
using SciMLBase: TrackerOriginator, NonlinearLeastSquaresProblem, remake

using ArrayInterface: ArrayInterface
using Tracker: Tracker, TrackedArray, TrackedReal

using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint
Expand Down
92 changes: 55 additions & 37 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,46 @@
module SimpleNonlinearSolve

using Accessors: @reset
using BracketingNonlinearSolve: BracketingNonlinearSolve
using CommonSolve: CommonSolve, solve, init, solve!
using ConcreteStructs: @concrete
using FastClosures: @closure
using LineSearch: LiFukushimaLineSearch
using LinearAlgebra: LinearAlgebra, dot
using MaybeInplace: @bb, setindex_trait, CannotSetindex, CanSetindex
using PrecompileTools: @compile_workload, @setup_workload
using Reexport: @reexport
using SciMLBase: SciMLBase, AbstractNonlinearAlgorithm, NonlinearFunction, NonlinearProblem,
NonlinearLeastSquaresProblem, IntervalNonlinearProblem, ReturnCode, remake
using Setfield: @set!

using BracketingNonlinearSolve: BracketingNonlinearSolve
using CommonSolve: CommonSolve, solve, init, solve!
using LineSearch: LiFukushimaLineSearch
using MaybeInplace: @bb
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM,
nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution,
AbstractNonlinearSolveAlgorithm
using SciMLBase: SciMLBase, NonlinearFunction, NonlinearProblem,
NonlinearLeastSquaresProblem, ReturnCode, remake

using LinearAlgebra: LinearAlgebra, dot

using StaticArraysCore: StaticArray, SArray, SVector, MArray

# AD Dependencies
using ADTypes: ADTypes, AutoForwardDiff
using DifferentiationInterface: DifferentiationInterface
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM,
nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution
using ForwardDiff: ForwardDiff, Dual

const DI = DifferentiationInterface

abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
const DualNonlinearProblem = NonlinearProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}

const DualNonlinearLeastSquaresProblem = NonlinearLeastSquaresProblem{
<:Union{Number, <:AbstractArray}, iip,
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}
} where {iip, T, V, P}

const safe_similar = NonlinearSolveBase.Utils.safe_similar
abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearSolveAlgorithm end

const NLBUtils = NonlinearSolveBase.Utils

is_extension_loaded(::Val) = false

Expand All @@ -42,61 +55,66 @@ include("raphson.jl")
include("trust_region.jl")

# By Pass the highlevel checks for NonlinearProblem for Simple Algorithms
function CommonSolve.solve(prob::NonlinearProblem,
alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...)
function CommonSolve.solve(
prob::NonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...;
kwargs...
)
prob = convert(ImmutableNonlinearProblem, prob)
return solve(prob, alg, args...; kwargs...)
end

function CommonSolve.solve(
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{
<:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
prob::DualNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; kwargs...
)
if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing
@reset alg.autodiff = AutoForwardDiff()
@set! alg.autodiff = AutoForwardDiff()
end
prob = convert(ImmutableNonlinearProblem, prob)
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

function CommonSolve.solve(
prob::NonlinearLeastSquaresProblem{<:Union{Number, <:AbstractArray}, iip,
<:Union{
<:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
prob::DualNonlinearLeastSquaresProblem, alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; kwargs...
)
if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing
@reset alg.autodiff = AutoForwardDiff()
@set! alg.autodiff = AutoForwardDiff()
end
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original
)
end

function CommonSolve.solve(
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...)
args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...
)
if sensealg === nothing && haskey(prob.kwargs, :sensealg)
sensealg = prob.kwargs[:sensealg]
end
new_u0 = u0 !== nothing ? u0 : prob.u0
new_p = p !== nothing ? p : prob.p
return simplenonlinearsolve_solve_up(prob, sensealg, new_u0, u0 === nothing, new_p,
p === nothing, alg, args...; prob.kwargs..., kwargs...)
return simplenonlinearsolve_solve_up(
prob, sensealg,
new_u0, u0 === nothing,
new_p, p === nothing,
alg, args...;
prob.kwargs..., kwargs...
)
end

function simplenonlinearsolve_solve_up(
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0,
u0_changed, p, p_changed, alg, args...; kwargs...)
u0_changed, p, p_changed, alg, args...; kwargs...
)
(u0_changed || p_changed) && (prob = remake(prob; u0, p))
return SciMLBase.__solve(prob, alg, args...; kwargs...)
end
Expand Down Expand Up @@ -131,7 +149,7 @@ function solve_adjoint_internal end

@compile_workload begin
for prob in (prob_scalar, prob_iip, prob_oop), alg in algs
CommonSolve.solve(prob, alg; abstol = 1e-2)
CommonSolve.solve(prob, alg; abstol = 1e-2, verbose = false)
end
end
end
Expand Down
Loading

0 comments on commit 254c2fb

Please sign in to comment.