Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BVProblem with constraints #3323

Open
wants to merge 43 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
9733460
init
vyudu Nov 22, 2024
d95e4a7
Merge remote-tracking branch 'origin/master' into MTK
vyudu Nov 22, 2024
b3da813
up
vyudu Dec 1, 2024
86c82ce
Merge remote-tracking branch 'origin/master' into MTK
vyudu Dec 1, 2024
4affeac
up
vyudu Dec 1, 2024
a3429ea
up
vyudu Dec 1, 2024
f751fbb
up
vyudu Dec 2, 2024
a9fdfd6
up
vyudu Dec 3, 2024
a9f2106
up
vyudu Dec 4, 2024
18fdd5f
up
vyudu Dec 13, 2024
9d65a33
fixing create_array
vyudu Dec 16, 2024
999ec30
revert Project.toml
vyudu Dec 16, 2024
9226ad6
Up
vyudu Dec 16, 2024
0cb4893
Merge remote-tracking branch 'origin/master' into MTK
vyudu Dec 16, 2024
67d8164
formatting
vyudu Dec 16, 2024
25988f3
up
vyudu Dec 17, 2024
bb28d4f
up
vyudu Dec 17, 2024
b2bf7c0
fix
vyudu Dec 17, 2024
3751c2a
up
vyudu Dec 20, 2024
ef1f089
up
vyudu Jan 8, 2025
d23d6f7
Merge remote-tracking branch 'origin/master' into MTK
vyudu Jan 8, 2025
2a25200
extend BVProblem for constraint equations
vyudu Jan 9, 2025
50504ab
adding tests
vyudu Jan 11, 2025
5d082ab
up
vyudu Jan 11, 2025
b83e003
refactor the bc creation function
vyudu Jan 14, 2025
db5eb66
up
vyudu Jan 14, 2025
e802946
test update
vyudu Jan 15, 2025
e74e047
fix
vyudu Jan 15, 2025
86d4144
test more solvers:
vyudu Jan 17, 2025
ec386fe
Refactor constraints
vyudu Jan 28, 2025
90ce80d
refactor tests
vyudu Jan 28, 2025
a15c670
fix sym validation
vyudu Jan 28, 2025
c6ef04a
remove file
vyudu Jan 28, 2025
7878225
up
vyudu Jan 28, 2025
5bcfdff
up
vyudu Jan 28, 2025
0493b5d
remove lines
vyudu Jan 28, 2025
1d32b6e
up
vyudu Jan 28, 2025
2b3ca96
up
vyudu Jan 28, 2025
0324522
fix typo
vyudu Jan 28, 2025
2a079be
Fix setter
vyudu Jan 28, 2025
d70a470
fix
vyudu Jan 28, 2025
37092f1
lower tol
vyudu Jan 29, 2025
e5eb8bd
fix Project.toml
vyudu Jan 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,14 +73,16 @@ MTKBifurcationKitExt = "BifurcationKit"
MTKChainRulesCoreExt = "ChainRulesCore"
MTKDeepDiffsExt = "DeepDiffs"
MTKHomotopyContinuationExt = "HomotopyContinuation"
MTKLabelledArraysExt = "LabelledArrays"
MTKInfiniteOptExt = "InfiniteOpt"
MTKLabelledArraysExt = "LabelledArrays"

[compat]
AbstractTrees = "0.3, 0.4"
ArrayInterface = "6, 7"
BifurcationKit = "0.4"
BlockArrays = "1.1"
BoundaryValueDiffEq = "5.12.0"
BoundaryValueDiffEqAscher = "1.1.0"
ChainRulesCore = "1"
Combinatorics = "1"
CommonSolve = "0.2.4"
Expand Down Expand Up @@ -139,8 +141,8 @@ SimpleNonlinearSolve = "0.1.0, 1, 2"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
StochasticDiffEq = "6.72.1"
StochasticDelayDiffEq = "1.8.1"
StochasticDiffEq = "6.72.1"
SymbolicIndexingInterface = "0.3.36"
SymbolicUtils = "3.7"
Symbolics = "6.19"
Expand All @@ -152,6 +154,8 @@ julia = "1.9"
[extras]
AmplNLWriter = "7c4d4715-977e-5154-bfe0-e096adeac482"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
BoundaryValueDiffEq = "764a87c0-6b3e-53db-9096-fe964310641d"
BoundaryValueDiffEqAscher = "7227322d-7511-4e07-9247-ad6ff830280e"
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
DeepDiffs = "ab62b9b5-e342-54a8-a765-a90f495de1a6"
Expand Down Expand Up @@ -183,4 +187,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
test = ["AmplNLWriter", "BenchmarkTools", "BoundaryValueDiffEq", "BoundaryValueDiffEqAscher", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"]
207 changes: 207 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,213 @@ function DiffEqBase.ODEProblem{iip, specialize}(sys::AbstractODESystem, u0map =
end
get_callback(prob::ODEProblem) = prob.kwargs[:callback]

"""
```julia
SciMLBase.BVProblem{iip}(sys::AbstractODESystem, u0map, tspan,
parammap = DiffEqBase.NullParameters();
constraints = nothing, guesses = nothing,
version = nothing, tgrad = false,
jac = true, sparse = true,
simplify = false,
kwargs...) where {iip}
```

Create a boundary value problem from the [`ODESystem`](@ref).

`u0map` is used to specify fixed initial values for the states. Every variable
must have either an initial guess supplied using `guesses` or a fixed initial
value specified using `u0map`.

`constraints` are used to specify boundary conditions to the ODESystem in the
form of equations. These values should specify values that state variables should
take at specific points, as in `x(0.5) ~ 1`). More general constraints that
should hold over the entire solution, such as `x(t)^2 + y(t)^2`, should be
specified as one of the equations used to build the `ODESystem`. Below is an example.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's odd that this is here. Problem constructors are only supposed to take values: constraints are structural. They should be added to the system I would think? @baggepinnen thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can the 0.5 in x(0.5) ~ 1 be a parameter? I can imagine this being something one would like to change without re-simplifying the system.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to have a BVPSystem that contains a ConstraintSystem like the OptimizationSystem?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to both of these

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I will work on these conversions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These two should be done now, I think


```julia
@parameters g
@variables x(..) y(t) [state_priority = 10] λ(t)
eqs = [D(D(x(t))) ~ λ * x(t)
D(D(y)) ~ λ * y - g
x(t)^2 + y^2 ~ 1]
@mtkbuild pend = ODESystem(eqs, t)

tspan = (0.0, 1.5)
u0map = [x(t) => 0.6, y => 0.8]
parammap = [g => 1]
guesses = [λ => 1]
constraints = [x(0.5) ~ 1]

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, parammap; constraints, guesses, check_length = false)
```

If no `constraints` are specified, the problem will be treated as an initial value problem.

If the `ODESystem` has algebraic equations like `x(t)^2 + y(t)^2`, the resulting
`BVProblem` must be solved using BVDAE solvers, such as Ascher.
"""
function SciMLBase.BVProblem(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem(sys::AbstractODESystem,
u0map::StaticArray,
args...;
kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, u0map, args...; kwargs...)
end

function SciMLBase.BVProblem{true}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{false}(sys::AbstractODESystem, args...; kwargs...)
BVProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...)
end

function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = [],
tspan = get_tspan(sys),
parammap = DiffEqBase.NullParameters();
constraints = nothing, guesses = Dict(),
version = nothing, tgrad = false,
callback = nothing,
check_length = true,
warn_initialize_determined = true,
eval_expression = false,
eval_module = @__MODULE__,
kwargs...) where {iip, specialize}

if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
end
!isnothing(callback) && error("BVP solvers do not support callbacks.")

has_alg_eqs(sys) && error("The BVProblem currently does not support ODESystems with algebraic equations.") # Remove this when the BVDAE solvers get updated, the codegen should work when it does.

constraintsts = nothing
constraintps = nothing
sts = unknowns(sys)
ps = parameters(sys)

# Constraint validation
if !isnothing(constraints)
constraints isa Equation ||
constraints isa Vector{Equation} ||
error("Constraints must be specified as an equation or a vector of equations.")

(length(constraints) + length(u0map) > length(sts)) &&
error("The BVProblem is overdetermined. The total number of conditions (# constraints + # fixed initial values given by u0map) cannot exceed the total number of states.")
end

# ODESystems without algebraic equations should use both fixed values + guesses
# for initialization.
_u0map = has_alg_eqs(sys) ? u0map : merge(Dict(u0map), Dict(guesses))
f, u0, p = process_SciMLProblem(ODEFunction{iip, specialize}, sys, _u0map, parammap;
t = tspan !== nothing ? tspan[1] : tspan, guesses,
check_length, warn_initialize_determined, eval_expression, eval_module, kwargs...)

stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
u0_idxs = has_alg_eqs(sys) ? collect(1:length(sts)) : [stidxmap[k] for (k,v) in u0map]

bc = process_constraints(sys, constraints, u0, u0_idxs, tspan, iip)

return BVProblem{iip}(f, bc, u0, tspan, p; kwargs...)
end

get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")

# Validate that all the variables in the BVP constraints are well-formed states or parameters.
function validate_constraint_syms(eq, constraintsts, constraintps, sts, ps, iv)
for var in constraintsts
if length(arguments(var)) > 1
error("Too many arguments for variable $var.")
elseif arguments(var) == iv
var ∈ sts || error("Constraint equation $eq contains a variable $var that is not a variable of the ODESystem.")
error("Constraint equation $eq contains a variable $var that does not have a specified argument. Such equations should be specified as algebraic equations to the ODESystem rather than a boundary constraints.")
else
operation(var)(iv) ∈ sts || error("Constraint equation $eq contains a variable $(operation(var)) that is not a variable of the ODESystem.")
end
end

for var in constraintps
if !iscall(var)
var ∈ ps || error("Constraint equation $eq contains a parameter $var that is not a parameter of the ODESystem.")
else
length(arguments(var)) > 1 && error("Too many arguments for parameter $var.")
operation(var) ∈ ps || error("Constraint equations contain a parameter $var that is not a parameter of the ODESystem.")
end
end
end

"""
process_constraints(sys, constraints, u0, tspan, iip)

Given an ODESystem with some constraints, generate the boundary condition function.
"""
function process_constraints(sys::ODESystem, constraints, u0, u0_idxs, tspan, iip)

iv = get_iv(sys)
sts = get_unknowns(sys)
ps = get_ps(sys)
np = length(ps)
ns = length(sts)

stidxmap = Dict([v => i for (i, v) in enumerate(sts)])
pidxmap = Dict([v => i for (i, v) in enumerate(ps)])

@variables sol(..)[1:ns] p[1:np]
exprs = Any[]

constraintsts = OrderedSet()
constraintps = OrderedSet()

!isnothing(constraints) && for cons in constraints
collect_vars!(constraintsts, constraintps, cons, iv)
validate_constraint_syms(cons, constraintsts, constraintps, Set(sts), Set(ps), iv)
expr = cons.rhs - cons.lhs

for st in constraintsts
x = operation(st)
t = arguments(st)[1]
idx = stidxmap[x(iv)]

expr = Symbolics.substitute(expr, Dict(x(t) => sol(t)[idx]))
end

for var in constraintps
if iscall(var)
x = operation(var)
t = arguments(var)[1]
idx = pidxmap[x]

expr = Symbolics.substitute(expr, Dict(x(t) => p[idx]))
else
idx = pidxmap[var]
expr = Symbolics.substitute(expr, Dict(var => p[idx]))
end
end

empty!(constraintsts)
empty!(constraintps)
push!(exprs, expr)
end

init_cond_exprs = Any[]

for i in u0_idxs
expr = sol(tspan[1])[i] - u0[i]
push!(init_cond_exprs, expr)
end

exprs = vcat(init_cond_exprs, exprs)
bcs = Symbolics.build_function(exprs, sol, p, expression = Val{false})
if iip
return (resid, u, p, t) -> bcs[2](resid, u, p)
else
return (u, p, t) -> bcs[1](u, p)
end
end

"""
```julia
DiffEqBase.DAEProblem{iip}(sys::AbstractODESystem, du0map, u0map, tspan,
Expand Down
Loading
Loading