diff --git a/lib/NonlinearSolveBase/Project.toml b/lib/NonlinearSolveBase/Project.toml index 3999de770..a89d035cf 100644 --- a/lib/NonlinearSolveBase/Project.toml +++ b/lib/NonlinearSolveBase/Project.toml @@ -20,10 +20,12 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [extensions] +NonlinearSolveBaseDiffEqBaseExt = "DiffEqBase" NonlinearSolveBaseForwardDiffExt = "ForwardDiff" NonlinearSolveBaseSparseArraysExt = "SparseArrays" @@ -33,6 +35,7 @@ ArrayInterface = "7.9" CommonSolve = "0.2.4" Compat = "4.15" ConcreteStructs = "0.2.3" +DiffEqBase = "6.149" DifferentiationInterface = "0.6.1" EnzymeCore = "0.8" FastClosures = "0.3" diff --git a/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl new file mode 100644 index 000000000..346a5ee55 --- /dev/null +++ b/lib/NonlinearSolveBase/ext/NonlinearSolveBaseDiffEqBaseExt.jl @@ -0,0 +1,16 @@ +module NonlinearSolveBaseDiffEqBaseExt + +using DiffEqBase: DiffEqBase +using SciMLBase: remake + +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem + +function DiffEqBase.get_concrete_problem( + prob::ImmutableNonlinearProblem, isadapt; kwargs...) + u0 = DiffEqBase.get_concrete_u0(prob, isadapt, nothing, kwargs) + u0 = DiffEqBase.promote_u0(u0, prob.p, nothing) + p = DiffEqBase.get_concrete_p(prob, kwargs) + return remake(prob; u0 = u0, p = p) +end + +end diff --git a/lib/SimpleNonlinearSolve/Project.toml b/lib/SimpleNonlinearSolve/Project.toml index fae63544a..8c6322730 100644 --- a/lib/SimpleNonlinearSolve/Project.toml +++ b/lib/SimpleNonlinearSolve/Project.toml @@ -46,7 +46,7 @@ CUDA = "5.3" ChainRulesCore = "1.24" CommonSolve = "0.2.4" ConcreteStructs = "0.2.3" -DiffEqBase = "6.155" +DiffEqBase = "6.149" DifferentiationInterface = "0.6.1" Enzyme = "0.13" ExplicitImports = "1.9" @@ -79,6 +79,7 @@ julia = "1.10" AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" @@ -95,4 +96,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["AllocCheck", "Aqua", "CUDA", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"] +test = ["AllocCheck", "Aqua", "CUDA", "DiffEqBase", "Enzyme", "ExplicitImports", "InteractiveUtils", "NonlinearProblemLibrary", "Pkg", "PolyesterForwardDiff", "Random", "ReverseDiff", "SciMLSensitivity", "StaticArrays", "Test", "TestItemRunner", "Tracker", "Zygote"] diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl index 950a04019..4954ffb26 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveDiffEqBaseExt.jl @@ -4,6 +4,8 @@ using DiffEqBase: DiffEqBase using SimpleNonlinearSolve: SimpleNonlinearSolve +SimpleNonlinearSolve.is_extension_loaded(::Val{:DiffEqBase}) = true + function SimpleNonlinearSolve.solve_adjoint_internal(args...; kwargs...) return DiffEqBase._solve_adjoint(args...; kwargs...) end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl index 7b476b3c5..0a407986e 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl @@ -32,6 +32,8 @@ for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) ∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Δ...) return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) end + + return Array(out), ∇simplenonlinearsolve_solve_up end end diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl index ead5a8e29..d29c2ac61 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl @@ -31,6 +31,8 @@ for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) ∂prob, ∂sensealg, ∂u0, ∂p, _, ∂args... = ∇internal(Tracker.data(Δ)) return (∂prob, ∂sensealg, ∂u0, nothing, ∂p, nothing, nothing, ∂args...) end + + return out, ∇simplenonlinearsolve_solve_up end end diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index 0b33c923e..23de7dbc6 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -20,7 +20,8 @@ using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM, + nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution const DI = DifferentiationInterface @@ -47,6 +48,20 @@ function CommonSolve.solve(prob::NonlinearProblem, return solve(prob, alg, args...; kwargs...) end +function CommonSolve.solve( + prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip, + <:Union{ + <:ForwardDiff.Dual{T, V, P}, <:AbstractArray{<:ForwardDiff.Dual{T, V, P}}}}, + alg::AbstractSimpleNonlinearSolveAlgorithm, + args...; + kwargs...) where {T, V, P, iip} + prob = convert(ImmutableNonlinearProblem, prob) + sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...) + dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p) + return SciMLBase.build_solution( + prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original) +end + function CommonSolve.solve( prob::ImmutableNonlinearProblem, alg::AbstractSimpleNonlinearSolveAlgorithm, args...; sensealg = nothing, u0 = nothing, p = nothing, kwargs...) @@ -59,9 +74,8 @@ function CommonSolve.solve( p === nothing, alg, args...; prob.kwargs..., kwargs...) end -function simplenonlinearsolve_solve_up( - prob::ImmutableNonlinearProblem, sensealg, u0, u0_changed, p, p_changed, - alg::AbstractSimpleNonlinearSolveAlgorithm, args...; kwargs...) +function simplenonlinearsolve_solve_up(prob::ImmutableNonlinearProblem, sensealg, u0, + u0_changed, p, p_changed, alg, args...; kwargs...) (u0_changed || p_changed) && (prob = remake(prob; u0, p)) return SciMLBase.__solve(prob, alg, args...; kwargs...) end diff --git a/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl b/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl index c56850eb5..449801ad7 100644 --- a/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl @@ -1,5 +1,6 @@ @testitem "Simple Adjoint Test" tags=[:adjoint] begin - using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote + using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote, DiffEqBase, + SimpleNonlinearSolve ff(u, p) = u .^ 2 .- p diff --git a/lib/SimpleNonlinearSolve/test/core/allocation_tests.jl b/lib/SimpleNonlinearSolve/test/core/allocation_tests.jl index 5da872b71..67cee39c0 100644 --- a/lib/SimpleNonlinearSolve/test/core/allocation_tests.jl +++ b/lib/SimpleNonlinearSolve/test/core/allocation_tests.jl @@ -34,7 +34,7 @@ @test true catch e @error e - @test false broken = (alg isa SimpleHalley) + @test false broken=(alg isa SimpleHalley) end end end diff --git a/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl b/lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl new file mode 100644 index 000000000..e69de29bb