Skip to content

Commit

Permalink
Merge pull request #2928 from AayushSabharwal/as/param-init
Browse files Browse the repository at this point in the history
feat: allow parameters to be unknowns in the initialization system
  • Loading branch information
ChrisRackauckas authored Oct 10, 2024
2 parents 8ce64bf + 84a1f2e commit 385c034
Show file tree
Hide file tree
Showing 15 changed files with 803 additions and 45 deletions.
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ ConstructionBase = "1"
DataInterpolations = "6.4"
DataStructures = "0.17, 0.18"
DeepDiffs = "1"
DiffEqBase = "6.103.0"
DiffEqBase = "6.157"
DiffEqCallbacks = "2.16, 3, 4"
DiffEqNoiseProcess = "5"
DiffRules = "0.1, 1.0"
Expand Down Expand Up @@ -110,12 +110,13 @@ NonlinearSolve = "3.14"
OffsetArrays = "1"
OrderedCollections = "1"
OrdinaryDiffEq = "6.82.0"
OrdinaryDiffEqCore = "1.7.0"
PrecompileTools = "1"
REPL = "1"
RecursiveArrayTools = "3.26"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.55"
SciMLBase = "2.56.1"
SciMLStructures = "1.0"
Serialization = "1"
Setfield = "0.7, 0.8, 1"
Expand Down Expand Up @@ -148,6 +149,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -162,4 +164,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"]
134 changes: 134 additions & 0 deletions docs/src/tutorials/initialization.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,87 @@ long enough you will see that `λ = 0` is required for this equation, but since
problem constructor. Additionally, any warning about not being fully determined can
be suppressed via passing `warn_initialize_determined = false`.

## Initialization of parameters

Parameters may also be treated as unknowns in the initialization system. Doing so works
almost identically to the standard case. For a parameter to be an initialization unknown
(henceforth referred to as "solved parameter") it must represent a floating point number
(have a `symtype` of `Real` or `<:AbstractFloat`) or an array of such numbers. Additionally,
it must have a guess and one of the following conditions must be satisfied:

1. The value of the parameter as passed to `ODEProblem` is an expression involving other
variables/parameters. For example, if `[p => 2q + x]` is passed to `ODEProblem`. In
this case, `p ~ 2q + x` is used as an equation during initialization.
2. The parameter has a default (and no value for it is given to `ODEProblem`, since
that is condition 1). The default will be used as an equation during initialization.
3. The parameter has a default of `missing`. If `ODEProblem` is given a value for this
parameter, it is used as an equation during initialization (whether the value is an
expression or not).
4. `ODEProblem` is given a value of `missing` for the parameter. If the parameter has a
default, it will be used as an equation during initialization.

All parameter dependencies (where the dependent parameter is a floating point number or
array thereof) also become equations during initialization, and the dependent parameters
become unknowns.

`remake` will reconstruct the initialization system and problem, given the new
constraints provided to it. The new values will be combined with the original
variable-value mapping provided to `ODEProblem` and used to construct the initialization
problem.

### Parameter initialization by example

Consider the following system, where the sum of two unknowns is a constant parameter
`total`.

```@example paraminit
using ModelingToolkit, OrdinaryDiffEq # hidden
using ModelingToolkit: t_nounits as t, D_nounits as D # hidden
@variables x(t) y(t)
@parameters total
@mtkbuild sys = ODESystem([D(x) ~ -x, total ~ x + y], t;
defaults = [total => missing], guesses = [total => 1.0])
```

Given any two of `x`, `y` and `total` we can determine the remaining variable.

```@example paraminit
prob = ODEProblem(sys, [x => 1.0, y => 2.0], (0.0, 1.0))
integ = init(prob, Tsit5())
@assert integ.ps[total] ≈ 3.0 # hide
integ.ps[total]
```

Suppose we want to re-create this problem, but now solve for `x` given `total` and `y`:

```@example paraminit
prob2 = remake(prob; u0 = [y => 1.0], p = [total => 4.0])
initsys = prob2.f.initializeprob.f.sys
```

The system is now overdetermined. In fact:

```@example paraminit
[equations(initsys); observed(initsys)]
```

The system can never be satisfied and will always lead to an `InitialFailure`. This is
due to the aforementioned behavior of retaining the original variable-value mapping
provided to `ODEProblem`. To fix this, we pass `x => nothing` to `remake` to remove its
retained value.

```@example paraminit
prob2 = remake(prob; u0 = [y => 1.0, x => nothing], p = [total => 4.0])
initsys = prob2.f.initializeprob.f.sys
```

The system is fully determined, and the equations are solvable.

```@example
[equations(initsys); observed(initsys)]
```

## Diving Deeper: Constructing the Initialization System

To get a better sense of the initialization system and to help debug it, you can construct
Expand Down Expand Up @@ -383,3 +464,56 @@ sol[α * x - β * x * y]
```@example init
plot(sol)
```

## Solving for parameters during initialization

Sometimes, it is necessary to solve for a parameter during initialization. For example,
given a spring-mass system we want to find the un-stretched length of the spring given
that the initial condition of the system is its steady state.

```@example init
using ModelingToolkitStandardLibrary.Mechanical.TranslationalModelica: Fixed, Mass, Spring,
Force, Damper
using ModelingToolkitStandardLibrary.Blocks: Constant
@named mass = Mass(; m = 1.0, s = 1.0, v = 0.0, a = 0.0)
@named fixed = Fixed(; s0 = 0.0)
@named spring = Spring(; c = 2.0, s_rel0 = missing)
@named gravity = Force()
@named constant = Constant(; k = 9.81)
@named damper = Damper(; d = 0.1)
@mtkbuild sys = ODESystem(
[connect(fixed.flange, spring.flange_a), connect(spring.flange_b, mass.flange_a),
connect(mass.flange_a, gravity.flange), connect(constant.output, gravity.f),
connect(fixed.flange, damper.flange_a), connect(damper.flange_b, mass.flange_a)],
t;
systems = [fixed, spring, mass, gravity, constant, damper],
guesses = [spring.s_rel0 => 1.0])
```

Note that we explicitly provide `s_rel0 = missing` to the spring. Parameters are only
solved for during initialization if their value (either default, or explicitly passed
to the `ODEProblem` constructor) is `missing`. We also need to provide a guess for the
parameter.

If a parameter is not given a value of `missing`, and does not have a default or initial
value, the `ODEProblem` constructor will throw an error. If the parameter _does_ have a
value of `missing`, it must be given a guess.

```@example init
prob = ODEProblem(sys, [], (0.0, 1.0))
prob.ps[spring.s_rel0]
```

Note that the value of the parameter in the problem is zero, similar to unknowns that
are solved for during initialization.

```@example init
integ = init(prob)
integ.ps[spring.s_rel0]
```

The un-stretched length of the spring is now correctly calculated. The same result can be
achieved if `s_rel0 = missing` is omitted when constructing `spring`, and instead
`spring.s_rel0 => missing` is passed to the `ODEProblem` constructor along with values
of other parameters.
23 changes: 22 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,15 @@ function has_observed_with_lhs(sys, sym)
end
end

function has_parameter_dependency_with_lhs(sys, sym)
has_parameter_dependencies(sys) || return false
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
return any(isequal(sym), ic.dependent_pars)
else
return any(isequal(sym), [eq.lhs for eq in parameter_dependencies(sys)])
end
end

function _all_ts_idxs!(ts_idxs, ::NotSymbolic, sys, sym)
if is_variable(sys, sym) || is_independent_variable(sys, sym)
push!(ts_idxs, ContinuousTimeseries())
Expand Down Expand Up @@ -1344,9 +1353,21 @@ function namespace_assignment(eq::Assignment, sys)
Assignment(_lhs, _rhs)
end

function is_array_of_symbolics(x)
symbolic_type(x) == ArraySymbolic() && return true
symbolic_type(x) == ScalarSymbolic() && return false
x isa AbstractArray &&
any(y -> symbolic_type(y) != NotSymbolic() || is_array_of_symbolics(y), x)
end

function namespace_expr(
O, sys, n = nameof(sys); ivs = independent_variables(sys))
O = unwrap(O)
# Exceptions for arrays of symbolic and Ref of a symbolic, the latter
# of which shows up in broadcasts
if symbolic_type(O) == NotSymbolic() && !(O isa AbstractArray) && !(O isa Ref)
return O
end
if any(isequal(O), ivs)
return O
elseif iscall(O)
Expand All @@ -1368,7 +1389,7 @@ function namespace_expr(
end
elseif isvariable(O)
renamespace(n, O)
elseif O isa Array
elseif O isa AbstractArray && is_array_of_symbolics(O)
let sys = sys, n = n
map(o -> namespace_expr(o, sys, n; ivs), O)
end
Expand Down
Loading

0 comments on commit 385c034

Please sign in to comment.