Skip to content

Commit

Permalink
Merge pull request #3005 from AayushSabharwal/as/autodiff-defaults
Browse files Browse the repository at this point in the history
fix: improve resolution of dependent parameter defaults
  • Loading branch information
ChrisRackauckas authored Sep 2, 2024
2 parents a25a254 + 2f10bf5 commit 7eb5354
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 2 deletions.
3 changes: 1 addition & 2 deletions src/systems/parameter_buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ function MTKParameters(
end

isempty(missing_params) || throw(MissingParametersError(collect(missing_params)))

p = Dict(unwrap(k) => fixpoint_sub(v, bigdefs) for (k, v) in p)
p = Dict(unwrap(k) => (bigdefs[unwrap(k)] = fixpoint_sub(v, bigdefs)) for (k, v) in p)
for (sym, _) in p
if iscall(sym) && operation(sym) === getindex &&
first(arguments(sym)) in all_ps
Expand Down
1 change: 1 addition & 0 deletions test/extensions/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
BifurcationKit = "0f109fa4-8a5d-4b75-95aa-f515264e7665"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Expand Down
24 changes: 24 additions & 0 deletions test/extensions/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using SymbolicIndexingInterface
using SciMLStructures
using OrdinaryDiffEq
using SciMLSensitivity
using ForwardDiff

@variables x(t)[1:3] y(t)
@parameters p[1:3, 1:3] q
Expand All @@ -27,3 +28,26 @@ gs = gradient(new_p) do new_p
new_sol = solve(new_prob, Tsit5())
sum(new_sol)
end

@testset "Issue#2997" begin
pars = @parameters y0 mh Tγ0 Th0 h ργ0
vars = @variables x(t)
@named sys = ODESystem([D(x) ~ y0],
t,
vars,
pars;
defaults = [
y0 => mh * 3.1 / (2.3 * Th0),
mh => 123.4,
Th0 => (4 / 11)^(1 / 3) * Tγ0,
Tγ0 => (15 / π^2 * ργ0 * (2 * h)^2 / 7)^(1 / 4) / 5
])
sys = structural_simplify(sys)

function x_at_0(θ)
prob = ODEProblem(sys, [sys.x => 1.0], (0.0, 1.0), [sys.ργ0 => θ[1], sys.h => θ[2]])
return prob.u0[1]
end

@test ForwardDiff.gradient(x_at_0, [0.3, 0.7]) == zeros(2)
end

0 comments on commit 7eb5354

Please sign in to comment.