Skip to content

Commit

Permalink
test: add a simple test for AD
Browse files Browse the repository at this point in the history
  • Loading branch information
DhairyaLGandhi authored and ChrisRackauckas committed Aug 20, 2024
1 parent 60235e3 commit 93908be
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
5 changes: 5 additions & 0 deletions test/extensions/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[deps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
31 changes: 31 additions & 0 deletions test/extensions/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
using Zygote
using SymbolicIndexingInterface
using SciMLStructures
using OrdinaryDiffEq
using SciMLSensitivity

@variables x(t)[1:3] y(t)
@parameters p[1:3, 1:3] q
eqs = [
D(x) ~ p * x
D(y) ~ sum(p) + q * y
]
u0 = [x => zeros(3),
y => 1.]
ps = [p => zeros(3, 3),
q => 1.]
tspan = (0., 10.)
@mtkbuild sys = ODESystem(eqs, t)
prob = ODEProblem(sys, u0, tspan, ps)
sol = solve(prob, Tsit5())

mtkparams = parameter_values(prob)
new_p = rand(10)
gs = gradient(new_p) do new_p
new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
new_prob = remake(prob, p = new_params)
new_sol = solve(new_prob, Tsit5())
sum(new_sol)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,6 @@ end
if GROUP == "All" || GROUP == "Extensions"
activate_extensions_env()
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
@safetestset "Auto Differentiation Test" include("extensions/ad.jl")
end
end

0 comments on commit 93908be

Please sign in to comment.