From 0cbc2fc394ab039e6bc7ff0fd44b1cc745f5590a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 6 Oct 2024 21:58:15 -0400 Subject: [PATCH] test: adjoints --- .../test/core/adjoint_tests.jl | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl diff --git a/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl b/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl new file mode 100644 index 000000000..c56850eb5 --- /dev/null +++ b/lib/SimpleNonlinearSolve/test/core/adjoint_tests.jl @@ -0,0 +1,21 @@ +@testitem "Simple Adjoint Test" tags=[:adjoint] begin + using ForwardDiff, ReverseDiff, SciMLSensitivity, Tracker, Zygote + + ff(u, p) = u .^ 2 .- p + + function solve_nlprob(p) + prob = NonlinearProblem{false}(ff, [1.0, 2.0], p) + sol = solve(prob, SimpleNewtonRaphson()) + res = sol isa AbstractArray ? sol : sol.u + return sum(abs2, res) + end + + p = [3.0, 2.0] + + ∂p_zygote = only(Zygote.gradient(solve_nlprob, p)) + ∂p_forwarddiff = ForwardDiff.gradient(solve_nlprob, p) + ∂p_tracker = Tracker.data(only(Tracker.gradient(solve_nlprob, p))) + ∂p_reversediff = ReverseDiff.gradient(solve_nlprob, p) + @test ∂p_zygote ≈ ∂p_tracker ≈ ∂p_reversediff + @test ∂p_zygote ≈ ∂p_forwarddiff ≈ ∂p_tracker ≈ ∂p_reversediff +end