Skip to content

Commit

Permalink
Merge pull request #2977 from DhairyaLGandhi/dg/crc
Browse files Browse the repository at this point in the history
AD: Add ChainRules extension for MTKParameters construction
  • Loading branch information
ChrisRackauckas authored Aug 20, 2024
2 parents c113dd1 + 93908be commit 2e35294
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[weakdeps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"

[extensions]
MTKBifurcationKitExt = "BifurcationKit"
MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"

[compat]
Expand Down
14 changes: 14 additions & 0 deletions ext/MTKChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module MTKChainRulesCoreExt

import ModelingToolkit as MTK
import ChainRulesCore
import ChainRulesCore: NoTangent

function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
function mtp_pullback(dt)
(NoTangent(), dt.tunable[1:length(tunables)], ntuple(_ -> NoTangent(), length(args))...)
end
MTK.MTKParameters(tunables, args...), mtp_pullback
end

end
5 changes: 5 additions & 0 deletions test/extensions/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
[deps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
31 changes: 31 additions & 0 deletions test/extensions/ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
using ModelingToolkit
using ModelingToolkit: t_nounits as t, D_nounits as D
using Zygote
using SymbolicIndexingInterface
using SciMLStructures
using OrdinaryDiffEq
using SciMLSensitivity

@variables x(t)[1:3] y(t)
@parameters p[1:3, 1:3] q
eqs = [
D(x) ~ p * x
D(y) ~ sum(p) + q * y
]
u0 = [x => zeros(3),
y => 1.]
ps = [p => zeros(3, 3),
q => 1.]
tspan = (0., 10.)
@mtkbuild sys = ODESystem(eqs, t)
prob = ODEProblem(sys, u0, tspan, ps)
sol = solve(prob, Tsit5())

mtkparams = parameter_values(prob)
new_p = rand(10)
gs = gradient(new_p) do new_p
new_params = SciMLStructures.replace(SciMLStructures.Tunable(), mtkparams, new_p)
new_prob = remake(prob, p = new_params)
new_sol = solve(new_prob, Tsit5())
sum(new_sol)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,5 +104,6 @@ end
if GROUP == "All" || GROUP == "Extensions"
activate_extensions_env()
@safetestset "BifurcationKit Extension Test" include("extensions/bifurcationkit.jl")
@safetestset "Auto Differentiation Test" include("extensions/ad.jl")
end
end

0 comments on commit 2e35294

Please sign in to comment.