Skip to content

Commit

Permalink
fix: auto-set autodiff for ForwardDiff if trying to propagate Duals
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 8, 2024
1 parent 244d7bb commit 8cf9899
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
6 changes: 6 additions & 0 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ function CommonSolve.solve(
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing
@reset alg.autodiff = AutoForwardDiff()
end
prob = convert(ImmutableNonlinearProblem, prob)
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
Expand All @@ -68,6 +71,9 @@ function CommonSolve.solve(
alg::AbstractSimpleNonlinearSolveAlgorithm,
args...;
kwargs...) where {T, V, P, iip}
if hasfield(typeof(alg), :autodiff) && alg.autodiff === nothing
@reset alg.autodiff = AutoForwardDiff()
end
sol, partials = nonlinearsolve_forwarddiff_solve(prob, alg, args...; kwargs...)
dual_soln = nonlinearsolve_dual_solution(sol.u, partials, prob.p)
return SciMLBase.build_solution(
Expand Down
4 changes: 2 additions & 2 deletions lib/SimpleNonlinearSolve/test/core/forward_diff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
jacobian_f(u, p::Number) = one.(u) .* (1 / (2 * p))
jacobian_f(u, p::AbstractArray) = diagm(vec(@. 1 / (2 * p)))

@testset for alg in (
@testset "#(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(),
SimpleTrustRegion(),
SimpleTrustRegion(; nlsolve_update_rule = Val(true)),
Expand Down Expand Up @@ -118,7 +118,7 @@ end

θ_init = θ_true .+ 0.1

@testset for alg in (
for alg in (
SimpleGaussNewton(),
SimpleGaussNewton(; autodiff = AutoForwardDiff()),
SimpleGaussNewton(; autodiff = AutoFiniteDiff()),
Expand Down
3 changes: 1 addition & 2 deletions lib/SimpleNonlinearSolve/test/core/qa_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ end
import ReverseDiff, Tracker, StaticArrays, Zygote
using ExplicitImports, SimpleNonlinearSolve

@test check_no_implicit_imports(
SimpleNonlinearSolve; skip = (Base, Core, SciMLBase)) === nothing
@test check_no_implicit_imports(SimpleNonlinearSolve; skip = (Base, Core)) === nothing
@test check_no_stale_explicit_imports(SimpleNonlinearSolve) === nothing
@test check_all_qualified_accesses_via_owners(SimpleNonlinearSolve) === nothing
end
2 changes: 1 addition & 1 deletion lib/SimpleNonlinearSolve/test/core/rootfind_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
end

@testitem "First Order Methods" setup=[RootfindTestSnippet] tags=[:core] begin
@testset for alg in (
for alg in (
SimpleNewtonRaphson,
SimpleTrustRegion,
(; kwargs...) -> SimpleTrustRegion(; kwargs..., nlsolve_update_rule = Val(true))
Expand Down

0 comments on commit 8cf9899

Please sign in to comment.