Skip to content

Commit

Permalink
Fix some tests due to an MTK update (and something related to FiniteD…
Browse files Browse the repository at this point in the history
…iff.jl) (#311)

* Fix the OOP test (which came with a MTK upgrade)

* Fix an autodiff test (not only MTK but also FiniteDiff.jl changed)

* JuliaFormatter.jl
  • Loading branch information
nathanaelbosch authored Apr 9, 2024
1 parent 6df1cd5 commit 0525b2d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
FiniteHorizonGramians = "b59a298d-d283-4a37-9369-85a9f9a111a5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Expand Down
22 changes: 16 additions & 6 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ using ProbNumDiffEq
using ModelingToolkit
using Test
using LinearAlgebra
using FiniteDiff
using FiniteDifferences
using ForwardDiff
# using ReverseDiff
# using Zygote
Expand All @@ -11,7 +11,14 @@ import ODEProblemLibrary: prob_ode_fitzhughnagumo

@testset "solver: $ALG" for ALG in (EK0, EK1, DiagonalEK1)
_prob = prob_ode_fitzhughnagumo
prob = ODEProblem(modelingtoolkitize(_prob), _prob.u0, _prob.tspan, jac=true)
prob = ODEProblem(
structural_simplify(modelingtoolkitize(_prob)),
_prob.u0,
_prob.tspan,
jac=true,
)
prob = remake(prob, p=collect(_prob.p))

function param_to_loss(p)
sol = solve(
remake(prob, p=p),
Expand All @@ -37,12 +44,15 @@ import ODEProblemLibrary: prob_ode_fitzhughnagumo
return norm(sol.u[end]) # Dummy loss
end

dldp = FiniteDiff.finite_difference_gradient(param_to_loss, prob.p)
dldu0 = FiniteDiff.finite_difference_gradient(startval_to_loss, prob.u0)
# dldp = FiniteDiff.finite_difference_gradient(param_to_loss, prob.p)
# dldu0 = FiniteDiff.finite_difference_gradient(startval_to_loss, prob.u0)
# For some reason FiniteDiff.jl is not working anymore so we use FiniteDifferences.jl:
dldp = grad(central_fdm(5, 1), param_to_loss, prob.p)[1]
dldu0 = grad(central_fdm(5, 1), startval_to_loss, prob.u0)[1]

@testset "ForwardDiff.jl" begin
@test ForwardDiff.gradient(param_to_loss, prob.p) dldp rtol = 1e-3
@test ForwardDiff.gradient(startval_to_loss, prob.u0) dldu0 rtol = 5e-3
@test ForwardDiff.gradient(param_to_loss, prob.p) dldp rtol = 1e-2
@test ForwardDiff.gradient(startval_to_loss, prob.u0) dldu0 rtol = 5e-2
end

# @testset "ReverseDiff.jl" begin
Expand Down
4 changes: 3 additions & 1 deletion test/oop_problems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ using Test
end
@testset "with jacobian" begin
# now with defined jac
prob = ODEProblem(modelingtoolkitize(prob), prob.u0, prob.tspan, jac=true)
prob = ODEProblem(
structural_simplify(modelingtoolkitize(prob)), prob.u0, prob.tspan, jac=true,
)
@test solve(prob, EK0(order=4)) isa ProbNumDiffEq.ProbODESolution
@test solve(prob, EK1(order=4)) isa ProbNumDiffEq.ProbODESolution
@test solve(prob, EK1(order=4, initialization=ClassicSolverInit())) isa
Expand Down

0 comments on commit 0525b2d

Please sign in to comment.