Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BVProblem with constraints #3323

Draft
wants to merge 28 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,16 @@ MTKBifurcationKitExt = "BifurcationKit"
MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKHomotopyContinuationExt = "HomotopyContinuation"
MTKLabelledArraysExt = "LabelledArrays"
MTKInfiniteOptExt = "InfiniteOpt"
MTKLabelledArraysExt = "LabelledArrays"

[compat]
AbstractTrees = "0.3, 0.4"
ArrayInterface = "6, 7"
BifurcationKit = "0.4"
BlockArrays = "1.1"
BoundaryValueDiffEq = "5.12.0"
BoundaryValueDiffEqAscher = "1.1.0"
ChainRulesCore = "1"
Combinatorics = "1"
CommonSolve = "0.2.4"
Expand Down Expand Up @@ -139,8 +141,8 @@ SimpleNonlinearSolve = "0.1.0, 1, 2"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
StochasticDiffEq = "6.72.1"
StochasticDelayDiffEq = "1.8.1"
StochasticDiffEq = "6.72.1"
SymbolicIndexingInterface = "0.3.36"
SymbolicUtils = "3.7"
Symbolics = "6.19"
Expand All @@ -152,6 +154,8 @@ julia = "1.9"
[extras]
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
Expand Down Expand Up @@ -183,4 +187,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEq", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
209 changes: 209 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,215 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

"""
```julia
SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
parammap = DiffEqBase.NullParameters();
constraints = nothing, guesses = nothing,
version = nothing, tgrad = false,
jac = true, sparse = true,
simplify = false,
kwargs...) where {iip}
```

Create a boundary value problem from the [`ODESystem`](@ref). The arguments `dvs` and
`ps` are used to set the order of the dependent variable and parameter vectors,
respectively. `u0map` is used to specify fixed initial values for the states.

Every variable must have either an initial guess supplied using `guesses` or
a fixed initial value specified using `u0map`.

`constraints` are used to specify boundary conditions to the ODESystem in the
form of equations. These values should specify values that state variables should
take at specific points, as in `x(0.5) ~ 1`). More general constraints that
should hold over the entire solution, such as `x(t)^2 + y(t)^2`, should be
specified as one of the equations used to build the `ODESystem`. Below is an example.

```julia
@parameters g
@variables x(..) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x(t))) ~ λ * x(t)
D(D(y)) ~ λ * y - g
x(t)^2 + y^2 ~ 1]
@mtkbuild pend = ODESystem(eqs, t)

tspan = (0.0, 1.5)
u0map = [x(t) => 0.6, y => 0.8]
parammap = [g => 1]
guesses = [λ => 1]
constraints = [x(0.5) ~ 1]

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
```

If no `constraints` are specified, the problem will be treated as an initial value problem.

If the `ODESystem` has algebraic equations like `x(t)^2 + y(t)^2`, the resulting
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
"""
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem(sys::AbstractODESystem,
u0map::StaticArray,
args...;
kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
end

function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
constraints = nothing, guesses = Dict(),
version = nothing, tgrad = false,
callback = nothing,
check_length = true,
warn_initialize_determined = true,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...) where {iip, specialize}

if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
end
!isnothing(callback) && error("BVP solvers do not support callbacks.")

has_alg_eqs(sys) && error("The BVProblem currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.

constraintsts = nothing
constraintps = nothing
sts = unknowns(sys)
ps = parameters(sys)

# Constraint validation
if !isnothing(constraints)
constraints isa Equation ||
constraints isa Vector{Equation} ||
error("Constraints must be specified as an equation or a vector of equations.")

(length(constraints) + length(u0map) > length(sts)) &&
error("The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) cannot exceed the total number of states.")
end

# ODESystems without algebraic equations should use both fixed values + guesses
# for initialization.
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan, guesses,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)

stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]

bc = process_constraints(sys, constraints, u0, u0_idxs, tspan, iip)

return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
end

get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")

# Validate that all the variables in the BVP constraints are well-formed states or parameters.
function validate_constraint_syms(eq, constraintsts, constraintps, sts, ps, iv)
for var in constraintsts
if length(arguments(var)) > 1
error("Too many arguments for variable $var.")
elseif arguments(var) == iv
var ∈ sts || error("Constraint equation $eq contains a variable $var that is not a variable of the ODESystem.")
error("Constraint equation $eq contains a variable $var that does not have a specified argument. Such equations should be specified as algebraic equations to the ODESystem rather than a boundary constraints.")
else
operation(var)(iv) ∈ sts || error("Constraint equation $eq contains a variable $(operation(var)) that is not a variable of the ODESystem.")
end
end

for var in constraintps
if !iscall(var)
var ∈ ps || error("Constraint equation $eq contains a parameter $var that is not a parameter of the ODESystem.")
else
length(arguments(var)) > 1 && error("Too many arguments for parameter $var.")
operation(var) ∈ ps || error("Constraint equations contain a parameter $var that is not a parameter of the ODESystem.")
end
end
end

"""
process_constraints(sys, constraints, u0, tspan, iip)

Given an ODESystem with some constraints, generate the boundary condition function.
"""
function process_constraints(sys::ODESystem, constraints, u0, u0_idxs, tspan, iip)

iv = get_iv(sys)
sts = get_unknowns(sys)
ps = get_ps(sys)
np = length(ps)
ns = length(sts)

stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])

@variables sol(..)[1:ns] p[1:np]
exprs = Any[]

constraintsts = OrderedSet()
constraintps = OrderedSet()

!isnothing(constraints) && for cons in constraints
collect_vars!(constraintsts, constraintps, cons, iv)
validate_constraint_syms(cons, constraintsts, constraintps, Set(sts), Set(ps), iv)
expr = cons.rhs - cons.lhs

for st in constraintsts
x = operation(st)
t = arguments(st)[1]
idx = stidxmap[x(iv)]

expr = Symbolics.substitute(expr, Dict(x(t) => sol(t)[idx]))
end

for var in constraintps
if iscall(var)
x = operation(var)
t = arguments(var)[1]
idx = pidxmap[x]

expr = Symbolics.substitute(expr, Dict(x(t) => p[idx]))
else
idx = pidxmap[var]
expr = Symbolics.substitute(expr, Dict(var => p[idx]))
end
end

empty!(constraintsts)
empty!(constraintps)
push!(exprs, expr)
end

init_cond_exprs = Any[]

for i in u0_idxs
expr = sol(tspan[1])[i] - u0[i]
push!(init_cond_exprs, expr)
end

exprs = vcat(init_cond_exprs, exprs)
@show exprs
bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
if iip
return (resid, u, p, t) -> bcs[2](resid, u, p)
else
return (u, p, t) -> bcs[1](u, p)
end
end

"""
```julia
DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
Expand Down
Loading
Loading