From a6f781d45cbbc885edf9167695682124cd9655bb Mon Sep 17 00:00:00 2001 From: Ben Chung Date: Thu, 1 Aug 2024 20:36:50 -0700 Subject: [PATCH 001/101] First pass at MutatingFunctionalAffect --- Project.toml | 1 + src/ModelingToolkit.jl | 1 + src/systems/callbacks.jl | 144 +++++++++++++++++++++++++++++-- src/systems/diffeqs/odesystem.jl | 27 +++++- test/symbolic_events.jl | 49 +++++++++++ 5 files changed, 215 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 808d68af06..729966fde0 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 2f57bb1765..f5262a1526 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -54,6 +54,7 @@ using Reexport using RecursiveArrayTools import Graphs: SimpleDiGraph, add_edge!, incidence_matrix import BlockArrays: BlockedArray, Block, blocksize, blocksizes +import ComponentArrays using RuntimeGeneratedFunctions using RuntimeGeneratedFunctions: drop_expr diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 86cab57634..f7d4baa4cb 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -60,8 +60,6 @@ function Base.hash(a::FunctionalAffect, s::UInt) hash(a.ctx, s) end -has_functional_affect(cb) = affects(cb) isa FunctionalAffect - namespace_affect(affect, s) = namespace_equation(affect, s) function namespace_affect(affect::FunctionalAffect, s) FunctionalAffect(func(affect), @@ -73,6 +71,67 @@ function namespace_affect(affect::FunctionalAffect, s) context(affect)) end +""" +`MutatingFunctionalAffect` differs from `FunctionalAffect` in two key ways: +* First, insetad of the `u` vector passed to `f` being a vector of indices into `integ.u` it's instead the result of evaluating `obs` at the current state, named as specified in `obs_syms`. This allows affects to easily access observed states and decouples affect inputs from the system structure. +* Second, it abstracts the assignment back to system states away. Instead of writing `integ.u[u.myvar] = [whatever]`, you instead declare in `mod_params` that you want to modify `myvar` and then either (out of place) return a named tuple with `myvar` or (in place) modify the associated element in the ComponentArray that's given. +Initially, we only support "flat" states in `modified`; these states will be marked as irreducible in the overarching system and they will simply be bulk assigned at mutation. In the future, this will be extended to perform a nonlinear solve to further decouple the affect from the system structure. +""" +@kwdef struct MutatingFunctionalAffect + f::Any + obs::Vector + obs_syms::Vector{Symbol} + modified::Vector + mod_syms::Vector{Symbol} + ctx::Any +end + +MutatingFunctionalAffect(f::Function; + observed::NamedTuple = NamedTuple{()}(()), + modified::NamedTuple = NamedTuple{()}(()), + ctx=nothing) = MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)), collect(values(modified)), collect(keys(modified)), ctx) +MutatingFunctionalAffect(f::Function, observed::NamedTuple; modified::NamedTuple = NamedTuple{()}(()), ctx=nothing) = + MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx) +MutatingFunctionalAffect(f::Function, observed::NamedTuple, modified::NamedTuple; ctx=nothing) = + MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx) +MutatingFunctionalAffect(f::Function, observed::NamedTuple, modified::NamedTuple, ctx) = + MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx) + +func(f::MutatingFunctionalAffect) = f.f +context(a::MutatingFunctionalAffect) = a.ctx +observed(a::MutatingFunctionalAffect) = a.obs +observed_syms(a::MutatingFunctionalAffect) = a.obs_syms +discretes(a::MutatingFunctionalAffect) = filter(ModelingToolkit.isparameter, a.modified) +modified(a::MutatingFunctionalAffect) = a.modified +modified_syms(a::MutatingFunctionalAffect) = a.mod_syms + +function Base.:(==)(a1::MutatingFunctionalAffect, a2::MutatingFunctionalAffect) + isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) && + isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms)&& isequal(a1.ctx, a2.ctx) +end + +function Base.hash(a::MutatingFunctionalAffect, s::UInt) + s = hash(a.f, s) + s = hash(a.obs, s) + s = hash(a.obs_syms, s) + s = hash(a.modified, s) + s = hash(a.mod_syms, s) + hash(a.ctx, s) +end + +function namespace_affect(affect::MutatingFunctionalAffect, s) + MutatingFunctionalAffect(func(affect), + renamespace.((s,), observed(affect)), + observed_syms(affect), + renamespace.((s,), modified(affect)), + modified_syms(affect), + context(affect)) +end + +function has_functional_affect(cb) + (affects(cb) isa FunctionalAffect || affects(cb) isa MutatingFunctionalAffect) +end + #################################### continuous events ##################################### const NULL_AFFECT = Equation[] @@ -109,8 +168,8 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either: """ struct SymbolicContinuousCallback eqs::Vector{Equation} - affect::Union{Vector{Equation}, FunctionalAffect} - affect_neg::Union{Vector{Equation}, FunctionalAffect, Nothing} + affect::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect} + affect_neg::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing} rootfind::SciMLBase.RootfindOpt function SymbolicContinuousCallback(; eqs::Vector{Equation}, affect = NULL_AFFECT, affect_neg = affect, rootfind = SciMLBase.LeftRootFind) @@ -250,6 +309,7 @@ scalarize_affects(affects) = scalarize(affects) scalarize_affects(affects::Tuple) = FunctionalAffect(affects...) scalarize_affects(affects::NamedTuple) = FunctionalAffect(; affects...) scalarize_affects(affects::FunctionalAffect) = affects +scalarize_affects(affects::MutatingFunctionalAffect) = affects SymbolicDiscreteCallback(p::Pair) = SymbolicDiscreteCallback(p[1], p[2]) SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback) = cb # passthrough @@ -257,7 +317,7 @@ SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback) = cb # passthrough function Base.show(io::IO, db::SymbolicDiscreteCallback) println(io, "condition: ", db.condition) println(io, "affects:") - if db.affects isa FunctionalAffect + if db.affects isa FunctionalAffect || db.affects isa MutatingFunctionalAffect # TODO println(io, " ", db.affects) else @@ -749,6 +809,80 @@ function compile_user_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs. end end +invalid_variables(sys, expr) = filter(x -> !any(isequal(x), all_symbols(sys)), vars(expr)) +function unassignable_variables(sys, expr) + assignable_syms = vcat(unknowns(sys), parameters(sys)) + return filter(x -> !any(isequal(x), assignable_syms), vars(expr)) +end + +function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwargs...) + #= + Implementation sketch: + generate observed function (oop), should save to a component array under obs_syms + do the same stuff as the normal FA for pars_syms + call the affect method - test if it's OOP or IP using applicable + unpack and apply the resulting values + =# + obs_exprs = observed(affect) + for oexpr in obs_exprs + invalid_vars = invalid_variables(sys, oexpr) + if length(invalid_vars) > 0 + error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).") + end + end + obs_syms = observed_syms(affect) + obs_size = size.(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size + + mod_exprs = modified(affect) + for mexpr in mod_exprs + if !is_observed(sys, mexpr) && parameter_index(sys, mexpr) === nothing + error("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") + end + invalid_vars = unassignable_variables(sys, mexpr) + if length(invalid_vars) > 0 + error("Observed equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") + end + end + mod_syms = modified_syms(affect) + _, mod_og_val_fun = build_explicit_observed_function(sys, mod_exprs; return_inplace=true) + + # sanity checks done! now build the data and update function for observed values + mkzero(sz) = if sz === () 0.0 else zeros(sz) end + _, obs_fun = build_explicit_observed_function(sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []); return_inplace=true) + obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms..., )}(mkzero.(obs_size))) + + # okay so now to generate the stuff to assign it back into the system + # note that we reorder the componentarray to make the views coherent wrt the base array + mod_pairs = mod_exprs .=> mod_syms + mod_param_pairs = filter(v -> is_parameter(sys, v[1]), mod_pairs) + mod_unk_pairs = filter(v -> !is_parameter(sys, v[1]), mod_pairs) + _, mod_og_val_fun = build_explicit_observed_function(sys, reduce(vcat, [first.(mod_param_pairs); first.(mod_unk_pairs)]; init = []); return_inplace=true) + upd_params_fun = setu(sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = [])) + upd_unk_fun = setu(sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = [])) + + upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs); last.(mod_unk_pairs)]...,)}( + [collect(mkzero(size(e)) for e in first.(mod_param_pairs)); + collect(mkzero(size(e)) for e in first.(mod_unk_pairs))])) + upd_params_view = view(upd_component_array, last.(mod_param_pairs)) + upd_unks_view = view(upd_component_array, last.(mod_unk_pairs)) + let user_affect = func(affect), ctx = context(affect) + function (integ) + # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens + mod_og_val_fun(upd_component_array, integ.u, integ.p..., integ.t) + + # update the observed values + obs_fun(obs_component_array, integ.u, integ.p..., integ.t) + + # let the user do their thing + user_affect(upd_component_array, obs_component_array, integ, ctx) + + # write the new values back to the integrator + upd_params_fun(integ, upd_params_view) + upd_unk_fun(integ, upd_unks_view) + end + end +end + function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...) compile_user_affect(affect, cb, sys, dvs, ps; kwargs...) end diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index ec2c5f8157..daa4321ed0 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -404,8 +404,31 @@ ODESystem(eq::Equation, args...; kwargs...) = ODESystem([eq], args...; kwargs... """ $(SIGNATURES) -Build the observed function assuming the observed equations are all explicit, -i.e. there are no cycles. +Generates a function that computes the observed value(s) `ts` in the system `sys` assuming that there are no cycles in the equations. + +The return value will be either: +* a single function if the input is a scalar or if the input is a Vector but `return_inplace` is false +* the out of place and in-place functions `(ip, oop)` if `return_inplace` is true and the input is a `Vector` + +The function(s) will be: +* `RuntimeGeneratedFunction`s by default, +* A Julia `Expr` if `expression` is true, +* A directly evaluated Julia function in the module `eval_module` if `eval_expression` is true + +The signatures will be of the form `g(...)` with arguments: +* `output` for in-place functions +* `unknowns` if `params_only` is `false` +* `inputs` if `inputs` is an array of symbolic inputs that should be available in `ts` +* `p...` unconditionally; note that in the case of `MTKParameters` more than one parameters argument may be present, so it must be splatted +* `t` if the system is time-dependent; for example `NonlinearSystem` will not have `t` +For example, a function `g(op, unknowns, p, inputs, t)` will be the in-place function generated if `return_inplace` is true, `ts` is a vector, an array of inputs `inputs` is given, and `params_only` is false for a time-dependent system. + +Options not otherwise specified are: +* `output_type = Array` the type of the array generated by the out-of-place vector-valued function +* `checkbounds = true` checks bounds if true when destructuring parameters +* `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail. +* `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist +* `drop_expr` is deprecated. """ function build_explicit_observed_function(sys, ts; inputs = nothing, diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index e1d12814ef..351f8111ee 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -887,3 +887,52 @@ end @test sol[b] == [2.0, 5.0, 5.0] @test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] end +@testset "Heater" begin + @variables temp(t) + params = @parameters furnace_on_threshold=0.5 furnace_off_threshold=0.7 furnace_power=1.0 leakage=0.1 furnace_on::Bool=false + eqs = [ + D(temp) ~ furnace_on * furnace_power - temp^2 * leakage + ] + + furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i, c + x.furnace_on = false + end) + furnace_enable = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_on_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i, c + x.furnace_on = true + end) + + @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) + ss = structural_simplify(sys) + prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + sol = solve(prob, Tsit5(); dtmax=0.01) + @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) +end + +@testset "Quadrature" begin + @variables theta(t) omega(t) + params = @parameters qA=0 qB=0 + eqs = [ + D(theta) ~ omega + omega ~ sin(0.5*t) + ] + qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(1000 * theta) ~ 0], + ModelingToolkit.MutatingFunctionalAffect(modified=(; qA)) do x, o, i, c + x.qA = 1 + end, + affect_neg = ModelingToolkit.MutatingFunctionalAffect(modified=(; qA)) do x, o, i, c + x.qA = 0 + end) + qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(1000 * theta + π/2) ~ 0], + ModelingToolkit.MutatingFunctionalAffect(modified=(; qB)) do x, o, i, c + x.qB = 1 + end, + affect_neg = ModelingToolkit.MutatingFunctionalAffect(modified=(; qB)) do x, o, i, c + x.qB = 0 + end) + @named sys = ODESystem(eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt]) + ss = structural_simplify(sys) + prob = ODEProblem(ss, [theta => 0.0], (0.0, 1.0)) + sol = solve(prob, Tsit5(); dtmax=0.01) +end From f151e429cd9edb33b55ec019be069af12cb9d108 Mon Sep 17 00:00:00 2001 From: Ben Chung Date: Fri, 2 Aug 2024 16:19:51 -0700 Subject: [PATCH 002/101] Clarify documentation for SCC --- src/systems/callbacks.jl | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index f7d4baa4cb..a57d5c006d 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -145,17 +145,26 @@ By default `affect_neg = affect`; to only get rising edges specify `affect_neg = Assume without loss of generality that the equation is of the form `c(u,p,t) ~ 0`; we denote the integrator state as `i.u`. For compactness, we define `prev_sign = sign(c(u[t-1], p[t-1], t-1))` and `cur_sign = sign(c(u[t], p[t], t))`. A condition edge will be detected and the callback will be invoked iff `prev_sign * cur_sign <= 0`. +The positive edge `affect` will be triggered iff an edge is detected and if `prev_sign < 0`; similarly, `affect_neg` will be +triggered iff an edge is detected and `prev_sign > 0`. + Inter-sample condition activation is not guaranteed; for example if we use the dirac delta function as `c` to insert a sharp discontinuity between integrator steps (which in this example would not normally be identified by adaptivity) then the condition is not guaranteed to be triggered. Once detected the integrator will "wind back" through a root-finding process to identify the point when the condition became active; the method used -is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). Multiple callbacks in the same system with different `rootfind` operations will be resolved -into separate VectorContinuousCallbacks in the enumeration order of `SciMLBase.RootfindOpt`, which may cause some callbacks to not fire if several become -active at the same instant. See the `SciMLBase` documentation for more information on the semantic rules. - -The positive edge `affect` will be triggered iff an edge is detected and if `prev_sign < 0`; similarly, `affect_neg` will be -triggered iff an edge is detected `prev_sign > 0`. +is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). If we denote the time when the condition becomes active at tc, +the value in the integrator after windback will be: +* `u[tc-epsilon], p[tc-epsilon], tc` if `LeftRootFind` is used, +* `u[tc+epsilon], p[tc+epsilon], tc` if `RightRootFind` is used, +* or `u[t], p[t], t` if `NoRootFind` is used. +For example, if we want to detect when an unknown variable `x` satisfies `x > 0` using the condition `x ~ 0` on a positive edge (that is, `D(x) > 0`), +then left root finding will get us `x=-epsilon`, right root finding `x=epsilon` and no root finding whatever the next step of the integrator was after +it passed through 0. + +Multiple callbacks in the same system with different `rootfind` operations will be grouped +by their `rootfind` value into separate VectorContinuousCallbacks in the enumeration order of `SciMLBase.RootfindOpt`. This may cause some callbacks to not fire if several become +active at the same instant. See the `SciMLBase` documentation for more information on the semantic rules. Affects (i.e. `affect` and `affect_neg`) can be specified as either: * A list of equations that should be applied when the callback is triggered (e.g. `x ~ 3, y ~ 7`) which must be of the form `unknown ~ observed value` where each `unknown` appears only once. Equations will be applied in the order that they appear in the vector; parameters and state updates will become immediately visible to following equations. From 49d48b833645a17f18f7581009f2a15a9c708161 Mon Sep 17 00:00:00 2001 From: Ben Chung Date: Fri, 2 Aug 2024 16:20:05 -0700 Subject: [PATCH 003/101] MutatingFunctionalAffect test cases --- test/symbolic_events.jl | 160 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 149 insertions(+), 11 deletions(-) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 351f8111ee..ca9f0ad9c3 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -227,6 +227,117 @@ affect_neg = [x ~ 1] @test e[].affect == affect end +@testset "MutatingFunctionalAffect constructors" begin + fmfa(o, x, i, c) = nothing + m = ModelingToolkit.MutatingFunctionalAffect(fmfa) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test m.obs == [] + @test m.obs_syms == [] + @test m.modified == [] + @test m.mod_syms == [] + @test m.ctx === nothing + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (;)) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test m.obs == [] + @test m.obs_syms == [] + @test m.modified == [] + @test m.mod_syms == [] + @test m.ctx === nothing + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; x)) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test isequal(m.obs, [x]) + @test m.obs_syms == [:x] + @test m.modified == [] + @test m.mod_syms == [] + @test m.ctx === nothing + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; y=x)) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test isequal(m.obs, [x]) + @test m.obs_syms == [:y] + @test m.modified == [] + @test m.mod_syms == [] + @test m.ctx === nothing + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa; observed=(; y=x)) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test isequal(m.obs, [x]) + @test m.obs_syms == [:y] + @test m.modified == [] + @test m.mod_syms == [] + @test m.ctx === nothing + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified=(; x)) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test isequal(m.obs, []) + @test m.obs_syms == [] + @test isequal(m.modified, [x]) + @test m.mod_syms == [:x] + @test m.ctx === nothing + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified=(; y=x)) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test isequal(m.obs, []) + @test m.obs_syms == [] + @test isequal(m.modified, [x]) + @test m.mod_syms == [:y] + @test m.ctx === nothing + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; x), (; x)) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test isequal(m.obs, [x]) + @test m.obs_syms == [:x] + @test isequal(m.modified, [x]) + @test m.mod_syms == [:x] + @test m.ctx === nothing + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; y=x), (; y=x)) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test isequal(m.obs, [x]) + @test m.obs_syms == [:y] + @test isequal(m.modified, [x]) + @test m.mod_syms == [:y] + @test m.ctx === nothing + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified=(; y=x), observed=(; y=x)) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test isequal(m.obs, [x]) + @test m.obs_syms == [:y] + @test isequal(m.modified, [x]) + @test m.mod_syms == [:y] + @test m.ctx === nothing + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified=(; y=x), observed=(; y=x), ctx=3) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test isequal(m.obs, [x]) + @test m.obs_syms == [:y] + @test isequal(m.modified, [x]) + @test m.mod_syms == [:y] + @test m.ctx === 3 + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; x), (; x), 3) + @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m.f == fmfa + @test isequal(m.obs, [x]) + @test m.obs_syms == [:x] + @test isequal(m.modified, [x]) + @test m.mod_syms == [:x] + @test m.ctx === 3 +end + ## @named sys = ODESystem(eqs, t, continuous_events = [x ~ 1]) @@ -912,27 +1023,54 @@ end @testset "Quadrature" begin @variables theta(t) omega(t) - params = @parameters qA=0 qB=0 + params = @parameters qA=0 qB=0 hA=0 hB=0 cnt=0 eqs = [ D(theta) ~ omega - omega ~ sin(0.5*t) + omega ~ 1.0 ] - qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(1000 * theta) ~ 0], - ModelingToolkit.MutatingFunctionalAffect(modified=(; qA)) do x, o, i, c + function decoder(oldA, oldB, newA, newB) + state = (oldA, oldB, newA, newB) + if state == (0, 0, 1, 0) || state == (1, 0, 1, 1) || state == (1, 1, 0, 1) || state == (0, 1, 0, 0) + return 1 + elseif state == (0, 0, 0, 1) || state == (0, 1, 1, 1) || state == (1, 1, 1, 0) || state == (1, 0, 0, 0) + return -1 + elseif state == (0, 0, 0, 0) || state == (0, 1, 0, 1) || state == (1, 0, 1, 0) || state == (1, 1, 1, 1) + return 0 + else + return 0 # err is interpreted as no movement + end + end + # todo: warn about dups + # todo: warn if a variable appears in both observed and modified + qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], + ModelingToolkit.MutatingFunctionalAffect((; qB), (; qA, hA, hB, cnt)) do x, o, i, c + x.hA = x.qA + x.hB = o.qB x.qA = 1 + x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) end, - affect_neg = ModelingToolkit.MutatingFunctionalAffect(modified=(; qA)) do x, o, i, c + affect_neg = ModelingToolkit.MutatingFunctionalAffect((; qB), (; qA, hA, hB, cnt)) do x, o, i, c + x.hA = x.qA + x.hB = o.qB x.qA = 0 - end) - qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(1000 * theta + π/2) ~ 0], - ModelingToolkit.MutatingFunctionalAffect(modified=(; qB)) do x, o, i, c + x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) + end; rootfind=SciMLBase.RightRootFind) + qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π/2) ~ 0], + ModelingToolkit.MutatingFunctionalAffect((; qA), (; qB, hA, hB, cnt)) do x, o, i, c + x.hA = o.qA + x.hB = x.qB x.qB = 1 + x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) end, - affect_neg = ModelingToolkit.MutatingFunctionalAffect(modified=(; qB)) do x, o, i, c + affect_neg = ModelingToolkit.MutatingFunctionalAffect((; qA), (; qB, hA, hB, cnt)) do x, o, i, c + x.hA = o.qA + x.hB = x.qB x.qB = 0 - end) + x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) + end; rootfind=SciMLBase.RightRootFind) @named sys = ODESystem(eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt]) ss = structural_simplify(sys) - prob = ODEProblem(ss, [theta => 0.0], (0.0, 1.0)) + prob = ODEProblem(ss, [theta => 0.0], (0.0, pi)) sol = solve(prob, Tsit5(); dtmax=0.01) + @test sol[cnt] == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state end From b96cd2eb273d640bf39d3a7f3c01a5e994cb7cc9 Mon Sep 17 00:00:00 2001 From: Ben Chung Date: Fri, 2 Aug 2024 16:42:55 -0700 Subject: [PATCH 004/101] More sanity checking --- src/systems/callbacks.jl | 28 ++++++++++++++++++++++++--- test/symbolic_events.jl | 42 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index a57d5c006d..612d80536c 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -818,10 +818,10 @@ function compile_user_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs. end end -invalid_variables(sys, expr) = filter(x -> !any(isequal(x), all_symbols(sys)), vars(expr)) +invalid_variables(sys, expr) = filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init=[])) function unassignable_variables(sys, expr) assignable_syms = vcat(unknowns(sys), parameters(sys)) - return filter(x -> !any(isequal(x), assignable_syms), vars(expr)) + return filter(x -> !any(isequal(x), assignable_syms), reduce(vcat, vars(expr); init=[])) end function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwargs...) @@ -832,6 +832,21 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa call the affect method - test if it's OOP or IP using applicable unpack and apply the resulting values =# + function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup) + seen = Set{Symbol}() + syms_dedup = []; exprs_dedup = [] + for (sym, exp) in Iterators.zip(syms, exprs) + if !in(sym, seen) + push!(syms_dedup, sym) + push!(exprs_dedup, exp) + push!(seen, sym) + else + @warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used." + end + end + return (syms_dedup, exprs_dedup) + end + obs_exprs = observed(affect) for oexpr in obs_exprs invalid_vars = invalid_variables(sys, oexpr) @@ -840,6 +855,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa end end obs_syms = observed_syms(affect) + obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs) obs_size = size.(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size mod_exprs = modified(affect) @@ -849,12 +865,18 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa end invalid_vars = unassignable_variables(sys, mexpr) if length(invalid_vars) > 0 - error("Observed equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") + error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") end end mod_syms = modified_syms(affect) + mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs) _, mod_og_val_fun = build_explicit_observed_function(sys, mod_exprs; return_inplace=true) + overlapping_syms = intersect(mod_syms, obs_syms) + if length(overlapping_syms) > 0 + @warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value." + end + # sanity checks done! now build the data and update function for observed values mkzero(sz) = if sz === () 0.0 else zeros(sz) end _, obs_fun = build_explicit_observed_function(sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []); return_inplace=true) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index ca9f0ad9c3..d24b590970 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1021,6 +1021,46 @@ end @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) end +@testset "MutatingFunctionalAffect errors and warnings" begin + @variables temp(t) + params = @parameters furnace_on_threshold=0.5 furnace_off_threshold=0.7 furnace_power=1.0 leakage=0.1 furnace_on::Bool=false + eqs = [ + D(temp) ~ furnace_on * furnace_power - temp^2 * leakage + ] + + furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on), observed=(; furnace_on)) do x, o, i, c + x.furnace_on = false + end) + @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off]) + ss = structural_simplify(sys) + @test_logs (:warn, "The symbols Any[:furnace_on] are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value.") prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + + @variables tempsq(t) # trivially eliminated + eqs = [ + tempsq ~ temp^2 + D(temp) ~ furnace_on * furnace_power - temp^2 * leakage + ] + + furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on, tempsq), observed=(; furnace_on)) do x, o, i, c + x.furnace_on = false + end) + @named sys = ODESystem(eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) + ss = structural_simplify(sys) + @test_throws "refers to missing variable(s)" prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + + + @parameters not_actually_here + furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on), observed=(; furnace_on, not_actually_here)) do x, o, i, c + x.furnace_on = false + end) + @named sys = ODESystem(eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) + ss = structural_simplify(sys) + @test_throws "refers to missing variable(s)" prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) +end + @testset "Quadrature" begin @variables theta(t) omega(t) params = @parameters qA=0 qB=0 hA=0 hB=0 cnt=0 @@ -1040,8 +1080,6 @@ end return 0 # err is interpreted as no movement end end - # todo: warn about dups - # todo: warn if a variable appears in both observed and modified qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], ModelingToolkit.MutatingFunctionalAffect((; qB), (; qA, hA, hB, cnt)) do x, o, i, c x.hA = x.qA From eec24cfb95546fb5b8a2415a8cd3a2be0d60a6a8 Mon Sep 17 00:00:00 2001 From: Ben Chung Date: Fri, 2 Aug 2024 16:56:56 -0700 Subject: [PATCH 005/101] Document MutatingFunctionalAffect --- src/systems/callbacks.jl | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 612d80536c..fefff94442 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -72,10 +72,29 @@ function namespace_affect(affect::FunctionalAffect, s) end """ -`MutatingFunctionalAffect` differs from `FunctionalAffect` in two key ways: -* First, insetad of the `u` vector passed to `f` being a vector of indices into `integ.u` it's instead the result of evaluating `obs` at the current state, named as specified in `obs_syms`. This allows affects to easily access observed states and decouples affect inputs from the system structure. -* Second, it abstracts the assignment back to system states away. Instead of writing `integ.u[u.myvar] = [whatever]`, you instead declare in `mod_params` that you want to modify `myvar` and then either (out of place) return a named tuple with `myvar` or (in place) modify the associated element in the ComponentArray that's given. -Initially, we only support "flat" states in `modified`; these states will be marked as irreducible in the overarching system and they will simply be bulk assigned at mutation. In the future, this will be extended to perform a nonlinear solve to further decouple the affect from the system structure. + MutatingFunctionalAffect(f::Function; observed::NamedTuple, modified::NamedTuple, ctx) + +`MutatingFunctionalAffect` is a helper for writing affect functions that will compute observed values and +ensure that modified values are correctly written back into the system. The affect function `f` needs to have +one of three signatures: +* `f(observed::ComponentArray)` if the function only reads observed values back from the system, +* `f(observed::ComponentArray, modified::ComponentArray)` if the function also writes values (unknowns or parameters) into the system, +* `f(observed::ComponentArray, modified::ComponentArray, ctx)` if the function needs the user-defined context, +* `f(observed::ComponentArray, modified::ComponentArray, ctx, integrator)` if the function needs the low-level integrator. + +The function `f` will be called with `observed` and `modified` `ComponentArray`s that are derived from their respective `NamedTuple` definitions. +Each `NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x` +so the value of `a + b` will be accessible as `observed.x` in `f`. `modified` currently restricts symbolic expressions to only bare variables, so only tuples of the form +`(; x = y)` or `(; x)` (which aliases `x` as itself) are allowed. + +Both `observed` and `modified` will be automatically populated with the current values of their corresponding expressions on function entry. +The values in `modified` will be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write + + MutatingFunctionalAffect(observed=(; x_plus_y = x + y), modified=(; x)) do o, m + m.x = o.x_plus_y + end + +The affect function updates the value at `x` in `modified` to be the result of evaluating `x + y` as passed in the observed values. """ @kwdef struct MutatingFunctionalAffect f::Any @@ -174,6 +193,7 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either: + `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`. + `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition. + `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem. +* A [`MutatingFunctionalAffect`](@ref); refer to its documentation for details. """ struct SymbolicContinuousCallback eqs::Vector{Equation} From fd0125d36715693c0b7e2c5d5fe130e402c137ae Mon Sep 17 00:00:00 2001 From: Ben Chung Date: Fri, 2 Aug 2024 17:15:38 -0700 Subject: [PATCH 006/101] Flip modified and observed order; write docstring --- src/systems/callbacks.jl | 34 +++++++++++++------ test/symbolic_events.jl | 73 +++++++++++++++++++++++++++++++--------- 2 files changed, 80 insertions(+), 27 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index fefff94442..715b0e9a91 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -72,15 +72,16 @@ function namespace_affect(affect::FunctionalAffect, s) end """ - MutatingFunctionalAffect(f::Function; observed::NamedTuple, modified::NamedTuple, ctx) + MutatingFunctionalAffect(f::Function; modified::NamedTuple, observed::NamedTuple, ctx) `MutatingFunctionalAffect` is a helper for writing affect functions that will compute observed values and ensure that modified values are correctly written back into the system. The affect function `f` needs to have -one of three signatures: -* `f(observed::ComponentArray)` if the function only reads observed values back from the system, -* `f(observed::ComponentArray, modified::ComponentArray)` if the function also writes values (unknowns or parameters) into the system, -* `f(observed::ComponentArray, modified::ComponentArray, ctx)` if the function needs the user-defined context, -* `f(observed::ComponentArray, modified::ComponentArray, ctx, integrator)` if the function needs the low-level integrator. +one of four signatures: +* `f(modified::ComponentArray)` if the function only writes values (unknowns or parameters) to the system, +* `f(modified::ComponentArray, observed::ComponentArray)` if the function also reads observed values from the system, +* `f(modified::ComponentArray, observed::ComponentArray, ctx)` if the function needs the user-defined context, +* `f(modified::ComponentArray, observed::ComponentArray, ctx, integrator)` if the function needs the low-level integrator. +These will be checked in reverse order (that is, the four-argument version first, than the 3, etc). The function `f` will be called with `observed` and `modified` `ComponentArray`s that are derived from their respective `NamedTuple` definitions. Each `NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x` @@ -90,7 +91,7 @@ so the value of `a + b` will be accessible as `observed.x` in `f`. `modified` cu Both `observed` and `modified` will be automatically populated with the current values of their corresponding expressions on function entry. The values in `modified` will be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write - MutatingFunctionalAffect(observed=(; x_plus_y = x + y), modified=(; x)) do o, m + MutatingFunctionalAffect(observed=(; x_plus_y = x + y), modified=(; x)) do m, o m.x = o.x_plus_y end @@ -109,11 +110,11 @@ MutatingFunctionalAffect(f::Function; observed::NamedTuple = NamedTuple{()}(()), modified::NamedTuple = NamedTuple{()}(()), ctx=nothing) = MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)), collect(values(modified)), collect(keys(modified)), ctx) -MutatingFunctionalAffect(f::Function, observed::NamedTuple; modified::NamedTuple = NamedTuple{()}(()), ctx=nothing) = +MutatingFunctionalAffect(f::Function, modified::NamedTuple; observed::NamedTuple = NamedTuple{()}(()), ctx=nothing) = MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx) -MutatingFunctionalAffect(f::Function, observed::NamedTuple, modified::NamedTuple; ctx=nothing) = +MutatingFunctionalAffect(f::Function, modified::NamedTuple, observed::NamedTuple; ctx=nothing) = MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx) -MutatingFunctionalAffect(f::Function, observed::NamedTuple, modified::NamedTuple, ctx) = +MutatingFunctionalAffect(f::Function, modified::NamedTuple, observed::NamedTuple, ctx) = MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx) func(f::MutatingFunctionalAffect) = f.f @@ -925,7 +926,18 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa obs_fun(obs_component_array, integ.u, integ.p..., integ.t) # let the user do their thing - user_affect(upd_component_array, obs_component_array, integ, ctx) + if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ) + user_affect(upd_component_array, obs_component_array, ctx, integ) + elseif applicable(user_affect, upd_component_array, obs_component_array, ctx) + user_affect(upd_component_array, obs_component_array, ctx) + elseif applicable(user_affect, upd_component_array, obs_component_array) + user_affect(upd_component_array, obs_component_array) + elseif applicable(user_affect, upd_component_array) + user_affect(upd_component_array) + else + @error "User affect function $user_affect needs to implement one of the supported MutatingFunctionalAffect callback forms; see the MutatingFunctionalAffect docstring for more details" + user_affect(upd_component_array, obs_component_array, integ, ctx) # this WILL error but it'll give a more sensible message + end # write the new values back to the integrator upd_params_fun(integ, upd_params_view) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index d24b590970..779f41471a 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -250,19 +250,19 @@ end m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; x)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa - @test isequal(m.obs, [x]) - @test m.obs_syms == [:x] - @test m.modified == [] - @test m.mod_syms == [] + @test isequal(m.obs, []) + @test m.obs_syms == [] + @test isequal(m.modified, [x]) + @test m.mod_syms == [:x] @test m.ctx === nothing m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; y=x)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa - @test isequal(m.obs, [x]) - @test m.obs_syms == [:y] - @test m.modified == [] - @test m.mod_syms == [] + @test isequal(m.obs, []) + @test m.obs_syms == [] + @test isequal(m.modified, [x]) + @test m.mod_syms == [:y] @test m.ctx === nothing m = ModelingToolkit.MutatingFunctionalAffect(fmfa; observed=(; y=x)) @@ -1013,7 +1013,48 @@ end ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i, c x.furnace_on = true end) - + @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) + ss = structural_simplify(sys) + prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + sol = solve(prob, Tsit5(); dtmax=0.01) + @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) + + furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i + x.furnace_on = false + end) + furnace_enable = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_on_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i + x.furnace_on = true + end) + @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) + ss = structural_simplify(sys) + prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + sol = solve(prob, Tsit5(); dtmax=0.01) + @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) + + furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o + x.furnace_on = false + end) + furnace_enable = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_on_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o + x.furnace_on = true + end) + @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) + ss = structural_simplify(sys) + prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + sol = solve(prob, Tsit5(); dtmax=0.01) + @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) + + furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x + x.furnace_on = false + end) + furnace_enable = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_on_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x + x.furnace_on = true + end) @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) ss = structural_simplify(sys) prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) @@ -1029,7 +1070,7 @@ end ] furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on), observed=(; furnace_on)) do x, o, i, c + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on), observed=(; furnace_on)) do x, o, c, i x.furnace_on = false end) @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off]) @@ -1043,7 +1084,7 @@ end ] furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on, tempsq), observed=(; furnace_on)) do x, o, i, c + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on, tempsq), observed=(; furnace_on)) do x, o, c, i x.furnace_on = false end) @named sys = ODESystem(eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) @@ -1053,7 +1094,7 @@ end @parameters not_actually_here furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on), observed=(; furnace_on, not_actually_here)) do x, o, i, c + ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on), observed=(; furnace_on, not_actually_here)) do x, o, c, i x.furnace_on = false end) @named sys = ODESystem(eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) @@ -1081,26 +1122,26 @@ end end end qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], - ModelingToolkit.MutatingFunctionalAffect((; qB), (; qA, hA, hB, cnt)) do x, o, i, c + ModelingToolkit.MutatingFunctionalAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c x.hA = x.qA x.hB = o.qB x.qA = 1 x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) end, - affect_neg = ModelingToolkit.MutatingFunctionalAffect((; qB), (; qA, hA, hB, cnt)) do x, o, i, c + affect_neg = ModelingToolkit.MutatingFunctionalAffect((; qA, hA, hB, cnt), (; qB)) do x, o, c, i x.hA = x.qA x.hB = o.qB x.qA = 0 x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) end; rootfind=SciMLBase.RightRootFind) qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π/2) ~ 0], - ModelingToolkit.MutatingFunctionalAffect((; qA), (; qB, hA, hB, cnt)) do x, o, i, c + ModelingToolkit.MutatingFunctionalAffect((; qB, hA, hB, cnt), (; qA)) do x, o, i, c x.hA = o.qA x.hB = x.qB x.qB = 1 x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) end, - affect_neg = ModelingToolkit.MutatingFunctionalAffect((; qA), (; qB, hA, hB, cnt)) do x, o, i, c + affect_neg = ModelingToolkit.MutatingFunctionalAffect((; qB, hA, hB, cnt), (; qA)) do x, o, c, i x.hA = o.qA x.hB = x.qB x.qB = 0 From 9948de076b491a114b395f8a9168eaaebe0959b1 Mon Sep 17 00:00:00 2001 From: Ben Chung Date: Fri, 2 Aug 2024 17:17:12 -0700 Subject: [PATCH 007/101] Run formatter --- src/systems/callbacks.jl | 87 ++++++++++++------- test/symbolic_events.jl | 175 ++++++++++++++++++++++----------------- 2 files changed, 158 insertions(+), 104 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 715b0e9a91..5c6f9b2801 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -106,16 +106,25 @@ The affect function updates the value at `x` in `modified` to be the result of e ctx::Any end -MutatingFunctionalAffect(f::Function; - observed::NamedTuple = NamedTuple{()}(()), - modified::NamedTuple = NamedTuple{()}(()), - ctx=nothing) = MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)), collect(values(modified)), collect(keys(modified)), ctx) -MutatingFunctionalAffect(f::Function, modified::NamedTuple; observed::NamedTuple = NamedTuple{()}(()), ctx=nothing) = - MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx) -MutatingFunctionalAffect(f::Function, modified::NamedTuple, observed::NamedTuple; ctx=nothing) = - MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx) -MutatingFunctionalAffect(f::Function, modified::NamedTuple, observed::NamedTuple, ctx) = - MutatingFunctionalAffect(f, observed=observed, modified=modified, ctx=ctx) +function MutatingFunctionalAffect(f::Function; + observed::NamedTuple = NamedTuple{()}(()), + modified::NamedTuple = NamedTuple{()}(()), + ctx = nothing) + MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)), + collect(values(modified)), collect(keys(modified)), ctx) +end +function MutatingFunctionalAffect(f::Function, modified::NamedTuple; + observed::NamedTuple = NamedTuple{()}(()), ctx = nothing) + MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx) +end +function MutatingFunctionalAffect( + f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing) + MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx) +end +function MutatingFunctionalAffect( + f::Function, modified::NamedTuple, observed::NamedTuple, ctx) + MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx) +end func(f::MutatingFunctionalAffect) = f.f context(a::MutatingFunctionalAffect) = a.ctx @@ -126,8 +135,9 @@ modified(a::MutatingFunctionalAffect) = a.modified modified_syms(a::MutatingFunctionalAffect) = a.mod_syms function Base.:(==)(a1::MutatingFunctionalAffect, a2::MutatingFunctionalAffect) - isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) && - isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms)&& isequal(a1.ctx, a2.ctx) + isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) && + isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) && + isequal(a1.ctx, a2.ctx) end function Base.hash(a::MutatingFunctionalAffect, s::UInt) @@ -839,10 +849,13 @@ function compile_user_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs. end end -invalid_variables(sys, expr) = filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init=[])) -function unassignable_variables(sys, expr) +function invalid_variables(sys, expr) + filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init = [])) +end +function unassignable_variables(sys, expr) assignable_syms = vcat(unknowns(sys), parameters(sys)) - return filter(x -> !any(isequal(x), assignable_syms), reduce(vcat, vars(expr); init=[])) + return filter( + x -> !any(isequal(x), assignable_syms), reduce(vcat, vars(expr); init = [])) end function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwargs...) @@ -855,7 +868,8 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa =# function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup) seen = Set{Symbol}() - syms_dedup = []; exprs_dedup = [] + syms_dedup = [] + exprs_dedup = [] for (sym, exp) in Iterators.zip(syms, exprs) if !in(sym, seen) push!(syms_dedup, sym) @@ -869,7 +883,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa end obs_exprs = observed(affect) - for oexpr in obs_exprs + for oexpr in obs_exprs invalid_vars = invalid_variables(sys, oexpr) if length(invalid_vars) > 0 error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).") @@ -880,7 +894,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa obs_size = size.(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size mod_exprs = modified(affect) - for mexpr in mod_exprs + for mexpr in mod_exprs if !is_observed(sys, mexpr) && parameter_index(sys, mexpr) === nothing error("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") end @@ -891,7 +905,8 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa end mod_syms = modified_syms(affect) mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs) - _, mod_og_val_fun = build_explicit_observed_function(sys, mod_exprs; return_inplace=true) + _, mod_og_val_fun = build_explicit_observed_function( + sys, mod_exprs; return_inplace = true) overlapping_syms = intersect(mod_syms, obs_syms) if length(overlapping_syms) > 0 @@ -899,21 +914,33 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa end # sanity checks done! now build the data and update function for observed values - mkzero(sz) = if sz === () 0.0 else zeros(sz) end - _, obs_fun = build_explicit_observed_function(sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []); return_inplace=true) - obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms..., )}(mkzero.(obs_size))) + mkzero(sz) = + if sz === () + 0.0 + else + zeros(sz) + end + _, obs_fun = build_explicit_observed_function( + sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []); + return_inplace = true) + obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms...,)}(mkzero.(obs_size))) # okay so now to generate the stuff to assign it back into the system # note that we reorder the componentarray to make the views coherent wrt the base array mod_pairs = mod_exprs .=> mod_syms mod_param_pairs = filter(v -> is_parameter(sys, v[1]), mod_pairs) mod_unk_pairs = filter(v -> !is_parameter(sys, v[1]), mod_pairs) - _, mod_og_val_fun = build_explicit_observed_function(sys, reduce(vcat, [first.(mod_param_pairs); first.(mod_unk_pairs)]; init = []); return_inplace=true) - upd_params_fun = setu(sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = [])) - upd_unk_fun = setu(sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = [])) - - upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs); last.(mod_unk_pairs)]...,)}( - [collect(mkzero(size(e)) for e in first.(mod_param_pairs)); + _, mod_og_val_fun = build_explicit_observed_function( + sys, reduce(vcat, [first.(mod_param_pairs); first.(mod_unk_pairs)]; init = []); + return_inplace = true) + upd_params_fun = setu( + sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = [])) + upd_unk_fun = setu( + sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = [])) + + upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs); + last.(mod_unk_pairs)]...,)}( + [collect(mkzero(size(e)) for e in first.(mod_param_pairs)); collect(mkzero(size(e)) for e in first.(mod_unk_pairs))])) upd_params_view = view(upd_component_array, last.(mod_param_pairs)) upd_unks_view = view(upd_component_array, last.(mod_unk_pairs)) @@ -921,7 +948,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa function (integ) # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens mod_og_val_fun(upd_component_array, integ.u, integ.p..., integ.t) - + # update the observed values obs_fun(obs_component_array, integ.u, integ.p..., integ.t) @@ -934,7 +961,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa user_affect(upd_component_array, obs_component_array) elseif applicable(user_affect, upd_component_array) user_affect(upd_component_array) - else + else @error "User affect function $user_affect needs to implement one of the supported MutatingFunctionalAffect callback forms; see the MutatingFunctionalAffect docstring for more details" user_affect(upd_component_array, obs_component_array, integ, ctx) # this WILL error but it'll give a more sensible message end diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 779f41471a..7aba69dc7f 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -237,7 +237,7 @@ end @test m.modified == [] @test m.mod_syms == [] @test m.ctx === nothing - + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (;)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @@ -246,7 +246,7 @@ end @test m.modified == [] @test m.mod_syms == [] @test m.ctx === nothing - + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; x)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @@ -255,8 +255,8 @@ end @test isequal(m.modified, [x]) @test m.mod_syms == [:x] @test m.ctx === nothing - - m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; y=x)) + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; y = x)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @test isequal(m.obs, []) @@ -264,8 +264,8 @@ end @test isequal(m.modified, [x]) @test m.mod_syms == [:y] @test m.ctx === nothing - - m = ModelingToolkit.MutatingFunctionalAffect(fmfa; observed=(; y=x)) + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa; observed = (; y = x)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @test isequal(m.obs, [x]) @@ -273,8 +273,8 @@ end @test m.modified == [] @test m.mod_syms == [] @test m.ctx === nothing - - m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified=(; x)) + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified = (; x)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @test isequal(m.obs, []) @@ -283,7 +283,7 @@ end @test m.mod_syms == [:x] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified=(; y=x)) + m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified = (; y = x)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @test isequal(m.obs, []) @@ -291,7 +291,7 @@ end @test isequal(m.modified, [x]) @test m.mod_syms == [:y] @test m.ctx === nothing - + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; x), (; x)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @@ -300,8 +300,8 @@ end @test isequal(m.modified, [x]) @test m.mod_syms == [:x] @test m.ctx === nothing - - m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; y=x), (; y=x)) + + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; y = x), (; y = x)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @test isequal(m.obs, [x]) @@ -309,8 +309,9 @@ end @test isequal(m.modified, [x]) @test m.mod_syms == [:y] @test m.ctx === nothing - - m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified=(; y=x), observed=(; y=x)) + + m = ModelingToolkit.MutatingFunctionalAffect( + fmfa; modified = (; y = x), observed = (; y = x)) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @test isequal(m.obs, [x]) @@ -318,8 +319,9 @@ end @test isequal(m.modified, [x]) @test m.mod_syms == [:y] @test m.ctx === nothing - - m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified=(; y=x), observed=(; y=x), ctx=3) + + m = ModelingToolkit.MutatingFunctionalAffect( + fmfa; modified = (; y = x), observed = (; y = x), ctx = 3) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @test isequal(m.obs, [x]) @@ -327,7 +329,7 @@ end @test isequal(m.modified, [x]) @test m.mod_syms == [:y] @test m.ctx === 3 - + m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; x), (; x), 3) @test m isa ModelingToolkit.MutatingFunctionalAffect @test m.f == fmfa @@ -1005,151 +1007,176 @@ end D(temp) ~ furnace_on * furnace_power - temp^2 * leakage ] - furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i, c + furnace_off = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i, c x.furnace_on = false end) - furnace_enable = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_on_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i, c + furnace_enable = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_on_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i, c x.furnace_on = true end) - @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) + @named sys = ODESystem( + eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) ss = structural_simplify(sys) prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) - sol = solve(prob, Tsit5(); dtmax=0.01) + sol = solve(prob, Tsit5(); dtmax = 0.01) @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) - furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i + furnace_off = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i x.furnace_on = false end) - furnace_enable = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_on_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o, i + furnace_enable = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_on_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i x.furnace_on = true end) - @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) + @named sys = ODESystem( + eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) ss = structural_simplify(sys) prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) - sol = solve(prob, Tsit5(); dtmax=0.01) + sol = solve(prob, Tsit5(); dtmax = 0.01) @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) - furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o + furnace_off = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o x.furnace_on = false end) - furnace_enable = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_on_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x, o + furnace_enable = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_on_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o x.furnace_on = true end) - @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) + @named sys = ODESystem( + eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) ss = structural_simplify(sys) prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) - sol = solve(prob, Tsit5(); dtmax=0.01) + sol = solve(prob, Tsit5(); dtmax = 0.01) @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) - furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x + furnace_off = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x x.furnace_on = false end) - furnace_enable = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_on_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on)) do x + furnace_enable = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_on_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x x.furnace_on = true end) - @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) + @named sys = ODESystem( + eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) ss = structural_simplify(sys) prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) - sol = solve(prob, Tsit5(); dtmax=0.01) + sol = solve(prob, Tsit5(); dtmax = 0.01) @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) end -@testset "MutatingFunctionalAffect errors and warnings" begin +@testset "MutatingFunctionalAffect errors and warnings" begin @variables temp(t) params = @parameters furnace_on_threshold=0.5 furnace_off_threshold=0.7 furnace_power=1.0 leakage=0.1 furnace_on::Bool=false eqs = [ D(temp) ~ furnace_on * furnace_power - temp^2 * leakage ] - furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on), observed=(; furnace_on)) do x, o, c, i + furnace_off = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect( + modified = (; furnace_on), observed = (; furnace_on)) do x, o, c, i x.furnace_on = false end) @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off]) ss = structural_simplify(sys) - @test_logs (:warn, "The symbols Any[:furnace_on] are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value.") prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + @test_logs (:warn, + "The symbols Any[:furnace_on] are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value.") prob=ODEProblem( + ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) @variables tempsq(t) # trivially eliminated - eqs = [ - tempsq ~ temp^2 - D(temp) ~ furnace_on * furnace_power - temp^2 * leakage - ] + eqs = [tempsq ~ temp^2 + D(temp) ~ furnace_on * furnace_power - temp^2 * leakage] - furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on, tempsq), observed=(; furnace_on)) do x, o, c, i + furnace_off = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect( + modified = (; furnace_on, tempsq), observed = (; furnace_on)) do x, o, c, i x.furnace_on = false end) - @named sys = ODESystem(eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) + @named sys = ODESystem( + eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) ss = structural_simplify(sys) - @test_throws "refers to missing variable(s)" prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + @test_throws "refers to missing variable(s)" prob=ODEProblem( + ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) - @parameters not_actually_here - furnace_off = ModelingToolkit.SymbolicContinuousCallback([temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified=(; furnace_on), observed=(; furnace_on, not_actually_here)) do x, o, c, i + furnace_off = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on), + observed = (; furnace_on, not_actually_here)) do x, o, c, i x.furnace_on = false end) - @named sys = ODESystem(eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) + @named sys = ODESystem( + eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) ss = structural_simplify(sys) - @test_throws "refers to missing variable(s)" prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + @test_throws "refers to missing variable(s)" prob=ODEProblem( + ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) end -@testset "Quadrature" begin +@testset "Quadrature" begin @variables theta(t) omega(t) params = @parameters qA=0 qB=0 hA=0 hB=0 cnt=0 - eqs = [ - D(theta) ~ omega - omega ~ 1.0 - ] + eqs = [D(theta) ~ omega + omega ~ 1.0] function decoder(oldA, oldB, newA, newB) state = (oldA, oldB, newA, newB) - if state == (0, 0, 1, 0) || state == (1, 0, 1, 1) || state == (1, 1, 0, 1) || state == (0, 1, 0, 0) + if state == (0, 0, 1, 0) || state == (1, 0, 1, 1) || state == (1, 1, 0, 1) || + state == (0, 1, 0, 0) return 1 - elseif state == (0, 0, 0, 1) || state == (0, 1, 1, 1) || state == (1, 1, 1, 0) || state == (1, 0, 0, 0) + elseif state == (0, 0, 0, 1) || state == (0, 1, 1, 1) || state == (1, 1, 1, 0) || + state == (1, 0, 0, 0) return -1 - elseif state == (0, 0, 0, 0) || state == (0, 1, 0, 1) || state == (1, 0, 1, 0) || state == (1, 1, 1, 1) + elseif state == (0, 0, 0, 0) || state == (0, 1, 0, 1) || state == (1, 0, 1, 0) || + state == (1, 1, 1, 1) return 0 else return 0 # err is interpreted as no movement end end - qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], + qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], ModelingToolkit.MutatingFunctionalAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c x.hA = x.qA x.hB = o.qB x.qA = 1 x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) end, - affect_neg = ModelingToolkit.MutatingFunctionalAffect((; qA, hA, hB, cnt), (; qB)) do x, o, c, i + affect_neg = ModelingToolkit.MutatingFunctionalAffect( + (; qA, hA, hB, cnt), (; qB)) do x, o, c, i x.hA = x.qA x.hB = o.qB x.qA = 0 x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) - end; rootfind=SciMLBase.RightRootFind) - qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π/2) ~ 0], + end; rootfind = SciMLBase.RightRootFind) + qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0], ModelingToolkit.MutatingFunctionalAffect((; qB, hA, hB, cnt), (; qA)) do x, o, i, c x.hA = o.qA x.hB = x.qB x.qB = 1 x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) end, - affect_neg = ModelingToolkit.MutatingFunctionalAffect((; qB, hA, hB, cnt), (; qA)) do x, o, c, i + affect_neg = ModelingToolkit.MutatingFunctionalAffect( + (; qB, hA, hB, cnt), (; qA)) do x, o, c, i x.hA = o.qA x.hB = x.qB x.qB = 0 x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) - end; rootfind=SciMLBase.RightRootFind) - @named sys = ODESystem(eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt]) + end; rootfind = SciMLBase.RightRootFind) + @named sys = ODESystem( + eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt]) ss = structural_simplify(sys) prob = ODEProblem(ss, [theta => 0.0], (0.0, pi)) - sol = solve(prob, Tsit5(); dtmax=0.01) + sol = solve(prob, Tsit5(); dtmax = 0.01) @test sol[cnt] == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state end From 61cf6762c43e595448ed6f37f37bd889db663b10 Mon Sep 17 00:00:00 2001 From: Ben Chung Date: Fri, 2 Aug 2024 18:44:18 -0700 Subject: [PATCH 008/101] FIx SCC reconstruction --- src/systems/callbacks.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 5c6f9b2801..8ea7d9f506 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -284,13 +284,15 @@ end namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af] namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s) +namespace_affects(af::MutatingFunctionalAffect, s) = namespace_affect(af, s) namespace_affects(::Nothing, s) = nothing function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback - SymbolicContinuousCallback( - namespace_equation.(equations(cb), (s,)), - namespace_affects(affects(cb), s); - affect_neg = namespace_affects(affect_negs(cb), s)) + SymbolicContinuousCallback(; + eqs = namespace_equation.(equations(cb), (s,)), + affect = namespace_affects(affects(cb), s), + affect_neg = namespace_affects(affect_negs(cb), s), + rootfind = cb.rootfind) end """ From 8edef14b63637a4132e5139540d3ec89fa9d39fd Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 19 Aug 2024 07:49:33 -0700 Subject: [PATCH 009/101] Implement initialize and finalize affects for symbolic callbacks --- src/systems/callbacks.jl | 206 +++++++++++++++++++++++++++++---------- 1 file changed, 154 insertions(+), 52 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 8ea7d9f506..d14fd50691 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -104,28 +104,38 @@ The affect function updates the value at `x` in `modified` to be the result of e modified::Vector mod_syms::Vector{Symbol} ctx::Any + skip_checks::Bool end function MutatingFunctionalAffect(f::Function; observed::NamedTuple = NamedTuple{()}(()), modified::NamedTuple = NamedTuple{()}(()), - ctx = nothing) - MutatingFunctionalAffect(f, collect(values(observed)), collect(keys(observed)), - collect(values(modified)), collect(keys(modified)), ctx) + ctx = nothing, + skip_checks = false) + MutatingFunctionalAffect(f, + collect(values(observed)), collect(keys(observed)), + collect(values(modified)), collect(keys(modified)), + ctx, skip_checks) end function MutatingFunctionalAffect(f::Function, modified::NamedTuple; - observed::NamedTuple = NamedTuple{()}(()), ctx = nothing) - MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx) + observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks=false) + MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) end function MutatingFunctionalAffect( - f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing) - MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx) + f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks=false) + MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) end function MutatingFunctionalAffect( - f::Function, modified::NamedTuple, observed::NamedTuple, ctx) - MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx) + f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks=false) + MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) end +function Base.show(io::IO, mfa::MutatingFunctionalAffect) + obs_vals = join(map((ob,nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ") + mod_vals = join(map((md,nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ") + affect = mfa.f + print(io, "MutatingFunctionalAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)") +end func(f::MutatingFunctionalAffect) = f.f context(a::MutatingFunctionalAffect) = a.ctx observed(a::MutatingFunctionalAffect) = a.obs @@ -208,12 +218,19 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either: """ struct SymbolicContinuousCallback eqs::Vector{Equation} + initialize::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect} + finalize::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect} affect::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect} affect_neg::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing} rootfind::SciMLBase.RootfindOpt - function SymbolicContinuousCallback(; eqs::Vector{Equation}, affect = NULL_AFFECT, - affect_neg = affect, rootfind = SciMLBase.LeftRootFind) - new(eqs, make_affect(affect), make_affect(affect_neg), rootfind) + function SymbolicContinuousCallback(; + eqs::Vector{Equation}, + affect = NULL_AFFECT, + affect_neg = affect, + rootfind = SciMLBase.LeftRootFind, + initialize=NULL_AFFECT, + finalize=NULL_AFFECT) + new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind) end # Default affect to nothing end make_affect(affect) = affect @@ -221,18 +238,81 @@ make_affect(affect::Tuple) = FunctionalAffect(affect...) make_affect(affect::NamedTuple) = FunctionalAffect(; affect...) function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback) - isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) && + isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) && + isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) && isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind) end Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs) function Base.hash(cb::SymbolicContinuousCallback, s::UInt) + hash_affect(affect::AbstractVector, s) = foldr(hash, affect, init = s) + hash_affect(affect, s) = hash(cb.affect, s) s = foldr(hash, cb.eqs, init = s) - s = cb.affect isa AbstractVector ? foldr(hash, cb.affect, init = s) : hash(cb.affect, s) - s = cb.affect_neg isa AbstractVector ? foldr(hash, cb.affect_neg, init = s) : - hash(cb.affect_neg, s) + s = hash_affect(cb.affect, s) + s = hash_affect(cb.affect_neg, s) + s = hash_affect(cb.initialize, s) + s = hash_affect(cb.finalize, s) hash(cb.rootfind, s) end + +function Base.show(io::IO, cb::SymbolicContinuousCallback) + indent = get(io, :indent, 0) + iio = IOContext(io, :indent => indent+1) + print(io, "SymbolicContinuousCallback(") + print(iio, "Equations:") + show(iio, equations(cb)) + print(iio, "; ") + if affects(cb) != NULL_AFFECT + print(iio, "Affect:") + show(iio, affects(cb)) + print(iio, ", ") + end + if affect_negs(cb) != NULL_AFFECT + print(iio, "Negative-edge affect:") + show(iio, affect_negs(cb)) + print(iio, ", ") + end + if initialize_affects(cb) != NULL_AFFECT + print(iio, "Initialization affect:") + show(iio, initialize_affects(cb)) + print(iio, ", ") + end + if finalize_affects(cb) != NULL_AFFECT + print(iio, "Finalization affect:") + show(iio, finalize_affects(cb)) + end + print(iio, ")") +end + +function Base.show(io::IO, mime::MIME"text/plain", cb::SymbolicContinuousCallback) + indent = get(io, :indent, 0) + iio = IOContext(io, :indent => indent+1) + println(io, "SymbolicContinuousCallback:") + println(iio, "Equations:") + show(iio, mime, equations(cb)) + print(iio, "\n") + if affects(cb) != NULL_AFFECT + println(iio, "Affect:") + show(iio, mime, affects(cb)) + print(iio, "\n") + end + if affect_negs(cb) != NULL_AFFECT + println(iio, "Negative-edge affect:") + show(iio, mime, affect_negs(cb)) + print(iio, "\n") + end + if initialize_affects(cb) != NULL_AFFECT + println(iio, "Initialization affect:") + show(iio, mime, initialize_affects(cb)) + print(iio, "\n") + end + if finalize_affects(cb) != NULL_AFFECT + println(iio, "Finalization affect:") + show(iio, mime, finalize_affects(cb)) + print(iio, "\n") + end +end + to_equation_vector(eq::Equation) = [eq] to_equation_vector(eqs::Vector{Equation}) = eqs function to_equation_vector(eqs::Vector{Any}) @@ -246,14 +326,14 @@ end # wrap eq in vector SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2]) SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT; - affect_neg = affect, rootfind = SciMLBase.LeftRootFind) + affect_neg = affect, rootfind = SciMLBase.LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT) SymbolicContinuousCallback( - eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind) + eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize=initialize, finalize=finalize) end function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT; - affect_neg = affect, rootfind = SciMLBase.LeftRootFind) + affect_neg = affect, rootfind = SciMLBase.LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT) SymbolicContinuousCallback( - eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind) + eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize=initialize, finalize=finalize) end SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb] @@ -282,6 +362,16 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback}) mapreduce(affect_negs, vcat, cbs, init = Equation[]) end +initialize_affects(cb::SymbolicContinuousCallback) = cb.initialize +function initialize_affects(cbs::Vector{SymbolicContinuousCallback}) + mapreduce(initialize_affects, vcat, cbs, init = Equation[]) +end + +finalize_affects(cb::SymbolicContinuousCallback) = cb.initialize +function finalize_affects(cbs::Vector{SymbolicContinuousCallback}) + mapreduce(finalize_affects, vcat, cbs, init = Equation[]) +end + namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af] namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s) namespace_affects(af::MutatingFunctionalAffect, s) = namespace_affect(af, s) @@ -292,6 +382,8 @@ function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuo eqs = namespace_equation.(equations(cb), (s,)), affect = namespace_affects(affects(cb), s), affect_neg = namespace_affects(affect_negs(cb), s), + initialize = namespace_affects(initialize_affects(cb), s), + finalize = namespace_affects(finalize_affects(cb), s), rootfind = cb.rootfind) end @@ -681,8 +773,9 @@ function generate_single_rootfinding_callback( initfn = SciMLBase.INITIALIZE_DEFAULT end return ContinuousCallback( - cond, affect_function.affect, affect_function.affect_neg, - rootfind = cb.rootfind, initialize = initfn) + cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind, + initialize = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i), + finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i)) end function generate_vector_rootfinding_callback( @@ -702,13 +795,12 @@ function generate_vector_rootfinding_callback( _, rf_ip = generate_custom_function( sys, rhss, dvs, ps; expression = Val{false}, kwargs...) - affect_functions = @NamedTuple{affect::Function, affect_neg::Union{Function, Nothing}}[compile_affect_fn( - cb, - sys, - dvs, - ps, - kwargs) - for cb in cbs] + affect_functions = @NamedTuple{ + affect::Function, + affect_neg::Union{Function, Nothing}, + initialize::Union{Function, Nothing}, + finalize::Union{Function, Nothing}}[ + compile_affect_fn(cb, sys, dvs, ps, kwargs) for cb in cbs] cond = function (out, u, t, integ) rf_ip(out, u, parameter_values(integ), t) end @@ -734,25 +826,27 @@ function generate_vector_rootfinding_callback( affect_neg(integ) end end - if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - save_idxs = mapreduce( - cb -> get(ic.callback_to_clocks, cb, Int[]), vcat, cbs; init = Int[]) - initfn = if isempty(save_idxs) - SciMLBase.INITIALIZE_DEFAULT + function handle_optional_setup_fn(funs, default) + if all(isnothing, funs) + return default else - let save_idxs = save_idxs - function (cb, u, t, integrator) - for idx in save_idxs - SciMLBase.save_discretes!(integrator, idx) + return let funs = funs + function (cb, u, t, integ) + for func in funs + if isnothing(func) + continue + else + func(integ) + end end end end end - else - initfn = SciMLBase.INITIALIZE_DEFAULT end + initialize = handle_optional_setup_fn(map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT) + finalize = handle_optional_setup_fn(map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT) return VectorContinuousCallback( - cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initfn) + cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, finalize = finalize) end """ @@ -762,15 +856,23 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs) eq_aff = affects(cb) eq_neg_aff = affect_negs(cb) affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) + function compile_optional_affect(aff) + if isnothing(aff) + return nothing + else + affspr = compile_affect(aff, cb, sys, dvs, ps; expression = Val{true}, kwargs...) + @show affspr + return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) + end + end if eq_neg_aff === eq_aff affect_neg = affect - elseif isnothing(eq_neg_aff) - affect_neg = nothing else - affect_neg = compile_affect( - eq_neg_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) + affect_neg = compile_optional_affect(eq_neg_aff) end - (affect = affect, affect_neg = affect_neg) + initialize = compile_optional_affect(initialize_affects(cb)) + finalize = compile_optional_affect(finalize_affects(cb)) + (affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize) end function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknowns(sys), @@ -877,7 +979,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa push!(syms_dedup, sym) push!(exprs_dedup, exp) push!(seen, sym) - else + elseif !affect.skip_checks @warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used." end end @@ -887,7 +989,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa obs_exprs = observed(affect) for oexpr in obs_exprs invalid_vars = invalid_variables(sys, oexpr) - if length(invalid_vars) > 0 + if length(invalid_vars) > 0 && !affect.skip_checks error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).") end end @@ -897,11 +999,11 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa mod_exprs = modified(affect) for mexpr in mod_exprs - if !is_observed(sys, mexpr) && parameter_index(sys, mexpr) === nothing - error("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") + if !is_observed(sys, mexpr) && parameter_index(sys, mexpr) === nothing && !affect.skip_checks + @warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") end invalid_vars = unassignable_variables(sys, mexpr) - if length(invalid_vars) > 0 + if length(invalid_vars) > 0 && !affect.skip_checks error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") end end @@ -911,7 +1013,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa sys, mod_exprs; return_inplace = true) overlapping_syms = intersect(mod_syms, obs_syms) - if length(overlapping_syms) > 0 + if length(overlapping_syms) > 0 && !affect.skip_checks @warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value." end From 3c7afd8c40b526dd6f139aa2d119095ab6debeb7 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 19 Aug 2024 08:08:13 -0700 Subject: [PATCH 010/101] Test some simple initialization affects --- test/symbolic_events.jl | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 7aba69dc7f..b0f230d75e 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1074,6 +1074,26 @@ end prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) sol = solve(prob, Tsit5(); dtmax = 0.01) @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) + + furnace_off = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x + x.furnace_on = false + end; initialize = ModelingToolkit.MutatingFunctionalAffect(modified = (; temp)) do x + x.temp = 0.2 + end) + furnace_enable = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_on_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, c, i + x.furnace_on = true + end) + @named sys = ODESystem( + eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) + ss = structural_simplify(sys) + prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + sol = solve(prob, Tsit5(); dtmax = 0.01) + @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) + @test all(sol[temp][sol.t .!= 0.0] .<= 0.79) && all(sol[temp][sol.t .!= 0.0] .>= 0.2) end @testset "MutatingFunctionalAffect errors and warnings" begin From 95956f363eaa8bbc5b492452946c6c2a21e35bbd Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Thu, 22 Aug 2024 07:17:50 -0700 Subject: [PATCH 011/101] Properly pass skip_checks through --- src/systems/callbacks.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index d14fd50691..ca0a64b452 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -165,7 +165,8 @@ function namespace_affect(affect::MutatingFunctionalAffect, s) observed_syms(affect), renamespace.((s,), modified(affect)), modified_syms(affect), - context(affect)) + context(affect), + affect.skip_checks) end function has_functional_affect(cb) From 4f928ae57e33b2d8bb5e52faed229941b5e589af Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Thu, 22 Aug 2024 07:35:38 -0700 Subject: [PATCH 012/101] Support time-indexed parameters --- src/systems/callbacks.jl | 16 ++++++++++++++-- src/systems/index_cache.jl | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index ca0a64b452..279d31483e 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -963,7 +963,7 @@ function unassignable_variables(sys, expr) x -> !any(isequal(x), assignable_syms), reduce(vcat, vars(expr); init = [])) end -function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwargs...) +function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; kwargs...) #= Implementation sketch: generate observed function (oop), should save to a component array under obs_syms @@ -1049,6 +1049,13 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa collect(mkzero(size(e)) for e in first.(mod_unk_pairs))])) upd_params_view = view(upd_component_array, last.(mod_param_pairs)) upd_unks_view = view(upd_component_array, last.(mod_unk_pairs)) + + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + save_idxs = get(ic.callback_to_clocks, cb, Int[]) + else + save_idxs = Int[] + end + let user_affect = func(affect), ctx = context(affect) function (integ) # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens @@ -1074,11 +1081,16 @@ function compile_user_affect(affect::MutatingFunctionalAffect, sys, dvs, ps; kwa # write the new values back to the integrator upd_params_fun(integ, upd_params_view) upd_unk_fun(integ, upd_unks_view) + + + for idx in save_idxs + SciMLBase.save_discretes!(integ, idx) + end end end end -function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...) +function compile_affect(affect::Union{FunctionalAffect, MutatingFunctionalAffect}, cb, sys, dvs, ps; kwargs...) compile_user_affect(affect, cb, sys, dvs, ps; kwargs...) end diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 00f7837407..55d819990b 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -126,7 +126,7 @@ function IndexCache(sys::AbstractSystem) for affect in affs if affect isa Equation is_parameter(sys, affect.lhs) && push!(discs, affect.lhs) - elseif affect isa FunctionalAffect + elseif affect isa FunctionalAffect || affect isa MutatingFunctionalAffect union!(discs, unwrap.(discretes(affect))) else error("Unhandled affect type $(typeof(affect))") From 3eca9b96567008469ee6963745a6a4f2ea64ed50 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 27 Aug 2024 09:28:17 -0700 Subject: [PATCH 013/101] Fix bugs relating to array arguments to callbacks --- src/systems/callbacks.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 279d31483e..884d1d87f3 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -958,9 +958,10 @@ function invalid_variables(sys, expr) filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init = [])) end function unassignable_variables(sys, expr) - assignable_syms = vcat(unknowns(sys), parameters(sys)) + assignable_syms = reduce(vcat, Symbolics.scalarize.(vcat(unknowns(sys), parameters(sys))); init=[]) + written = reduce(vcat, Symbolics.scalarize.(vars(expr)); init = []) return filter( - x -> !any(isequal(x), assignable_syms), reduce(vcat, vars(expr); init = [])) + x -> !any(isequal(x), assignable_syms), written) end function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; kwargs...) @@ -1000,7 +1001,10 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; mod_exprs = modified(affect) for mexpr in mod_exprs - if !is_observed(sys, mexpr) && parameter_index(sys, mexpr) === nothing && !affect.skip_checks + if affect.skip_checks + continue + end + if !is_variable(sys, mexpr) && parameter_index(sys, mexpr) === nothing && !affect.skip_checks @warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") end invalid_vars = unassignable_variables(sys, mexpr) @@ -1036,7 +1040,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; mod_param_pairs = filter(v -> is_parameter(sys, v[1]), mod_pairs) mod_unk_pairs = filter(v -> !is_parameter(sys, v[1]), mod_pairs) _, mod_og_val_fun = build_explicit_observed_function( - sys, reduce(vcat, [first.(mod_param_pairs); first.(mod_unk_pairs)]; init = []); + sys, reduce(vcat, Symbolics.scalarize.([first.(mod_param_pairs); first.(mod_unk_pairs)]); init = []); return_inplace = true) upd_params_fun = setu( sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = [])) From c41e7d4677077f69024c0f4f0a8e9ef6f83b2d38 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 27 Aug 2024 18:27:22 -0700 Subject: [PATCH 014/101] Remove debug logging --- src/systems/callbacks.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 884d1d87f3..a44b5093c8 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -861,8 +861,6 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs) if isnothing(aff) return nothing else - affspr = compile_affect(aff, cb, sys, dvs, ps; expression = Val{true}, kwargs...) - @show affspr return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) end end From 95fa1ee554c0cf449c3c2ac9abdcff85546c80d4 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 9 Sep 2024 20:06:45 -0700 Subject: [PATCH 015/101] Fix the namespace operation used while namespacing MutatingFunctionalAffects --- src/systems/callbacks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index a44b5093c8..a7834d918e 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -161,7 +161,7 @@ end function namespace_affect(affect::MutatingFunctionalAffect, s) MutatingFunctionalAffect(func(affect), - renamespace.((s,), observed(affect)), + namespace_expr.(observed(affect), (s,)), observed_syms(affect), renamespace.((s,), modified(affect)), modified_syms(affect), From f57215ab42f5c4c3747744827b8793403c2e6761 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 10 Sep 2024 12:45:14 -0700 Subject: [PATCH 016/101] Add support for the initializealg argument in SciMLBase callbacks --- src/systems/callbacks.jl | 41 ++++++++++++++++++++++++++++++---------- test/symbolic_events.jl | 6 +++--- 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index a7834d918e..d2657d2ced 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -216,6 +216,11 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either: + `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition. + `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem. * A [`MutatingFunctionalAffect`](@ref); refer to its documentation for details. + +Callbacks that impact a DAE are applied, then the DAE is reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`). +This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate +that the newly-assigned values indeed satisfy the algebraic system; see the documentation on DAE initialization for a more detailed discussion of +initialization. """ struct SymbolicContinuousCallback eqs::Vector{Equation} @@ -224,14 +229,16 @@ struct SymbolicContinuousCallback affect::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect} affect_neg::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing} rootfind::SciMLBase.RootfindOpt + reinitializealg::SciMLBase.DAEInitializationAlgorithm function SymbolicContinuousCallback(; eqs::Vector{Equation}, affect = NULL_AFFECT, affect_neg = affect, rootfind = SciMLBase.LeftRootFind, initialize=NULL_AFFECT, - finalize=NULL_AFFECT) - new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind) + finalize=NULL_AFFECT, + reinitializealg=SciMLBase.CheckInit()) + new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg) end # Default affect to nothing end make_affect(affect) = affect @@ -373,6 +380,10 @@ function finalize_affects(cbs::Vector{SymbolicContinuousCallback}) mapreduce(finalize_affects, vcat, cbs, init = Equation[]) end +reinitialization_alg(cb::SymbolicContinuousCallback) = cb.reinitializealg +reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) = + mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) + namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af] namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s) namespace_affects(af::MutatingFunctionalAffect, s) = namespace_affect(af, s) @@ -419,11 +430,12 @@ struct SymbolicDiscreteCallback # TODO: Iterative condition::Any affects::Any + reinitializealg::SciMLBase.DAEInitializationAlgorithm - function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT) + function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT, reinitializealg=SciMLBase.CheckInit()) c = scalarize_condition(condition) a = scalarize_affects(affects) - new(c, a) + new(c, a, reinitializealg) end # Default affect to nothing end @@ -481,6 +493,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback}) reduce(vcat, affects(cb) for cb in cbs; init = []) end +reinitialization_alg(cb::SymbolicDiscreteCallback) = cb.reinitializealg +reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback}) = + mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) + function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback af = affects(cb) af = af isa AbstractVector ? namespace_affect.(af, Ref(s)) : namespace_affect(af, s) @@ -776,12 +792,13 @@ function generate_single_rootfinding_callback( return ContinuousCallback( cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind, initialize = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i), - finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i)) + finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i), + initializealg = reinitialization_alg(cb)) end function generate_vector_rootfinding_callback( cbs, sys::AbstractODESystem, dvs = unknowns(sys), - ps = parameters(sys); rootfind = SciMLBase.RightRootFind, kwargs...) + ps = parameters(sys); rootfind = SciMLBase.RightRootFind, reinitialization = SciMLBase.CheckInit(), kwargs...) eqs = map(cb -> flatten_equations(cb.eqs), cbs) num_eqs = length.(eqs) # fuse equations to create VectorContinuousCallback @@ -847,7 +864,7 @@ function generate_vector_rootfinding_callback( initialize = handle_optional_setup_fn(map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT) finalize = handle_optional_setup_fn(map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT) return VectorContinuousCallback( - cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, finalize = finalize) + cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, finalize = finalize, initializealg = reinitialization) end """ @@ -893,10 +910,14 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow # group the cbs by what rootfind op they use # groupby would be very useful here, but alas cb_classes = Dict{ - @NamedTuple{rootfind::SciMLBase.RootfindOpt}, Vector{SymbolicContinuousCallback}}() + @NamedTuple{ + rootfind::SciMLBase.RootfindOpt, + reinitialization::SciMLBase.DAEInitializationAlgorithm}, Vector{SymbolicContinuousCallback}}() for cb in cbs push!( - get!(() -> SymbolicContinuousCallback[], cb_classes, (rootfind = cb.rootfind,)), + get!(() -> SymbolicContinuousCallback[], cb_classes, ( + rootfind = cb.rootfind, + reinitialization = reinitialization_alg(cb))), cb) end @@ -904,7 +925,7 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow compiled_callbacks = map(collect(pairs(sort!( OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class) return generate_vector_rootfinding_callback( - cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, kwargs...) + cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, reinitialization=equiv_class.reinitialization, kwargs...) end if length(compiled_callbacks) == 1 return compiled_callbacks[] diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index b0f230d75e..dd60b99003 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -996,8 +996,8 @@ end @test sort(canonicalize(Discrete(), prob.p)[1]) == [0.0, 1.0, 2.0] sol = solve(prob, Tsit5()) - @test sol[a] == [1.0, -1.0] - @test sol[b] == [2.0, 5.0, 5.0] + @test sol[a] == [-1.0] + @test sol[b] == [5.0, 5.0] @test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] end @testset "Heater" begin @@ -1198,5 +1198,5 @@ end ss = structural_simplify(sys) prob = ODEProblem(ss, [theta => 0.0], (0.0, pi)) sol = solve(prob, Tsit5(); dtmax = 0.01) - @test sol[cnt] == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state + @test getp(sol, cnt)(sol) == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state end From d79d49de426ece41ab43df618c7960d88611b6c5 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 10 Sep 2024 16:23:58 -0700 Subject: [PATCH 017/101] Switch MutatingFunctionalAffect from using ComponentArrays to using NamedTuples for heterotyped operation support. --- src/systems/callbacks.jl | 78 +++++++++++++++--------------- src/systems/diffeqs/odesystem.jl | 10 ++-- test/symbolic_events.jl | 81 ++++++++++++++++++++------------ 3 files changed, 94 insertions(+), 75 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index d2657d2ced..34acb5b2ae 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -77,25 +77,29 @@ end `MutatingFunctionalAffect` is a helper for writing affect functions that will compute observed values and ensure that modified values are correctly written back into the system. The affect function `f` needs to have one of four signatures: -* `f(modified::ComponentArray)` if the function only writes values (unknowns or parameters) to the system, -* `f(modified::ComponentArray, observed::ComponentArray)` if the function also reads observed values from the system, -* `f(modified::ComponentArray, observed::ComponentArray, ctx)` if the function needs the user-defined context, -* `f(modified::ComponentArray, observed::ComponentArray, ctx, integrator)` if the function needs the low-level integrator. +* `f(modified::NamedTuple)::NamedTuple` if the function only writes values (unknowns or parameters) to the system, +* `f(modified::NamedTuple, observed::NamedTuple)::NamedTuple` if the function also reads observed values from the system, +* `f(modified::NamedTuple, observed::NamedTuple, ctx)::NamedTuple` if the function needs the user-defined context, +* `f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple` if the function needs the low-level integrator. These will be checked in reverse order (that is, the four-argument version first, than the 3, etc). -The function `f` will be called with `observed` and `modified` `ComponentArray`s that are derived from their respective `NamedTuple` definitions. -Each `NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x` +The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions. +Each declaration`NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x` so the value of `a + b` will be accessible as `observed.x` in `f`. `modified` currently restricts symbolic expressions to only bare variables, so only tuples of the form `(; x = y)` or `(; x)` (which aliases `x` as itself) are allowed. -Both `observed` and `modified` will be automatically populated with the current values of their corresponding expressions on function entry. -The values in `modified` will be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write +The argument NamedTuples (for instance `(;x=y)`) will be populated with the declared values on function entry; if we require `(;x=y)` in `observed` and `y=2`, for example, +then the NamedTuple `(;x=2)` will be passed as `observed` to the affect function `f`. + +The NamedTuple returned from `f` includes the values to be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write MutatingFunctionalAffect(observed=(; x_plus_y = x + y), modified=(; x)) do m, o - m.x = o.x_plus_y + @set! m.x = o.x_plus_y end -The affect function updates the value at `x` in `modified` to be the result of evaluating `x + y` as passed in the observed values. +Where we use Setfield to copy the tuple `m` with a new value for `x`, then return the modified value of `m`. All values updated by the tuple must have names originally declared in +`modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it +in the returned tuple, in which case the associated field will not be updated. """ @kwdef struct MutatingFunctionalAffect f::Any @@ -983,6 +987,18 @@ function unassignable_variables(sys, expr) x -> !any(isequal(x), assignable_syms), written) end +@generated function _generated_writeback(integ, setters::NamedTuple{NS1,<:Tuple}, values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2} + setter_exprs = [] + for name in NS2 + if !(name in NS1) + missing_name = "Tried to write back to $name from affect; only declared states ($NS1) may be written to." + error(missing_name) + end + push!(setter_exprs, :(setters.$name(integ, values.$name))) + end + return :(begin $(setter_exprs...) end) +end + function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; kwargs...) #= Implementation sketch: @@ -1016,7 +1032,6 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; end obs_syms = observed_syms(affect) obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs) - obs_size = size.(obs_exprs) # we will generate a work buffer of a ComponentArray that maps obs_syms to arrays of size obs_size mod_exprs = modified(affect) for mexpr in mod_exprs @@ -1033,8 +1048,6 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; end mod_syms = modified_syms(affect) mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs) - _, mod_og_val_fun = build_explicit_observed_function( - sys, mod_exprs; return_inplace = true) overlapping_syms = intersect(mod_syms, obs_syms) if length(overlapping_syms) > 0 && !affect.skip_checks @@ -1048,31 +1061,20 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; else zeros(sz) end - _, obs_fun = build_explicit_observed_function( + obs_fun = build_explicit_observed_function( sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []); - return_inplace = true) - obs_component_array = ComponentArrays.ComponentArray(NamedTuple{(obs_syms...,)}(mkzero.(obs_size))) + array_type = :tuple) + obs_sym_tuple = (obs_syms...,) # okay so now to generate the stuff to assign it back into the system - # note that we reorder the componentarray to make the views coherent wrt the base array mod_pairs = mod_exprs .=> mod_syms - mod_param_pairs = filter(v -> is_parameter(sys, v[1]), mod_pairs) - mod_unk_pairs = filter(v -> !is_parameter(sys, v[1]), mod_pairs) - _, mod_og_val_fun = build_explicit_observed_function( - sys, reduce(vcat, Symbolics.scalarize.([first.(mod_param_pairs); first.(mod_unk_pairs)]); init = []); - return_inplace = true) - upd_params_fun = setu( - sys, reduce(vcat, Symbolics.scalarize.(first.(mod_param_pairs)); init = [])) - upd_unk_fun = setu( - sys, reduce(vcat, Symbolics.scalarize.(first.(mod_unk_pairs)); init = [])) - - upd_component_array = ComponentArrays.ComponentArray(NamedTuple{([last.(mod_param_pairs); - last.(mod_unk_pairs)]...,)}( - [collect(mkzero(size(e)) for e in first.(mod_param_pairs)); - collect(mkzero(size(e)) for e in first.(mod_unk_pairs))])) - upd_params_view = view(upd_component_array, last.(mod_param_pairs)) - upd_unks_view = view(upd_component_array, last.(mod_unk_pairs)) + mod_names = (mod_syms..., ) + mod_og_val_fun = build_explicit_observed_function( + sys, reduce(vcat, Symbolics.scalarize.(first.(mod_pairs)); init = []); + array_type = :tuple) + upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing save_idxs = get(ic.callback_to_clocks, cb, Int[]) else @@ -1082,13 +1084,13 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; let user_affect = func(affect), ctx = context(affect) function (integ) # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens - mod_og_val_fun(upd_component_array, integ.u, integ.p..., integ.t) + upd_component_array = NamedTuple{mod_names}(mod_og_val_fun(integ.u, integ.p..., integ.t)) # update the observed values - obs_fun(obs_component_array, integ.u, integ.p..., integ.t) + obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(integ.u, integ.p..., integ.t)) # let the user do their thing - if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ) + modvals = if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ) user_affect(upd_component_array, obs_component_array, ctx, integ) elseif applicable(user_affect, upd_component_array, obs_component_array, ctx) user_affect(upd_component_array, obs_component_array, ctx) @@ -1102,9 +1104,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; end # write the new values back to the integrator - upd_params_fun(integ, upd_params_view) - upd_unk_fun(integ, upd_unks_view) - + _generated_writeback(integ, upd_funs, modvals) for idx in save_idxs SciMLBase.save_discretes!(integ, idx) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index daa4321ed0..e99911eec6 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -429,6 +429,7 @@ Options not otherwise specified are: * `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail. * `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist * `drop_expr` is deprecated. +* `array_type`; only used if the output is an array (that is, `!isscalar(ts)`). If `:array`, then it will generate an array, if `:tuple` then it will generate a tuple. """ function build_explicit_observed_function(sys, ts; inputs = nothing, @@ -442,7 +443,8 @@ function build_explicit_observed_function(sys, ts; return_inplace = false, param_only = false, op = Operator, - throw = true) + throw = true, + array_type=:array) if (isscalar = symbolic_type(ts) !== NotSymbolic()) ts = [ts] end @@ -587,12 +589,10 @@ function build_explicit_observed_function(sys, ts; oop_mtkp_wrapper = mtkparams_wrapper end + output_expr = isscalar ? ts[1] : (array_type == :array ? MakeArray(ts, output_type) : MakeTuple(ts)) # Need to keep old method of building the function since it uses `output_type`, # which can't be provided to `build_function` - oop_fn = Func(args, [], - pre(Let(obsexprs, - isscalar ? ts[1] : MakeArray(ts, output_type), - false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr + oop_fn = Func(args, [], pre(Let(obsexprs, output_expr, false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module) if !isscalar diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index dd60b99003..bc455ec06e 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -8,6 +8,7 @@ using ModelingToolkit: SymbolicContinuousCallback, using StableRNGs import SciMLBase using SymbolicIndexingInterface +using Setfield rng = StableRNG(12345) @variables x(t) = 0 @@ -1010,12 +1011,12 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i, c - x.furnace_on = false + @set! x.furnace_on = false end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i, c - x.furnace_on = true + @set! x.furnace_on = true end) @named sys = ODESystem( eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) @@ -1027,12 +1028,12 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i - x.furnace_on = false + @set! x.furnace_on = false end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i - x.furnace_on = true + @set! x.furnace_on = true end) @named sys = ODESystem( eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) @@ -1044,12 +1045,12 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o - x.furnace_on = false + @set! x.furnace_on = false end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o - x.furnace_on = true + @set! x.furnace_on = true end) @named sys = ODESystem( eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) @@ -1061,12 +1062,12 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x - x.furnace_on = false + @set! x.furnace_on = false end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x - x.furnace_on = true + @set! x.furnace_on = true end) @named sys = ODESystem( eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) @@ -1078,14 +1079,14 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x - x.furnace_on = false + @set! x.furnace_on = false end; initialize = ModelingToolkit.MutatingFunctionalAffect(modified = (; temp)) do x - x.temp = 0.2 + @set! x.temp = 0.2 end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, c, i - x.furnace_on = true + @set! x.furnace_on = true end) @named sys = ODESystem( eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) @@ -1107,7 +1108,7 @@ end [temp ~ furnace_off_threshold], ModelingToolkit.MutatingFunctionalAffect( modified = (; furnace_on), observed = (; furnace_on)) do x, o, c, i - x.furnace_on = false + @set! x.furnace_on = false end) @named sys = ODESystem(eqs, t, [temp], params; continuous_events = [furnace_off]) ss = structural_simplify(sys) @@ -1123,7 +1124,7 @@ end [temp ~ furnace_off_threshold], ModelingToolkit.MutatingFunctionalAffect( modified = (; furnace_on, tempsq), observed = (; furnace_on)) do x, o, c, i - x.furnace_on = false + @set! x.furnace_on = false end) @named sys = ODESystem( eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) @@ -1136,18 +1137,32 @@ end [temp ~ furnace_off_threshold], ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on), observed = (; furnace_on, not_actually_here)) do x, o, c, i - x.furnace_on = false + @set! x.furnace_on = false end) @named sys = ODESystem( eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) ss = structural_simplify(sys) @test_throws "refers to missing variable(s)" prob=ODEProblem( ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + + + furnace_off = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_off_threshold], + ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on), + observed = (; furnace_on)) do x, o, c, i + return (;fictional2 = false) + end) + @named sys = ODESystem( + eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) + ss = structural_simplify(sys) + prob=ODEProblem( + ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) + @test_throws "Tried to write back to" solve(prob, Tsit5()) end @testset "Quadrature" begin @variables theta(t) omega(t) - params = @parameters qA=0 qB=0 hA=0 hB=0 cnt=0 + params = @parameters qA=0 qB=0 hA=0 hB=0 cnt::Int=0 eqs = [D(theta) ~ omega omega ~ 1.0] function decoder(oldA, oldB, newA, newB) @@ -1167,31 +1182,35 @@ end end qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], ModelingToolkit.MutatingFunctionalAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c - x.hA = x.qA - x.hB = o.qB - x.qA = 1 - x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) + @set! x.hA = x.qA + @set! x.hB = o.qB + @set! x.qA = 1 + @set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) + x end, affect_neg = ModelingToolkit.MutatingFunctionalAffect( (; qA, hA, hB, cnt), (; qB)) do x, o, c, i - x.hA = x.qA - x.hB = o.qB - x.qA = 0 - x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) + @set! x.hA = x.qA + @set! x.hB = o.qB + @set! x.qA = 0 + @set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) + x end; rootfind = SciMLBase.RightRootFind) qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0], ModelingToolkit.MutatingFunctionalAffect((; qB, hA, hB, cnt), (; qA)) do x, o, i, c - x.hA = o.qA - x.hB = x.qB - x.qB = 1 - x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) + @set! x.hA = o.qA + @set! x.hB = x.qB + @set! x.qB = 1 + @set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) + x end, affect_neg = ModelingToolkit.MutatingFunctionalAffect( (; qB, hA, hB, cnt), (; qA)) do x, o, c, i - x.hA = o.qA - x.hB = x.qB - x.qB = 0 - x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) + @set! x.hA = o.qA + @set! x.hB = x.qB + @set! x.qB = 0 + @set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) + x end; rootfind = SciMLBase.RightRootFind) @named sys = ODESystem( eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt]) From c940e5ed27eeed1cafdb3d2a408a245b31950c3e Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Fri, 13 Sep 2024 09:10:20 -0700 Subject: [PATCH 018/101] Fix support for array forms in the NamedTuple version of MutatingFunctionalAffect --- src/systems/callbacks.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 34acb5b2ae..ad41807ab8 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -1062,7 +1062,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; zeros(sz) end obs_fun = build_explicit_observed_function( - sys, reduce(vcat, Symbolics.scalarize.(obs_exprs); init = []); + sys, Symbolics.scalarize.(obs_exprs); array_type = :tuple) obs_sym_tuple = (obs_syms...,) @@ -1070,7 +1070,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; mod_pairs = mod_exprs .=> mod_syms mod_names = (mod_syms..., ) mod_og_val_fun = build_explicit_observed_function( - sys, reduce(vcat, Symbolics.scalarize.(first.(mod_pairs)); init = []); + sys, Symbolics.scalarize.(first.(mod_pairs)); array_type = :tuple) upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) @@ -1084,7 +1084,8 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; let user_affect = func(affect), ctx = context(affect) function (integ) # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens - upd_component_array = NamedTuple{mod_names}(mod_og_val_fun(integ.u, integ.p..., integ.t)) + modvals = mod_og_val_fun(integ.u, integ.p..., integ.t) + upd_component_array = NamedTuple{mod_names}(modvals) # update the observed values obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(integ.u, integ.p..., integ.t)) From 2206425cae06f57c2985607c1c18d740bb7d439c Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 23 Sep 2024 08:18:16 -0700 Subject: [PATCH 019/101] Remove ComponentArrays dep, cleanup handling of skip_checks --- Project.toml | 1 - src/systems/callbacks.jl | 29 +++++++++++++++-------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 729966fde0..808d68af06 100644 --- a/Project.toml +++ b/Project.toml @@ -9,7 +9,6 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e" Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e" diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index ad41807ab8..cf12b9078b 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -1024,26 +1024,27 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; end obs_exprs = observed(affect) - for oexpr in obs_exprs - invalid_vars = invalid_variables(sys, oexpr) - if length(invalid_vars) > 0 && !affect.skip_checks - error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).") + if !affect.skip_checks + for oexpr in obs_exprs + invalid_vars = invalid_variables(sys, oexpr) + if length(invalid_vars) > 0 + error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).") + end end end obs_syms = observed_syms(affect) obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs) mod_exprs = modified(affect) - for mexpr in mod_exprs - if affect.skip_checks - continue - end - if !is_variable(sys, mexpr) && parameter_index(sys, mexpr) === nothing && !affect.skip_checks - @warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") - end - invalid_vars = unassignable_variables(sys, mexpr) - if length(invalid_vars) > 0 && !affect.skip_checks - error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") + if !affect.skip_checks + for mexpr in mod_exprs + if !is_variable(sys, mexpr) && parameter_index(sys, mexpr) === nothing + @warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") + end + invalid_vars = unassignable_variables(sys, mexpr) + if length(invalid_vars) > 0 + error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") + end end end mod_syms = modified_syms(affect) From 98dcd4ee1ff90a572c273977f780c4cd23f4ce0d Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 23 Sep 2024 09:59:43 -0700 Subject: [PATCH 020/101] Improve detection of writeback values --- src/systems/callbacks.jl | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index cf12b9078b..23b7ccac2c 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -999,6 +999,18 @@ end return :(begin $(setter_exprs...) end) end +function check_assignable(sys, sym) + if symbolic_type(sym) == ScalarSymbolic() + is_variable(sys, sym) || is_parameter(sys, sym) + elseif symbolic_type(sym) == ArraySymbolic() + is_variable(sys, sym) || is_parameter(sys, sym) || all(x -> check_assignable(sys, x), collect(sym)) + elseif sym isa Union{AbstractArray, Tuple} + all(x -> check_assignable(sys, x), sym) + else + false + end +end + function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; kwargs...) #= Implementation sketch: @@ -1038,7 +1050,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; mod_exprs = modified(affect) if !affect.skip_checks for mexpr in mod_exprs - if !is_variable(sys, mexpr) && parameter_index(sys, mexpr) === nothing + if !check_assignable(sys, mexpr) @warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") end invalid_vars = unassignable_variables(sys, mexpr) From 3fd4462d31d0c10b77fae3a0c6678d36f9fe0046 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 23 Sep 2024 10:06:16 -0700 Subject: [PATCH 021/101] Remove ComponentArrays dep --- src/ModelingToolkit.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index f5262a1526..2f57bb1765 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -54,7 +54,6 @@ using Reexport using RecursiveArrayTools import Graphs: SimpleDiGraph, add_edge!, incidence_matrix import BlockArrays: BlockedArray, Block, blocksize, blocksizes -import ComponentArrays using RuntimeGeneratedFunctions using RuntimeGeneratedFunctions: drop_expr From e6ce6ab65fe9e62b5c34f8adc546ceafa91f475f Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Thu, 17 Oct 2024 12:00:44 -0700 Subject: [PATCH 022/101] Rename MutatingFunctionalAffect to ImperativeAffect --- src/systems/callbacks.jl | 74 +++++++++++++++---------------- src/systems/index_cache.jl | 2 +- test/symbolic_events.jl | 90 +++++++++++++++++++------------------- 3 files changed, 83 insertions(+), 83 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 23b7ccac2c..73f2a7a6cf 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -72,9 +72,9 @@ function namespace_affect(affect::FunctionalAffect, s) end """ - MutatingFunctionalAffect(f::Function; modified::NamedTuple, observed::NamedTuple, ctx) + ImperativeAffect(f::Function; modified::NamedTuple, observed::NamedTuple, ctx) -`MutatingFunctionalAffect` is a helper for writing affect functions that will compute observed values and +`ImperativeAffect` is a helper for writing affect functions that will compute observed values and ensure that modified values are correctly written back into the system. The affect function `f` needs to have one of four signatures: * `f(modified::NamedTuple)::NamedTuple` if the function only writes values (unknowns or parameters) to the system, @@ -93,7 +93,7 @@ then the NamedTuple `(;x=2)` will be passed as `observed` to the affect function The NamedTuple returned from `f` includes the values to be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write - MutatingFunctionalAffect(observed=(; x_plus_y = x + y), modified=(; x)) do m, o + ImperativeAffect(observed=(; x_plus_y = x + y), modified=(; x)) do m, o @set! m.x = o.x_plus_y end @@ -101,7 +101,7 @@ Where we use Setfield to copy the tuple `m` with a new value for `x`, then retur `modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it in the returned tuple, in which case the associated field will not be updated. """ -@kwdef struct MutatingFunctionalAffect +@kwdef struct ImperativeAffect f::Any obs::Vector obs_syms::Vector{Symbol} @@ -111,50 +111,50 @@ in the returned tuple, in which case the associated field will not be updated. skip_checks::Bool end -function MutatingFunctionalAffect(f::Function; +function ImperativeAffect(f::Function; observed::NamedTuple = NamedTuple{()}(()), modified::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false) - MutatingFunctionalAffect(f, + ImperativeAffect(f, collect(values(observed)), collect(keys(observed)), collect(values(modified)), collect(keys(modified)), ctx, skip_checks) end -function MutatingFunctionalAffect(f::Function, modified::NamedTuple; +function ImperativeAffect(f::Function, modified::NamedTuple; observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks=false) - MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) + ImperativeAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) end -function MutatingFunctionalAffect( +function ImperativeAffect( f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks=false) - MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) + ImperativeAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) end -function MutatingFunctionalAffect( +function ImperativeAffect( f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks=false) - MutatingFunctionalAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) + ImperativeAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) end -function Base.show(io::IO, mfa::MutatingFunctionalAffect) +function Base.show(io::IO, mfa::ImperativeAffect) obs_vals = join(map((ob,nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ") mod_vals = join(map((md,nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ") affect = mfa.f - print(io, "MutatingFunctionalAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)") + print(io, "ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)") end -func(f::MutatingFunctionalAffect) = f.f -context(a::MutatingFunctionalAffect) = a.ctx -observed(a::MutatingFunctionalAffect) = a.obs -observed_syms(a::MutatingFunctionalAffect) = a.obs_syms -discretes(a::MutatingFunctionalAffect) = filter(ModelingToolkit.isparameter, a.modified) -modified(a::MutatingFunctionalAffect) = a.modified -modified_syms(a::MutatingFunctionalAffect) = a.mod_syms +func(f::ImperativeAffect) = f.f +context(a::ImperativeAffect) = a.ctx +observed(a::ImperativeAffect) = a.obs +observed_syms(a::ImperativeAffect) = a.obs_syms +discretes(a::ImperativeAffect) = filter(ModelingToolkit.isparameter, a.modified) +modified(a::ImperativeAffect) = a.modified +modified_syms(a::ImperativeAffect) = a.mod_syms -function Base.:(==)(a1::MutatingFunctionalAffect, a2::MutatingFunctionalAffect) +function Base.:(==)(a1::ImperativeAffect, a2::ImperativeAffect) isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) && isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) && isequal(a1.ctx, a2.ctx) end -function Base.hash(a::MutatingFunctionalAffect, s::UInt) +function Base.hash(a::ImperativeAffect, s::UInt) s = hash(a.f, s) s = hash(a.obs, s) s = hash(a.obs_syms, s) @@ -163,8 +163,8 @@ function Base.hash(a::MutatingFunctionalAffect, s::UInt) hash(a.ctx, s) end -function namespace_affect(affect::MutatingFunctionalAffect, s) - MutatingFunctionalAffect(func(affect), +function namespace_affect(affect::ImperativeAffect, s) + ImperativeAffect(func(affect), namespace_expr.(observed(affect), (s,)), observed_syms(affect), renamespace.((s,), modified(affect)), @@ -174,7 +174,7 @@ function namespace_affect(affect::MutatingFunctionalAffect, s) end function has_functional_affect(cb) - (affects(cb) isa FunctionalAffect || affects(cb) isa MutatingFunctionalAffect) + (affects(cb) isa FunctionalAffect || affects(cb) isa ImperativeAffect) end #################################### continuous events ##################################### @@ -219,7 +219,7 @@ Affects (i.e. `affect` and `affect_neg`) can be specified as either: + `read_parameters` is a vector of the parameters that are *used* by `f!`. Their indices are passed to `f` in `p` similarly to the indices of `unknowns` passed in `u`. + `modified_parameters` is a vector of the parameters that are *modified* by `f!`. Note that a parameter will not appear in `p` if it only appears in `modified_parameters`; it must appear in both `parameters` and `modified_parameters` if it is used in the affect definition. + `ctx` is a user-defined context object passed to `f!` when invoked. This value is aliased for each problem. -* A [`MutatingFunctionalAffect`](@ref); refer to its documentation for details. +* A [`ImperativeAffect`](@ref); refer to its documentation for details. Callbacks that impact a DAE are applied, then the DAE is reinitialized using `reinitializealg` (which defaults to `SciMLBase.CheckInit`). This reinitialization algorithm ensures that the DAE is satisfied after the callback runs. The default value of `CheckInit` will simply validate @@ -228,10 +228,10 @@ initialization. """ struct SymbolicContinuousCallback eqs::Vector{Equation} - initialize::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect} - finalize::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect} - affect::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect} - affect_neg::Union{Vector{Equation}, FunctionalAffect, MutatingFunctionalAffect, Nothing} + initialize::Union{Vector{Equation}, FunctionalAffect, ImperativeAffect} + finalize::Union{Vector{Equation}, FunctionalAffect, ImperativeAffect} + affect::Union{Vector{Equation}, FunctionalAffect, ImperativeAffect} + affect_neg::Union{Vector{Equation}, FunctionalAffect, ImperativeAffect, Nothing} rootfind::SciMLBase.RootfindOpt reinitializealg::SciMLBase.DAEInitializationAlgorithm function SymbolicContinuousCallback(; @@ -390,7 +390,7 @@ reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) = namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af] namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s) -namespace_affects(af::MutatingFunctionalAffect, s) = namespace_affect(af, s) +namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s) namespace_affects(::Nothing, s) = nothing function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback @@ -460,7 +460,7 @@ scalarize_affects(affects) = scalarize(affects) scalarize_affects(affects::Tuple) = FunctionalAffect(affects...) scalarize_affects(affects::NamedTuple) = FunctionalAffect(; affects...) scalarize_affects(affects::FunctionalAffect) = affects -scalarize_affects(affects::MutatingFunctionalAffect) = affects +scalarize_affects(affects::ImperativeAffect) = affects SymbolicDiscreteCallback(p::Pair) = SymbolicDiscreteCallback(p[1], p[2]) SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback) = cb # passthrough @@ -468,7 +468,7 @@ SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback) = cb # passthrough function Base.show(io::IO, db::SymbolicDiscreteCallback) println(io, "condition: ", db.condition) println(io, "affects:") - if db.affects isa FunctionalAffect || db.affects isa MutatingFunctionalAffect + if db.affects isa FunctionalAffect || db.affects isa ImperativeAffect # TODO println(io, " ", db.affects) else @@ -1011,7 +1011,7 @@ function check_assignable(sys, sym) end end -function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; kwargs...) +function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs...) #= Implementation sketch: generate observed function (oop), should save to a component array under obs_syms @@ -1113,7 +1113,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; elseif applicable(user_affect, upd_component_array) user_affect(upd_component_array) else - @error "User affect function $user_affect needs to implement one of the supported MutatingFunctionalAffect callback forms; see the MutatingFunctionalAffect docstring for more details" + @error "User affect function $user_affect needs to implement one of the supported ImperativeAffect callback forms; see the ImperativeAffect docstring for more details" user_affect(upd_component_array, obs_component_array, integ, ctx) # this WILL error but it'll give a more sensible message end @@ -1127,7 +1127,7 @@ function compile_user_affect(affect::MutatingFunctionalAffect, cb, sys, dvs, ps; end end -function compile_affect(affect::Union{FunctionalAffect, MutatingFunctionalAffect}, cb, sys, dvs, ps; kwargs...) +function compile_affect(affect::Union{FunctionalAffect, ImperativeAffect}, cb, sys, dvs, ps; kwargs...) compile_user_affect(affect, cb, sys, dvs, ps; kwargs...) end diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 55d819990b..bd295055fa 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -126,7 +126,7 @@ function IndexCache(sys::AbstractSystem) for affect in affs if affect isa Equation is_parameter(sys, affect.lhs) && push!(discs, affect.lhs) - elseif affect isa FunctionalAffect || affect isa MutatingFunctionalAffect + elseif affect isa FunctionalAffect || affect isa ImperativeAffect union!(discs, unwrap.(discretes(affect))) else error("Unhandled affect type $(typeof(affect))") diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index bc455ec06e..b301b65650 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -228,10 +228,10 @@ affect_neg = [x ~ 1] @test e[].affect == affect end -@testset "MutatingFunctionalAffect constructors" begin +@testset "ImperativeAffect constructors" begin fmfa(o, x, i, c) = nothing - m = ModelingToolkit.MutatingFunctionalAffect(fmfa) - @test m isa ModelingToolkit.MutatingFunctionalAffect + m = ModelingToolkit.ImperativeAffect(fmfa) + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test m.obs == [] @test m.obs_syms == [] @@ -239,8 +239,8 @@ end @test m.mod_syms == [] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (;)) - @test m isa ModelingToolkit.MutatingFunctionalAffect + m = ModelingToolkit.ImperativeAffect(fmfa, (;)) + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test m.obs == [] @test m.obs_syms == [] @@ -248,8 +248,8 @@ end @test m.mod_syms == [] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; x)) - @test m isa ModelingToolkit.MutatingFunctionalAffect + m = ModelingToolkit.ImperativeAffect(fmfa, (; x)) + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test isequal(m.obs, []) @test m.obs_syms == [] @@ -257,8 +257,8 @@ end @test m.mod_syms == [:x] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; y = x)) - @test m isa ModelingToolkit.MutatingFunctionalAffect + m = ModelingToolkit.ImperativeAffect(fmfa, (; y = x)) + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test isequal(m.obs, []) @test m.obs_syms == [] @@ -266,8 +266,8 @@ end @test m.mod_syms == [:y] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect(fmfa; observed = (; y = x)) - @test m isa ModelingToolkit.MutatingFunctionalAffect + m = ModelingToolkit.ImperativeAffect(fmfa; observed = (; y = x)) + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test isequal(m.obs, [x]) @test m.obs_syms == [:y] @@ -275,8 +275,8 @@ end @test m.mod_syms == [] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified = (; x)) - @test m isa ModelingToolkit.MutatingFunctionalAffect + m = ModelingToolkit.ImperativeAffect(fmfa; modified = (; x)) + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test isequal(m.obs, []) @test m.obs_syms == [] @@ -284,8 +284,8 @@ end @test m.mod_syms == [:x] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect(fmfa; modified = (; y = x)) - @test m isa ModelingToolkit.MutatingFunctionalAffect + m = ModelingToolkit.ImperativeAffect(fmfa; modified = (; y = x)) + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test isequal(m.obs, []) @test m.obs_syms == [] @@ -293,8 +293,8 @@ end @test m.mod_syms == [:y] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; x), (; x)) - @test m isa ModelingToolkit.MutatingFunctionalAffect + m = ModelingToolkit.ImperativeAffect(fmfa, (; x), (; x)) + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test isequal(m.obs, [x]) @test m.obs_syms == [:x] @@ -302,8 +302,8 @@ end @test m.mod_syms == [:x] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; y = x), (; y = x)) - @test m isa ModelingToolkit.MutatingFunctionalAffect + m = ModelingToolkit.ImperativeAffect(fmfa, (; y = x), (; y = x)) + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test isequal(m.obs, [x]) @test m.obs_syms == [:y] @@ -311,9 +311,9 @@ end @test m.mod_syms == [:y] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect( + m = ModelingToolkit.ImperativeAffect( fmfa; modified = (; y = x), observed = (; y = x)) - @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test isequal(m.obs, [x]) @test m.obs_syms == [:y] @@ -321,9 +321,9 @@ end @test m.mod_syms == [:y] @test m.ctx === nothing - m = ModelingToolkit.MutatingFunctionalAffect( + m = ModelingToolkit.ImperativeAffect( fmfa; modified = (; y = x), observed = (; y = x), ctx = 3) - @test m isa ModelingToolkit.MutatingFunctionalAffect + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test isequal(m.obs, [x]) @test m.obs_syms == [:y] @@ -331,8 +331,8 @@ end @test m.mod_syms == [:y] @test m.ctx === 3 - m = ModelingToolkit.MutatingFunctionalAffect(fmfa, (; x), (; x), 3) - @test m isa ModelingToolkit.MutatingFunctionalAffect + m = ModelingToolkit.ImperativeAffect(fmfa, (; x), (; x), 3) + @test m isa ModelingToolkit.ImperativeAffect @test m.f == fmfa @test isequal(m.obs, [x]) @test m.obs_syms == [:x] @@ -1010,12 +1010,12 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i, c + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i, c @set! x.furnace_on = false end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i, c + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i, c @set! x.furnace_on = true end) @named sys = ODESystem( @@ -1027,12 +1027,12 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i @set! x.furnace_on = false end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, i + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i @set! x.furnace_on = true end) @named sys = ODESystem( @@ -1044,12 +1044,12 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o @set! x.furnace_on = false end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o @set! x.furnace_on = true end) @named sys = ODESystem( @@ -1061,12 +1061,12 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x @set! x.furnace_on = false end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x @set! x.furnace_on = true end) @named sys = ODESystem( @@ -1078,14 +1078,14 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x @set! x.furnace_on = false - end; initialize = ModelingToolkit.MutatingFunctionalAffect(modified = (; temp)) do x + end; initialize = ModelingToolkit.ImperativeAffect(modified = (; temp)) do x @set! x.temp = 0.2 end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on)) do x, o, c, i + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i @set! x.furnace_on = true end) @named sys = ODESystem( @@ -1097,7 +1097,7 @@ end @test all(sol[temp][sol.t .!= 0.0] .<= 0.79) && all(sol[temp][sol.t .!= 0.0] .>= 0.2) end -@testset "MutatingFunctionalAffect errors and warnings" begin +@testset "ImperativeAffect errors and warnings" begin @variables temp(t) params = @parameters furnace_on_threshold=0.5 furnace_off_threshold=0.7 furnace_power=1.0 leakage=0.1 furnace_on::Bool=false eqs = [ @@ -1106,7 +1106,7 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect( + ModelingToolkit.ImperativeAffect( modified = (; furnace_on), observed = (; furnace_on)) do x, o, c, i @set! x.furnace_on = false end) @@ -1122,7 +1122,7 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect( + ModelingToolkit.ImperativeAffect( modified = (; furnace_on, tempsq), observed = (; furnace_on)) do x, o, c, i @set! x.furnace_on = false end) @@ -1135,7 +1135,7 @@ end @parameters not_actually_here furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on), + ModelingToolkit.ImperativeAffect(modified = (; furnace_on), observed = (; furnace_on, not_actually_here)) do x, o, c, i @set! x.furnace_on = false end) @@ -1148,7 +1148,7 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.MutatingFunctionalAffect(modified = (; furnace_on), + ModelingToolkit.ImperativeAffect(modified = (; furnace_on), observed = (; furnace_on)) do x, o, c, i return (;fictional2 = false) end) @@ -1181,14 +1181,14 @@ end end end qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], - ModelingToolkit.MutatingFunctionalAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c + ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c @set! x.hA = x.qA @set! x.hB = o.qB @set! x.qA = 1 @set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) x end, - affect_neg = ModelingToolkit.MutatingFunctionalAffect( + affect_neg = ModelingToolkit.ImperativeAffect( (; qA, hA, hB, cnt), (; qB)) do x, o, c, i @set! x.hA = x.qA @set! x.hB = o.qB @@ -1197,14 +1197,14 @@ end x end; rootfind = SciMLBase.RightRootFind) qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0], - ModelingToolkit.MutatingFunctionalAffect((; qB, hA, hB, cnt), (; qA)) do x, o, i, c + ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA)) do x, o, i, c @set! x.hA = o.qA @set! x.hB = x.qB @set! x.qB = 1 @set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) x end, - affect_neg = ModelingToolkit.MutatingFunctionalAffect( + affect_neg = ModelingToolkit.ImperativeAffect( (; qB, hA, hB, cnt), (; qA)) do x, o, c, i @set! x.hA = o.qA @set! x.hB = x.qB From 54bb95af05f6212375ba3a9e2717147b45246ad9 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Thu, 17 Oct 2024 17:02:04 -0700 Subject: [PATCH 023/101] Fix tests --- src/systems/callbacks.jl | 19 +++++++++++-------- test/symbolic_events.jl | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 73f2a7a6cf..96758ccf04 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -781,21 +781,24 @@ function generate_single_rootfinding_callback( end end + user_initfun = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing initfn = let save_idxs = save_idxs function (cb, u, t, integrator) + user_initfun(cb, u, t, integrator) for idx in save_idxs SciMLBase.save_discretes!(integrator, idx) end end end else - initfn = SciMLBase.INITIALIZE_DEFAULT + initfn = user_initfun end + return ContinuousCallback( cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind, - initialize = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i), + initialize = initfn, finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i), initializealg = reinitialization_alg(cb)) end @@ -878,8 +881,8 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs) eq_aff = affects(cb) eq_neg_aff = affect_negs(cb) affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) - function compile_optional_affect(aff) - if isnothing(aff) + function compile_optional_affect(aff, default=nothing) + if isnothing(aff) || aff==default return nothing else return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) @@ -890,8 +893,8 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs) else affect_neg = compile_optional_affect(eq_neg_aff) end - initialize = compile_optional_affect(initialize_affects(cb)) - finalize = compile_optional_affect(finalize_affects(cb)) + initialize = compile_optional_affect(initialize_affects(cb), NULL_AFFECT) + finalize = compile_optional_affect(finalize_affects(cb), NULL_AFFECT) (affect = affect, affect_neg = affect_neg, initialize = initialize, finalize = finalize) end @@ -1097,11 +1100,11 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. let user_affect = func(affect), ctx = context(affect) function (integ) # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens - modvals = mod_og_val_fun(integ.u, integ.p..., integ.t) + modvals = mod_og_val_fun(integ.u, integ.p, integ.t) upd_component_array = NamedTuple{mod_names}(modvals) # update the observed values - obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(integ.u, integ.p..., integ.t)) + obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(integ.u, integ.p, integ.t)) # let the user do their thing modvals = if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index b301b65650..c2c26aae7f 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1219,3 +1219,17 @@ end sol = solve(prob, Tsit5(); dtmax = 0.01) @test getp(sol, cnt)(sol) == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state end + + + +import RuntimeGeneratedFunctions +function (f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id})(args::Vararg{Any, N}) where {N, argnames, cache_tag, context_tag, id} + try + RuntimeGeneratedFunctions.generated_callfunc(f, args...) + catch e + @error "Caught error in RuntimeGeneratedFunction; source code follows" + func_expr = Expr(:->, Expr(:tuple, argnames...), RuntimeGeneratedFunctions._lookup_body(cache_tag, id)) + @show func_expr + rethrow(e) + end +end From 065490971ad5bb91718f0cd2a61463060b962794 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 22 Oct 2024 18:36:18 -0700 Subject: [PATCH 024/101] Document ImperativeEffect and the SymbolicContinousCallback changes --- docs/Project.toml | 1 + docs/src/basics/Events.md | 198 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 199 insertions(+) diff --git a/docs/Project.toml b/docs/Project.toml index 078df7d696..15f1a6a7f0 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -14,6 +14,7 @@ OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StochasticDiffEq = "789caeaf-c7a9-5a7d-9973-96adeb23e2a0" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index f425fdce5b..9d3ba30780 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -378,3 +378,201 @@ sol.ps[c] # sol[c] will error, since `c` is not a timeseries value ``` It can be seen that the timeseries for `c` is not saved. + + +## [(Experimental) Imperative affects](@id imp_affects) +The `ImperativeAffect` can be used as an alternative to the aforementioned functional affect form. Note +that `ImperativeAffect` is still experimental; to emphasize this, we do not export it and it should be +included as `ModelingToolkit.ImperativeAffect`. It abstracts over how values are written back to the +system, simplifying the definitions and (in the future) allowing assignments back to observed values +by solving the nonlinear reinitialization problem afterwards. + +We will use two examples to describe `ImperativeAffect`: a simple heater and a quadrature encoder. +These examples will also demonstrate advanced usage of `ModelingToolkit.SymbolicContinousCallback`, +the low-level interface that the aforementioned tuple form converts into and allows control over the +exact SciMLCallbacks event that is generated for a continous event. + +### [Heater](@id heater_events) +Bang-bang control of a heater connected to a leaky plant requires hysteresis in order to prevent control oscillation. + +```@example events +@variables temp(t) +params = @parameters furnace_on_threshold=0.5 furnace_off_threshold=0.7 furnace_power=1.0 leakage=0.1 furnace_on(t)::Bool=false +eqs = [ + D(temp) ~ furnace_on * furnace_power - temp^2 * leakage +] +``` +Our plant is simple. We have a heater that's turned on and off by the clocked parameter `furnace_on` +which adds `furnace_power` forcing to the system when enabled. We then leak heat porportional to `leakage` +as a function of the square of the current temperature. + +We need a controller with hysteresis to conol the plant. We wish the furnace to turn on when the temperature +is below `furnace_on_threshold` and off when above `furnace_off_threshold`, while maintaining its current state +in between. To do this, we create two continous callbacks: +```@example events +using Setfield +furnace_disable = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_off_threshold], + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i, c + @set! x.furnace_on = false + end) +furnace_enable = ModelingToolkit.SymbolicContinuousCallback( + [temp ~ furnace_on_threshold], + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i, c + @set! x.furnace_on = true + end) +``` +We're using the explicit form of `SymbolicContinuousCallback` here, though +so far we aren't using anything that's not possible with the implicit interface. +You can also write +```julia +[temp ~ furnace_off_threshold] => ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i, c + @set! x.furnace_on = false +end +``` +and it would work the same. + +The `ImperativeAffect` is the larger change in this example. `ImperativeAffect` has the constructor signature +```julia + ImperativeAffect(f::Function; modified::NamedTuple, observed::NamedTuple, ctx) +``` +that accepts the function to call, a named tuple of both the names of and symbolic values representing +values in the system to be modified, a named tuple of the values that are merely observed (that is, used from +the system but not modified), and a context that's passed to the affect function. + +In our example, each event merely changes whether the furnace is on or off. Accordingly, we pass a `modified` tuple +`(; furnace_on)` (creating a `NamedTuple` equivalent to `(furnace_on = furnace_on)`). `ImperativeAffect` will then +evaluate this before calling our function to fill out all of the numerical values, then apply them back to the system +once our affect function returns. Furthermore, it will check that it is possible to do this assignment. + +The function given to `ImperativeAffect` needs to have one of four signatures, checked in this order: +* `f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple` if the function needs the low-level integrator, +* `f(modified::NamedTuple, observed::NamedTuple, ctx)::NamedTuple` if the function needs the user-defined context, +* `f(modified::NamedTuple, observed::NamedTuple)::NamedTuple` if the function also reads observed values from the system, +* `f(modified::NamedTuple)::NamedTuple` if the function only writes values (unknowns or parameters) to the system. +The `do` block in the example implicitly constructs said function inline. For exposition, we use the full version (e.g. `x, o, i, c`) but this could be simplified to merely `x`. + +The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions. +In our example, if `furnace_on` is `false`, then the value of the `x` that's passed in as `modified` will be `(furnace_on = false)`. +The modified values should be passed out in the same format: to set `furnace_on` to `true` we need to return a tuple `(furnace_on = true)`. +We use Setfield to do this in the example, recreating the result tuple before returning it. + +Accordingly, we can now interpret the `ImperativeAffect` definitions to mean that when `temp = furnace_off_threshold` we +will write `furnace_on = false` back to the system, and when `temp = furnace_on_threshold` we will write `furnace_on = true` back +to the system. + +```@example events +@named sys = ODESystem( + eqs, t, [temp], params; continuous_events = [furnace_disable, furnace_enable]) +ss = structural_simplify(sys) +prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 10.0)) +sol = solve(prob, Tsit5()) +plot(sol) +hline!([sol.ps[furnace_off_threshold], sol.ps[furnace_on_threshold]], l = (:black, 1), primary = false) +``` + +Here we see exactly the desired hysteresis. The heater starts on until the temperature hits +`furnace_off_threshold`. The temperature then bleeds away until `furnace_on_threshold` at which +point the furnace turns on again until `furnace_off_threshold` and so on and so forth. The controller +is effectively regulating the temperature of the plant. + +### [Quadrature Encoder](@id quadrature) +For a more complex application we'll look at modeling a quadrature encoder attached to a shaft spinning at a constant speed. +Traditionally, a quadrature encoder is built out of a code wheel that interrupts the sensors at constant intervals and two sensors slightly out of phase with one another. +A state machine can take the pattern of pulses produced by the two sensors and determine the number of steps that the shaft has spun. The state machine takes the new value +from each sensor and the old values and decodes them into the direction that the wheel has spun in this step. + +```@example events + @variables theta(t) omega(t) + params = @parameters qA=0 qB=0 hA=0 hB=0 cnt::Int=0 + eqs = [D(theta) ~ omega + omega ~ 1.0] +``` +Our continous-time system is extremely simple. We have two states, `theta` for the angle of the shaft +and `omega` for the rate at which it's spinning. We then have parameters for the state machine `qA, qB, hA, hB` +and a step count `cnt`. + +We'll then implement the decoder as a simple Julia function. +```@example events + function decoder(oldA, oldB, newA, newB) + state = (oldA, oldB, newA, newB) + if state == (0, 0, 1, 0) || state == (1, 0, 1, 1) || state == (1, 1, 0, 1) || + state == (0, 1, 0, 0) + return 1 + elseif state == (0, 0, 0, 1) || state == (0, 1, 1, 1) || state == (1, 1, 1, 0) || + state == (1, 0, 0, 0) + return -1 + elseif state == (0, 0, 0, 0) || state == (0, 1, 0, 1) || state == (1, 0, 1, 0) || + state == (1, 1, 1, 1) + return 0 + else + return 0 # err is interpreted as no movement + end + end +``` +Based on the current and old state, this function will return 1 if the wheel spun in the positive direction, +-1 if in the negative, and 0 otherwise. + +The encoder state advances when the occlusion begins or ends. We model the +code wheel as simply detecting when `cos(100*theta)` is 0; if we're at a positive +edge of the 0 crossing, then we interpret that as occlusion (so the discrete `qA` goes to 1). Otherwise, if `cos` is +going negative, we interpret that as lack of occlusion (so the discrete goes to 0). The decoder function is +then invoked to update the count with this new information. + +We can implement this in one of two ways: using edge sign detection or right root finding. For exposition, we +will implement each sensor differently. + +For sensor A, we're using the edge detction method. By providing a different affect to `SymbolicContinuousCallback`'s +`affect_neg` argument, we can specify different behaviour for the negative crossing vs. the positive crossing of the root. +In our encoder, we interpret this as occlusion or nonocclusion of the sensor, update the internal state, and tick the decoder. +```@example events + qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], + ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c + @set! x.hA = x.qA + @set! x.hB = o.qB + @set! x.qA = 1 + @set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) + x + end, + affect_neg = ModelingToolkit.ImperativeAffect( + (; qA, hA, hB, cnt), (; qB)) do x, o, c, i + @set! x.hA = x.qA + @set! x.hB = o.qB + @set! x.qA = 0 + @set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) + x + end) +``` + +The other way we can implement a sensor is by changing the root find. +Normally, we use left root finding; the affect will be invoked instantaneously before +the root is crossed. This makes it trickier to figure out what the new state is. +Instead, we can use right root finding: + +```@example events + qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0], + ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA, theta)) do x, o, i, c + @set! x.hA = o.qA + @set! x.hB = x.qB + @set! x.qB = clamp(sign(cos(100 * o.theta - π / 2)), 0.0, 1.0) + @set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) + x + end; rootfind = SciMLBase.RightRootFind) +``` +Here, sensor B is located `π / 2` behind sensor A in angular space, so we're adjusting our +trigger function accordingly. We here ask for right root finding on the callback, so we know +that the value of said function will have the "new" sign rather than the old one. Thus, we can +determine the new state of the sensor from the sign of the indicator function evaluated at the +affect activation point, with -1 mapped to 0. + +We can now simulate the encoder. +```@example events + @named sys = ODESystem( + eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt]) + ss = structural_simplify(sys) + prob = ODEProblem(ss, [theta => 0.0], (0.0, pi)) + sol = solve(prob, Tsit5(); dtmax = 0.01) + sol.ps[cnt] +``` +`cos(100*theta)` will have 200 crossings in the half rotation we've gone through, so the encoder would notionally count 200 steps. +Our encoder counts 198 steps (it loses one step to initialization and one step due to the final state falling squarely on an edge). \ No newline at end of file From aecd59bfea213d5e0650057135afe7eb6262b988 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 22 Oct 2024 18:40:34 -0700 Subject: [PATCH 025/101] Formatter --- src/systems/callbacks.jl | 164 +++++++++++++++++++++++---------------- test/symbolic_events.jl | 20 ++--- 2 files changed, 107 insertions(+), 77 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 96758ccf04..e7198817b1 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -116,29 +116,33 @@ function ImperativeAffect(f::Function; modified::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false) - ImperativeAffect(f, + ImperativeAffect(f, collect(values(observed)), collect(keys(observed)), - collect(values(modified)), collect(keys(modified)), + collect(values(modified)), collect(keys(modified)), ctx, skip_checks) end function ImperativeAffect(f::Function, modified::NamedTuple; - observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks=false) - ImperativeAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) + observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false) + ImperativeAffect( + f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) end function ImperativeAffect( - f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks=false) - ImperativeAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) + f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks = false) + ImperativeAffect( + f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) end function ImperativeAffect( - f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks=false) - ImperativeAffect(f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) + f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks = false) + ImperativeAffect( + f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) end -function Base.show(io::IO, mfa::ImperativeAffect) - obs_vals = join(map((ob,nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ") - mod_vals = join(map((md,nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ") +function Base.show(io::IO, mfa::ImperativeAffect) + obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ") + mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ") affect = mfa.f - print(io, "ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)") + print(io, + "ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)") end func(f::ImperativeAffect) = f.f context(a::ImperativeAffect) = a.ctx @@ -234,15 +238,16 @@ struct SymbolicContinuousCallback affect_neg::Union{Vector{Equation}, FunctionalAffect, ImperativeAffect, Nothing} rootfind::SciMLBase.RootfindOpt reinitializealg::SciMLBase.DAEInitializationAlgorithm - function SymbolicContinuousCallback(; - eqs::Vector{Equation}, - affect = NULL_AFFECT, - affect_neg = affect, - rootfind = SciMLBase.LeftRootFind, - initialize=NULL_AFFECT, - finalize=NULL_AFFECT, - reinitializealg=SciMLBase.CheckInit()) - new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg) + function SymbolicContinuousCallback(; + eqs::Vector{Equation}, + affect = NULL_AFFECT, + affect_neg = affect, + rootfind = SciMLBase.LeftRootFind, + initialize = NULL_AFFECT, + finalize = NULL_AFFECT, + reinitializealg = SciMLBase.CheckInit()) + new(eqs, initialize, finalize, make_affect(affect), + make_affect(affect_neg), rootfind, reinitializealg) end # Default affect to nothing end make_affect(affect) = affect @@ -250,8 +255,8 @@ make_affect(affect::Tuple) = FunctionalAffect(affect...) make_affect(affect::NamedTuple) = FunctionalAffect(; affect...) function Base.:(==)(e1::SymbolicContinuousCallback, e2::SymbolicContinuousCallback) - isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) && - isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) && + isequal(e1.eqs, e2.eqs) && isequal(e1.affect, e2.affect) && + isequal(e1.initialize, e2.initialize) && isequal(e1.finalize, e2.finalize) && isequal(e1.affect_neg, e2.affect_neg) && isequal(e1.rootfind, e2.rootfind) end Base.isempty(cb::SymbolicContinuousCallback) = isempty(cb.eqs) @@ -266,10 +271,9 @@ function Base.hash(cb::SymbolicContinuousCallback, s::UInt) hash(cb.rootfind, s) end - function Base.show(io::IO, cb::SymbolicContinuousCallback) indent = get(io, :indent, 0) - iio = IOContext(io, :indent => indent+1) + iio = IOContext(io, :indent => indent + 1) print(io, "SymbolicContinuousCallback(") print(iio, "Equations:") show(iio, equations(cb)) @@ -298,7 +302,7 @@ end function Base.show(io::IO, mime::MIME"text/plain", cb::SymbolicContinuousCallback) indent = get(io, :indent, 0) - iio = IOContext(io, :indent => indent+1) + iio = IOContext(io, :indent => indent + 1) println(io, "SymbolicContinuousCallback:") println(iio, "Equations:") show(iio, mime, equations(cb)) @@ -338,14 +342,18 @@ end # wrap eq in vector SymbolicContinuousCallback(p::Pair) = SymbolicContinuousCallback(p[1], p[2]) SymbolicContinuousCallback(cb::SymbolicContinuousCallback) = cb # passthrough function SymbolicContinuousCallback(eqs::Equation, affect = NULL_AFFECT; - affect_neg = affect, rootfind = SciMLBase.LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT) + affect_neg = affect, rootfind = SciMLBase.LeftRootFind, + initialize = NULL_AFFECT, finalize = NULL_AFFECT) SymbolicContinuousCallback( - eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize=initialize, finalize=finalize) + eqs = [eqs], affect = affect, affect_neg = affect_neg, rootfind = rootfind, + initialize = initialize, finalize = finalize) end function SymbolicContinuousCallback(eqs::Vector{Equation}, affect = NULL_AFFECT; - affect_neg = affect, rootfind = SciMLBase.LeftRootFind, initialize = NULL_AFFECT, finalize = NULL_AFFECT) + affect_neg = affect, rootfind = SciMLBase.LeftRootFind, + initialize = NULL_AFFECT, finalize = NULL_AFFECT) SymbolicContinuousCallback( - eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind, initialize=initialize, finalize=finalize) + eqs = eqs, affect = affect, affect_neg = affect_neg, rootfind = rootfind, + initialize = initialize, finalize = finalize) end SymbolicContinuousCallbacks(cb::SymbolicContinuousCallback) = [cb] @@ -385,8 +393,10 @@ function finalize_affects(cbs::Vector{SymbolicContinuousCallback}) end reinitialization_alg(cb::SymbolicContinuousCallback) = cb.reinitializealg -reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) = - mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) +function reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) + mapreduce( + reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) +end namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af] namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s) @@ -436,7 +446,8 @@ struct SymbolicDiscreteCallback affects::Any reinitializealg::SciMLBase.DAEInitializationAlgorithm - function SymbolicDiscreteCallback(condition, affects = NULL_AFFECT, reinitializealg=SciMLBase.CheckInit()) + function SymbolicDiscreteCallback( + condition, affects = NULL_AFFECT, reinitializealg = SciMLBase.CheckInit()) c = scalarize_condition(condition) a = scalarize_affects(affects) new(c, a, reinitializealg) @@ -498,8 +509,10 @@ function affects(cbs::Vector{SymbolicDiscreteCallback}) end reinitialization_alg(cb::SymbolicDiscreteCallback) = cb.reinitializealg -reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback}) = - mapreduce(reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) +function reinitialization_algs(cbs::Vector{SymbolicDiscreteCallback}) + mapreduce( + reinitialization_alg, vcat, cbs, init = SciMLBase.DAEInitializationAlgorithm[]) +end function namespace_callback(cb::SymbolicDiscreteCallback, s)::SymbolicDiscreteCallback af = affects(cb) @@ -781,7 +794,8 @@ function generate_single_rootfinding_callback( end end - user_initfun = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : (c, u, t, i) -> affect_function.initialize(i) + user_initfun = isnothing(affect_function.initialize) ? SciMLBase.INITIALIZE_DEFAULT : + (c, u, t, i) -> affect_function.initialize(i) if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing && (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing initfn = let save_idxs = save_idxs @@ -795,17 +809,19 @@ function generate_single_rootfinding_callback( else initfn = user_initfun end - + return ContinuousCallback( - cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind, - initialize = initfn, - finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : (c, u, t, i) -> affect_function.finalize(i), + cond, affect_function.affect, affect_function.affect_neg, rootfind = cb.rootfind, + initialize = initfn, + finalize = isnothing(affect_function.finalize) ? SciMLBase.FINALIZE_DEFAULT : + (c, u, t, i) -> affect_function.finalize(i), initializealg = reinitialization_alg(cb)) end function generate_vector_rootfinding_callback( cbs, sys::AbstractODESystem, dvs = unknowns(sys), - ps = parameters(sys); rootfind = SciMLBase.RightRootFind, reinitialization = SciMLBase.CheckInit(), kwargs...) + ps = parameters(sys); rootfind = SciMLBase.RightRootFind, + reinitialization = SciMLBase.CheckInit(), kwargs...) eqs = map(cb -> flatten_equations(cb.eqs), cbs) num_eqs = length.(eqs) # fuse equations to create VectorContinuousCallback @@ -821,11 +837,12 @@ function generate_vector_rootfinding_callback( sys, rhss, dvs, ps; expression = Val{false}, kwargs...) affect_functions = @NamedTuple{ - affect::Function, - affect_neg::Union{Function, Nothing}, - initialize::Union{Function, Nothing}, + affect::Function, + affect_neg::Union{Function, Nothing}, + initialize::Union{Function, Nothing}, finalize::Union{Function, Nothing}}[ - compile_affect_fn(cb, sys, dvs, ps, kwargs) for cb in cbs] + compile_affect_fn(cb, sys, dvs, ps, kwargs) + for cb in cbs] cond = function (out, u, t, integ) rf_ip(out, u, parameter_values(integ), t) end @@ -861,17 +878,20 @@ function generate_vector_rootfinding_callback( if isnothing(func) continue else - func(integ) + func(integ) end end end end end end - initialize = handle_optional_setup_fn(map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT) - finalize = handle_optional_setup_fn(map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT) + initialize = handle_optional_setup_fn( + map(fn -> fn.initialize, affect_functions), SciMLBase.INITIALIZE_DEFAULT) + finalize = handle_optional_setup_fn( + map(fn -> fn.finalize, affect_functions), SciMLBase.FINALIZE_DEFAULT) return VectorContinuousCallback( - cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, finalize = finalize, initializealg = reinitialization) + cond, affect, affect_neg, length(eqs), rootfind = rootfind, initialize = initialize, + finalize = finalize, initializealg = reinitialization) end """ @@ -881,8 +901,8 @@ function compile_affect_fn(cb, sys::AbstractODESystem, dvs, ps, kwargs) eq_aff = affects(cb) eq_neg_aff = affect_negs(cb) affect = compile_affect(eq_aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) - function compile_optional_affect(aff, default=nothing) - if isnothing(aff) || aff==default + function compile_optional_affect(aff, default = nothing) + if isnothing(aff) || aff == default return nothing else return compile_affect(aff, cb, sys, dvs, ps; expression = Val{false}, kwargs...) @@ -918,13 +938,14 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow # groupby would be very useful here, but alas cb_classes = Dict{ @NamedTuple{ - rootfind::SciMLBase.RootfindOpt, + rootfind::SciMLBase.RootfindOpt, reinitialization::SciMLBase.DAEInitializationAlgorithm}, Vector{SymbolicContinuousCallback}}() for cb in cbs push!( - get!(() -> SymbolicContinuousCallback[], cb_classes, ( - rootfind = cb.rootfind, - reinitialization = reinitialization_alg(cb))), + get!(() -> SymbolicContinuousCallback[], cb_classes, + ( + rootfind = cb.rootfind, + reinitialization = reinitialization_alg(cb))), cb) end @@ -932,7 +953,8 @@ function generate_rootfinding_callback(cbs, sys::AbstractODESystem, dvs = unknow compiled_callbacks = map(collect(pairs(sort!( OrderedDict(cb_classes); by = p -> p.rootfind)))) do (equiv_class, cbs_in_class) return generate_vector_rootfinding_callback( - cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, reinitialization=equiv_class.reinitialization, kwargs...) + cbs_in_class, sys, dvs, ps; rootfind = equiv_class.rootfind, + reinitialization = equiv_class.reinitialization, kwargs...) end if length(compiled_callbacks) == 1 return compiled_callbacks[] @@ -984,29 +1006,34 @@ function invalid_variables(sys, expr) filter(x -> !any(isequal(x), all_symbols(sys)), reduce(vcat, vars(expr); init = [])) end function unassignable_variables(sys, expr) - assignable_syms = reduce(vcat, Symbolics.scalarize.(vcat(unknowns(sys), parameters(sys))); init=[]) + assignable_syms = reduce( + vcat, Symbolics.scalarize.(vcat(unknowns(sys), parameters(sys))); init = []) written = reduce(vcat, Symbolics.scalarize.(vars(expr)); init = []) return filter( x -> !any(isequal(x), assignable_syms), written) end -@generated function _generated_writeback(integ, setters::NamedTuple{NS1,<:Tuple}, values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2} +@generated function _generated_writeback(integ, setters::NamedTuple{NS1, <:Tuple}, + values::NamedTuple{NS2, <:Tuple}) where {NS1, NS2} setter_exprs = [] - for name in NS2 + for name in NS2 if !(name in NS1) missing_name = "Tried to write back to $name from affect; only declared states ($NS1) may be written to." error(missing_name) end push!(setter_exprs, :(setters.$name(integ, values.$name))) end - return :(begin $(setter_exprs...) end) + return :(begin + $(setter_exprs...) + end) end function check_assignable(sys, sym) if symbolic_type(sym) == ScalarSymbolic() is_variable(sys, sym) || is_parameter(sys, sym) elseif symbolic_type(sym) == ArraySymbolic() - is_variable(sys, sym) || is_parameter(sys, sym) || all(x -> check_assignable(sys, x), collect(sym)) + is_variable(sys, sym) || is_parameter(sys, sym) || + all(x -> check_assignable(sys, x), collect(sym)) elseif sym isa Union{AbstractArray, Tuple} all(x -> check_assignable(sys, x), sym) else @@ -1084,13 +1111,13 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. # okay so now to generate the stuff to assign it back into the system mod_pairs = mod_exprs .=> mod_syms - mod_names = (mod_syms..., ) + mod_names = (mod_syms...,) mod_og_val_fun = build_explicit_observed_function( sys, Symbolics.scalarize.(first.(mod_pairs)); array_type = :tuple) upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) - + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing save_idxs = get(ic.callback_to_clocks, cb, Int[]) else @@ -1104,10 +1131,12 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. upd_component_array = NamedTuple{mod_names}(modvals) # update the observed values - obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun(integ.u, integ.p, integ.t)) + obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun( + integ.u, integ.p, integ.t)) # let the user do their thing - modvals = if applicable(user_affect, upd_component_array, obs_component_array, ctx, integ) + modvals = if applicable( + user_affect, upd_component_array, obs_component_array, ctx, integ) user_affect(upd_component_array, obs_component_array, ctx, integ) elseif applicable(user_affect, upd_component_array, obs_component_array, ctx) user_affect(upd_component_array, obs_component_array, ctx) @@ -1122,7 +1151,7 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. # write the new values back to the integrator _generated_writeback(integ, upd_funs, modvals) - + for idx in save_idxs SciMLBase.save_discretes!(integ, idx) end @@ -1130,7 +1159,8 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. end end -function compile_affect(affect::Union{FunctionalAffect, ImperativeAffect}, cb, sys, dvs, ps; kwargs...) +function compile_affect( + affect::Union{FunctionalAffect, ImperativeAffect}, cb, sys, dvs, ps; kwargs...) compile_user_affect(affect, cb, sys, dvs, ps; kwargs...) end diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index c2c26aae7f..f4f97fb2b9 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1001,7 +1001,7 @@ end @test sol[b] == [5.0, 5.0] @test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] end -@testset "Heater" begin +@testset "Heater" begin @variables temp(t) params = @parameters furnace_on_threshold=0.5 furnace_off_threshold=0.7 furnace_power=1.0 leakage=0.1 furnace_on::Bool=false eqs = [ @@ -1080,7 +1080,7 @@ end [temp ~ furnace_off_threshold], ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x @set! x.furnace_on = false - end; initialize = ModelingToolkit.ImperativeAffect(modified = (; temp)) do x + end; initialize = ModelingToolkit.ImperativeAffect(modified = (; temp)) do x @set! x.temp = 0.2 end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( @@ -1145,17 +1145,16 @@ end @test_throws "refers to missing variable(s)" prob=ODEProblem( ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) - furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], ModelingToolkit.ImperativeAffect(modified = (; furnace_on), observed = (; furnace_on)) do x, o, c, i - return (;fictional2 = false) + return (; fictional2 = false) end) @named sys = ODESystem( eqs, t, [temp, tempsq], params; continuous_events = [furnace_off]) ss = structural_simplify(sys) - prob=ODEProblem( + prob = ODEProblem( ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) @test_throws "Tried to write back to" solve(prob, Tsit5()) end @@ -1220,15 +1219,16 @@ end @test getp(sol, cnt)(sol) == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state end - - import RuntimeGeneratedFunctions -function (f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{argnames, cache_tag, context_tag, id})(args::Vararg{Any, N}) where {N, argnames, cache_tag, context_tag, id} +function (f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ + argnames, cache_tag, context_tag, + id})(args::Vararg{Any, N}) where {N, argnames, cache_tag, context_tag, id} try RuntimeGeneratedFunctions.generated_callfunc(f, args...) - catch e + catch e @error "Caught error in RuntimeGeneratedFunction; source code follows" - func_expr = Expr(:->, Expr(:tuple, argnames...), RuntimeGeneratedFunctions._lookup_body(cache_tag, id)) + func_expr = Expr(:->, Expr(:tuple, argnames...), + RuntimeGeneratedFunctions._lookup_body(cache_tag, id)) @show func_expr rethrow(e) end From 89954e46437a7a6fd3536d699d2e3d1846020a48 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 22 Oct 2024 18:43:31 -0700 Subject: [PATCH 026/101] Formatter (2) --- docs/src/basics/Events.md | 161 +++++++++++++++++-------------- src/systems/diffeqs/odesystem.jl | 8 +- 2 files changed, 95 insertions(+), 74 deletions(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index 9d3ba30780..4a59149308 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -379,36 +379,39 @@ sol.ps[c] # sol[c] will error, since `c` is not a timeseries value It can be seen that the timeseries for `c` is not saved. - ## [(Experimental) Imperative affects](@id imp_affects) + The `ImperativeAffect` can be used as an alternative to the aforementioned functional affect form. Note that `ImperativeAffect` is still experimental; to emphasize this, we do not export it and it should be -included as `ModelingToolkit.ImperativeAffect`. It abstracts over how values are written back to the +included as `ModelingToolkit.ImperativeAffect`. It abstracts over how values are written back to the system, simplifying the definitions and (in the future) allowing assignments back to observed values by solving the nonlinear reinitialization problem afterwards. -We will use two examples to describe `ImperativeAffect`: a simple heater and a quadrature encoder. +We will use two examples to describe `ImperativeAffect`: a simple heater and a quadrature encoder. These examples will also demonstrate advanced usage of `ModelingToolkit.SymbolicContinousCallback`, the low-level interface that the aforementioned tuple form converts into and allows control over the exact SciMLCallbacks event that is generated for a continous event. ### [Heater](@id heater_events) + Bang-bang control of a heater connected to a leaky plant requires hysteresis in order to prevent control oscillation. -```@example events +```@example events @variables temp(t) params = @parameters furnace_on_threshold=0.5 furnace_off_threshold=0.7 furnace_power=1.0 leakage=0.1 furnace_on(t)::Bool=false eqs = [ D(temp) ~ furnace_on * furnace_power - temp^2 * leakage ] ``` + Our plant is simple. We have a heater that's turned on and off by the clocked parameter `furnace_on` which adds `furnace_power` forcing to the system when enabled. We then leak heat porportional to `leakage` -as a function of the square of the current temperature. +as a function of the square of the current temperature. We need a controller with hysteresis to conol the plant. We wish the furnace to turn on when the temperature is below `furnace_on_threshold` and off when above `furnace_off_threshold`, while maintaining its current state in between. To do this, we create two continous callbacks: + ```@example events using Setfield furnace_disable = ModelingToolkit.SymbolicContinuousCallback( @@ -422,42 +425,49 @@ furnace_enable = ModelingToolkit.SymbolicContinuousCallback( @set! x.furnace_on = true end) ``` + We're using the explicit form of `SymbolicContinuousCallback` here, though so far we aren't using anything that's not possible with the implicit interface. -You can also write +You can also write + ```julia -[temp ~ furnace_off_threshold] => ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i, c +[temp ~ furnace_off_threshold] => ModelingToolkit.ImperativeAffect(modified = (; + furnace_on)) do x, o, i, c @set! x.furnace_on = false end ``` + and it would work the same. The `ImperativeAffect` is the larger change in this example. `ImperativeAffect` has the constructor signature + ```julia - ImperativeAffect(f::Function; modified::NamedTuple, observed::NamedTuple, ctx) +ImperativeAffect(f::Function; modified::NamedTuple, observed::NamedTuple, ctx) ``` + that accepts the function to call, a named tuple of both the names of and symbolic values representing values in the system to be modified, a named tuple of the values that are merely observed (that is, used from the system but not modified), and a context that's passed to the affect function. In our example, each event merely changes whether the furnace is on or off. Accordingly, we pass a `modified` tuple -`(; furnace_on)` (creating a `NamedTuple` equivalent to `(furnace_on = furnace_on)`). `ImperativeAffect` will then +`(; furnace_on)` (creating a `NamedTuple` equivalent to `(furnace_on = furnace_on)`). `ImperativeAffect` will then evaluate this before calling our function to fill out all of the numerical values, then apply them back to the system once our affect function returns. Furthermore, it will check that it is possible to do this assignment. The function given to `ImperativeAffect` needs to have one of four signatures, checked in this order: -* `f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple` if the function needs the low-level integrator, -* `f(modified::NamedTuple, observed::NamedTuple, ctx)::NamedTuple` if the function needs the user-defined context, -* `f(modified::NamedTuple, observed::NamedTuple)::NamedTuple` if the function also reads observed values from the system, -* `f(modified::NamedTuple)::NamedTuple` if the function only writes values (unknowns or parameters) to the system. -The `do` block in the example implicitly constructs said function inline. For exposition, we use the full version (e.g. `x, o, i, c`) but this could be simplified to merely `x`. + + - `f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple` if the function needs the low-level integrator, + - `f(modified::NamedTuple, observed::NamedTuple, ctx)::NamedTuple` if the function needs the user-defined context, + - `f(modified::NamedTuple, observed::NamedTuple)::NamedTuple` if the function also reads observed values from the system, + - `f(modified::NamedTuple)::NamedTuple` if the function only writes values (unknowns or parameters) to the system. + The `do` block in the example implicitly constructs said function inline. For exposition, we use the full version (e.g. `x, o, i, c`) but this could be simplified to merely `x`. The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions. -In our example, if `furnace_on` is `false`, then the value of the `x` that's passed in as `modified` will be `(furnace_on = false)`. +In our example, if `furnace_on` is `false`, then the value of the `x` that's passed in as `modified` will be `(furnace_on = false)`. The modified values should be passed out in the same format: to set `furnace_on` to `true` we need to return a tuple `(furnace_on = true)`. We use Setfield to do this in the example, recreating the result tuple before returning it. -Accordingly, we can now interpret the `ImperativeAffect` definitions to mean that when `temp = furnace_off_threshold` we +Accordingly, we can now interpret the `ImperativeAffect` definitions to mean that when `temp = furnace_off_threshold` we will write `furnace_on = false` back to the system, and when `temp = furnace_on_threshold` we will write `furnace_on = true` back to the system. @@ -468,7 +478,8 @@ ss = structural_simplify(sys) prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 10.0)) sol = solve(prob, Tsit5()) plot(sol) -hline!([sol.ps[furnace_off_threshold], sol.ps[furnace_on_threshold]], l = (:black, 1), primary = false) +hline!([sol.ps[furnace_off_threshold], sol.ps[furnace_on_threshold]], + l = (:black, 1), primary = false) ``` Here we see exactly the desired hysteresis. The heater starts on until the temperature hits @@ -477,71 +488,76 @@ point the furnace turns on again until `furnace_off_threshold` and so on and so is effectively regulating the temperature of the plant. ### [Quadrature Encoder](@id quadrature) + For a more complex application we'll look at modeling a quadrature encoder attached to a shaft spinning at a constant speed. Traditionally, a quadrature encoder is built out of a code wheel that interrupts the sensors at constant intervals and two sensors slightly out of phase with one another. A state machine can take the pattern of pulses produced by the two sensors and determine the number of steps that the shaft has spun. The state machine takes the new value from each sensor and the old values and decodes them into the direction that the wheel has spun in this step. ```@example events - @variables theta(t) omega(t) - params = @parameters qA=0 qB=0 hA=0 hB=0 cnt::Int=0 - eqs = [D(theta) ~ omega - omega ~ 1.0] +@variables theta(t) omega(t) +params = @parameters qA=0 qB=0 hA=0 hB=0 cnt::Int=0 +eqs = [D(theta) ~ omega + omega ~ 1.0] ``` + Our continous-time system is extremely simple. We have two states, `theta` for the angle of the shaft -and `omega` for the rate at which it's spinning. We then have parameters for the state machine `qA, qB, hA, hB` +and `omega` for the rate at which it's spinning. We then have parameters for the state machine `qA, qB, hA, hB` and a step count `cnt`. We'll then implement the decoder as a simple Julia function. + ```@example events - function decoder(oldA, oldB, newA, newB) - state = (oldA, oldB, newA, newB) - if state == (0, 0, 1, 0) || state == (1, 0, 1, 1) || state == (1, 1, 0, 1) || - state == (0, 1, 0, 0) - return 1 - elseif state == (0, 0, 0, 1) || state == (0, 1, 1, 1) || state == (1, 1, 1, 0) || - state == (1, 0, 0, 0) - return -1 - elseif state == (0, 0, 0, 0) || state == (0, 1, 0, 1) || state == (1, 0, 1, 0) || - state == (1, 1, 1, 1) - return 0 - else - return 0 # err is interpreted as no movement - end +function decoder(oldA, oldB, newA, newB) + state = (oldA, oldB, newA, newB) + if state == (0, 0, 1, 0) || state == (1, 0, 1, 1) || state == (1, 1, 0, 1) || + state == (0, 1, 0, 0) + return 1 + elseif state == (0, 0, 0, 1) || state == (0, 1, 1, 1) || state == (1, 1, 1, 0) || + state == (1, 0, 0, 0) + return -1 + elseif state == (0, 0, 0, 0) || state == (0, 1, 0, 1) || state == (1, 0, 1, 0) || + state == (1, 1, 1, 1) + return 0 + else + return 0 # err is interpreted as no movement end +end ``` + Based on the current and old state, this function will return 1 if the wheel spun in the positive direction, -1 if in the negative, and 0 otherwise. -The encoder state advances when the occlusion begins or ends. We model the +The encoder state advances when the occlusion begins or ends. We model the code wheel as simply detecting when `cos(100*theta)` is 0; if we're at a positive edge of the 0 crossing, then we interpret that as occlusion (so the discrete `qA` goes to 1). Otherwise, if `cos` is going negative, we interpret that as lack of occlusion (so the discrete goes to 0). The decoder function is then invoked to update the count with this new information. We can implement this in one of two ways: using edge sign detection or right root finding. For exposition, we -will implement each sensor differently. +will implement each sensor differently. For sensor A, we're using the edge detction method. By providing a different affect to `SymbolicContinuousCallback`'s `affect_neg` argument, we can specify different behaviour for the negative crossing vs. the positive crossing of the root. In our encoder, we interpret this as occlusion or nonocclusion of the sensor, update the internal state, and tick the decoder. + ```@example events - qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], - ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c - @set! x.hA = x.qA - @set! x.hB = o.qB - @set! x.qA = 1 - @set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) - x - end, - affect_neg = ModelingToolkit.ImperativeAffect( - (; qA, hA, hB, cnt), (; qB)) do x, o, c, i - @set! x.hA = x.qA - @set! x.hB = o.qB - @set! x.qA = 0 - @set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) - x - end) +qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], + ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c + @set! x.hA = x.qA + @set! x.hB = o.qB + @set! x.qA = 1 + @set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) + x + end, + affect_neg = ModelingToolkit.ImperativeAffect( + (; qA, hA, hB, cnt), (; qB)) do x, o, c, i + @set! x.hA = x.qA + @set! x.hB = o.qB + @set! x.qA = 0 + @set! x.cnt += decoder(x.hA, x.hB, x.qA, o.qB) + x + end) ``` The other way we can implement a sensor is by changing the root find. @@ -550,29 +566,32 @@ the root is crossed. This makes it trickier to figure out what the new state is. Instead, we can use right root finding: ```@example events - qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0], - ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA, theta)) do x, o, i, c - @set! x.hA = o.qA - @set! x.hB = x.qB - @set! x.qB = clamp(sign(cos(100 * o.theta - π / 2)), 0.0, 1.0) - @set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) - x - end; rootfind = SciMLBase.RightRootFind) +qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0], + ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA, theta)) do x, o, i, c + @set! x.hA = o.qA + @set! x.hB = x.qB + @set! x.qB = clamp(sign(cos(100 * o.theta - π / 2)), 0.0, 1.0) + @set! x.cnt += decoder(x.hA, x.hB, o.qA, x.qB) + x + end; rootfind = SciMLBase.RightRootFind) ``` -Here, sensor B is located `π / 2` behind sensor A in angular space, so we're adjusting our + +Here, sensor B is located `π / 2` behind sensor A in angular space, so we're adjusting our trigger function accordingly. We here ask for right root finding on the callback, so we know -that the value of said function will have the "new" sign rather than the old one. Thus, we can +that the value of said function will have the "new" sign rather than the old one. Thus, we can determine the new state of the sensor from the sign of the indicator function evaluated at the affect activation point, with -1 mapped to 0. We can now simulate the encoder. + ```@example events - @named sys = ODESystem( - eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt]) - ss = structural_simplify(sys) - prob = ODEProblem(ss, [theta => 0.0], (0.0, pi)) - sol = solve(prob, Tsit5(); dtmax = 0.01) - sol.ps[cnt] +@named sys = ODESystem( + eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt]) +ss = structural_simplify(sys) +prob = ODEProblem(ss, [theta => 0.0], (0.0, pi)) +sol = solve(prob, Tsit5(); dtmax = 0.01) +sol.ps[cnt] ``` + `cos(100*theta)` will have 200 crossings in the half rotation we've gone through, so the encoder would notionally count 200 steps. -Our encoder counts 198 steps (it loses one step to initialization and one step due to the final state falling squarely on an edge). \ No newline at end of file +Our encoder counts 198 steps (it loses one step to initialization and one step due to the final state falling squarely on an edge). diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index e99911eec6..1cc8273b4d 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -444,7 +444,7 @@ function build_explicit_observed_function(sys, ts; param_only = false, op = Operator, throw = true, - array_type=:array) + array_type = :array) if (isscalar = symbolic_type(ts) !== NotSymbolic()) ts = [ts] end @@ -589,10 +589,12 @@ function build_explicit_observed_function(sys, ts; oop_mtkp_wrapper = mtkparams_wrapper end - output_expr = isscalar ? ts[1] : (array_type == :array ? MakeArray(ts, output_type) : MakeTuple(ts)) + output_expr = isscalar ? ts[1] : + (array_type == :array ? MakeArray(ts, output_type) : MakeTuple(ts)) # Need to keep old method of building the function since it uses `output_type`, # which can't be provided to `build_function` - oop_fn = Func(args, [], pre(Let(obsexprs, output_expr, false))) |> array_wrapper[1] |> oop_mtkp_wrapper |> toexpr + oop_fn = Func(args, [], pre(Let(obsexprs, output_expr, false))) |> array_wrapper[1] |> + oop_mtkp_wrapper |> toexpr oop_fn = expression ? oop_fn : eval_or_rgf(oop_fn; eval_expression, eval_module) if !isscalar From 711fb8c654be38ce7377b85f3025ee7934cd5de8 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 22 Oct 2024 18:44:45 -0700 Subject: [PATCH 027/101] Spelling --- docs/src/basics/Events.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index 4a59149308..8088e0f52a 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -388,9 +388,9 @@ system, simplifying the definitions and (in the future) allowing assignments bac by solving the nonlinear reinitialization problem afterwards. We will use two examples to describe `ImperativeAffect`: a simple heater and a quadrature encoder. -These examples will also demonstrate advanced usage of `ModelingToolkit.SymbolicContinousCallback`, +These examples will also demonstrate advanced usage of `ModelingToolkit.SymbolicContinuousCallback`, the low-level interface that the aforementioned tuple form converts into and allows control over the -exact SciMLCallbacks event that is generated for a continous event. +exact SciMLCallbacks event that is generated for a continuous event. ### [Heater](@id heater_events) @@ -405,7 +405,7 @@ eqs = [ ``` Our plant is simple. We have a heater that's turned on and off by the clocked parameter `furnace_on` -which adds `furnace_power` forcing to the system when enabled. We then leak heat porportional to `leakage` +which adds `furnace_power` forcing to the system when enabled. We then leak heat proportional to `leakage` as a function of the square of the current temperature. We need a controller with hysteresis to conol the plant. We wish the furnace to turn on when the temperature @@ -537,7 +537,7 @@ then invoked to update the count with this new information. We can implement this in one of two ways: using edge sign detection or right root finding. For exposition, we will implement each sensor differently. -For sensor A, we're using the edge detction method. By providing a different affect to `SymbolicContinuousCallback`'s +For sensor A, we're using the edge detection method. By providing a different affect to `SymbolicContinuousCallback`'s `affect_neg` argument, we can specify different behaviour for the negative crossing vs. the positive crossing of the root. In our encoder, we interpret this as occlusion or nonocclusion of the sensor, update the internal state, and tick the decoder. From a8ea3698e71afe91dd9040a6fb80bab22c7d6c95 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 22 Oct 2024 18:46:54 -0700 Subject: [PATCH 028/101] Spelling (2) --- docs/src/basics/Events.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index 8088e0f52a..e583af8d9e 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -410,7 +410,7 @@ as a function of the square of the current temperature. We need a controller with hysteresis to conol the plant. We wish the furnace to turn on when the temperature is below `furnace_on_threshold` and off when above `furnace_off_threshold`, while maintaining its current state -in between. To do this, we create two continous callbacks: +in between. To do this, we create two continuous callbacks: ```@example events using Setfield @@ -501,7 +501,7 @@ eqs = [D(theta) ~ omega omega ~ 1.0] ``` -Our continous-time system is extremely simple. We have two states, `theta` for the angle of the shaft +Our continuous-time system is extremely simple. We have two states, `theta` for the angle of the shaft and `omega` for the rate at which it's spinning. We then have parameters for the state machine `qA, qB, hA, hB` and a step count `cnt`. From 6a304e7d7f8960d7f2cc63d1be93d8f41f6db824 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 22 Oct 2024 18:48:15 -0700 Subject: [PATCH 029/101] Remove debug RGF shim --- test/symbolic_events.jl | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index f4f97fb2b9..96a57ec40e 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1217,19 +1217,4 @@ end prob = ODEProblem(ss, [theta => 0.0], (0.0, pi)) sol = solve(prob, Tsit5(); dtmax = 0.01) @test getp(sol, cnt)(sol) == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state -end - -import RuntimeGeneratedFunctions -function (f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ - argnames, cache_tag, context_tag, - id})(args::Vararg{Any, N}) where {N, argnames, cache_tag, context_tag, id} - try - RuntimeGeneratedFunctions.generated_callfunc(f, args...) - catch e - @error "Caught error in RuntimeGeneratedFunction; source code follows" - func_expr = Expr(:->, Expr(:tuple, argnames...), - RuntimeGeneratedFunctions._lookup_body(cache_tag, id)) - @show func_expr - rethrow(e) - end -end +end \ No newline at end of file From 35ec1c59b31ab3d0c6a421482d7e48459a00111c Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 22 Oct 2024 18:50:32 -0700 Subject: [PATCH 030/101] Formatter (3) --- test/symbolic_events.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 96a57ec40e..c021f99eea 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1217,4 +1217,4 @@ end prob = ODEProblem(ss, [theta => 0.0], (0.0, pi)) sol = solve(prob, Tsit5(); dtmax = 0.01) @test getp(sol, cnt)(sol) == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state -end \ No newline at end of file +end From 4c3d6fcc1eaedbec33cfd3f979a958110ff4eace Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 23 Oct 2024 06:21:05 +0000 Subject: [PATCH 031/101] Update docs/src/basics/Events.md Co-authored-by: Fredrik Bagge Carlson --- docs/src/basics/Events.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index e583af8d9e..41fa96d619 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -394,7 +394,7 @@ exact SciMLCallbacks event that is generated for a continuous event. ### [Heater](@id heater_events) -Bang-bang control of a heater connected to a leaky plant requires hysteresis in order to prevent control oscillation. +Bang-bang control of a heater connected to a leaky plant requires hysteresis in order to prevent rapid control oscillation. ```@example events @variables temp(t) From cdddeb9f62ffcd12357dc96648167fe8514993be Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 23 Oct 2024 06:21:34 +0000 Subject: [PATCH 032/101] Update docs/src/basics/Events.md Co-authored-by: Fredrik Bagge Carlson --- docs/src/basics/Events.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index 41fa96d619..add870b7f3 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -408,7 +408,7 @@ Our plant is simple. We have a heater that's turned on and off by the clocked pa which adds `furnace_power` forcing to the system when enabled. We then leak heat proportional to `leakage` as a function of the square of the current temperature. -We need a controller with hysteresis to conol the plant. We wish the furnace to turn on when the temperature +We need a controller with hysteresis to control the plant. We wish the furnace to turn on when the temperature is below `furnace_on_threshold` and off when above `furnace_off_threshold`, while maintaining its current state in between. To do this, we create two continuous callbacks: From 2ac2d230664916fdd0151cb6257adbf9dd7b2ccb Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 23 Oct 2024 06:21:54 +0000 Subject: [PATCH 033/101] Update docs/src/basics/Events.md Co-authored-by: Fredrik Bagge Carlson --- docs/src/basics/Events.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index add870b7f3..0d9149e23c 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -561,7 +561,7 @@ qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], ``` The other way we can implement a sensor is by changing the root find. -Normally, we use left root finding; the affect will be invoked instantaneously before +Normally, we use left root finding; the affect will be invoked instantaneously _before_ the root is crossed. This makes it trickier to figure out what the new state is. Instead, we can use right root finding: From 9bf734dcbc03123b72fb179a47370d9aa484869f Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 23 Oct 2024 15:28:56 -0700 Subject: [PATCH 034/101] Simplify callback interface, fix references --- docs/src/basics/Events.md | 22 +++++++-------- src/systems/callbacks.jl | 31 ++++++-------------- test/symbolic_events.jl | 59 +++------------------------------------ 3 files changed, 23 insertions(+), 89 deletions(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index 0d9149e23c..de1c1000ce 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -416,12 +416,12 @@ in between. To do this, we create two continuous callbacks: using Setfield furnace_disable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i, c + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i @set! x.furnace_on = false end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_on_threshold], - ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i, c + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i @set! x.furnace_on = true end) ``` @@ -454,18 +454,16 @@ In our example, each event merely changes whether the furnace is on or off. Acco evaluate this before calling our function to fill out all of the numerical values, then apply them back to the system once our affect function returns. Furthermore, it will check that it is possible to do this assignment. -The function given to `ImperativeAffect` needs to have one of four signatures, checked in this order: - - - `f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple` if the function needs the low-level integrator, - - `f(modified::NamedTuple, observed::NamedTuple, ctx)::NamedTuple` if the function needs the user-defined context, - - `f(modified::NamedTuple, observed::NamedTuple)::NamedTuple` if the function also reads observed values from the system, - - `f(modified::NamedTuple)::NamedTuple` if the function only writes values (unknowns or parameters) to the system. - The `do` block in the example implicitly constructs said function inline. For exposition, we use the full version (e.g. `x, o, i, c`) but this could be simplified to merely `x`. +The function given to `ImperativeAffect` needs to have the signature: +```julia + f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple +``` The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions. In our example, if `furnace_on` is `false`, then the value of the `x` that's passed in as `modified` will be `(furnace_on = false)`. The modified values should be passed out in the same format: to set `furnace_on` to `true` we need to return a tuple `(furnace_on = true)`. -We use Setfield to do this in the example, recreating the result tuple before returning it. +The examples does this with Setfield, recreating the result tuple before returning it; the returned tuple may optionally be missing values as +well, in which case those values will not be written back to the problem. Accordingly, we can now interpret the `ImperativeAffect` definitions to mean that when `temp = furnace_off_threshold` we will write `furnace_on = false` back to the system, and when `temp = furnace_on_threshold` we will write `furnace_on = true` back @@ -543,7 +541,7 @@ In our encoder, we interpret this as occlusion or nonocclusion of the sensor, up ```@example events qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], - ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c + ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, c, i @set! x.hA = x.qA @set! x.hB = o.qB @set! x.qA = 1 @@ -567,7 +565,7 @@ Instead, we can use right root finding: ```@example events qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0], - ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA, theta)) do x, o, i, c + ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA, theta)) do x, o, c, i @set! x.hA = o.qA @set! x.hB = x.qB @set! x.qB = clamp(sign(cos(100 * o.theta - π / 2)), 0.0, 1.0) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index e7198817b1..618a1f3a83 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -76,12 +76,11 @@ end `ImperativeAffect` is a helper for writing affect functions that will compute observed values and ensure that modified values are correctly written back into the system. The affect function `f` needs to have -one of four signatures: -* `f(modified::NamedTuple)::NamedTuple` if the function only writes values (unknowns or parameters) to the system, -* `f(modified::NamedTuple, observed::NamedTuple)::NamedTuple` if the function also reads observed values from the system, -* `f(modified::NamedTuple, observed::NamedTuple, ctx)::NamedTuple` if the function needs the user-defined context, -* `f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple` if the function needs the low-level integrator. -These will be checked in reverse order (that is, the four-argument version first, than the 3, etc). +the signature + +``` + f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple +``` The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions. Each declaration`NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x` @@ -1046,7 +1045,7 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. Implementation sketch: generate observed function (oop), should save to a component array under obs_syms do the same stuff as the normal FA for pars_syms - call the affect method - test if it's OOP or IP using applicable + call the affect method unpack and apply the resulting values =# function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup) @@ -1135,22 +1134,10 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. integ.u, integ.p, integ.t)) # let the user do their thing - modvals = if applicable( - user_affect, upd_component_array, obs_component_array, ctx, integ) - user_affect(upd_component_array, obs_component_array, ctx, integ) - elseif applicable(user_affect, upd_component_array, obs_component_array, ctx) - user_affect(upd_component_array, obs_component_array, ctx) - elseif applicable(user_affect, upd_component_array, obs_component_array) - user_affect(upd_component_array, obs_component_array) - elseif applicable(user_affect, upd_component_array) - user_affect(upd_component_array) - else - @error "User affect function $user_affect needs to implement one of the supported ImperativeAffect callback forms; see the ImperativeAffect docstring for more details" - user_affect(upd_component_array, obs_component_array, integ, ctx) # this WILL error but it'll give a more sensible message - end - + upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) + # write the new values back to the integrator - _generated_writeback(integ, upd_funs, modvals) + _generated_writeback(integ, upd_funs, upd_vals) for idx in save_idxs SciMLBase.save_discretes!(integ, idx) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index c021f99eea..3c3e9168e7 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1027,60 +1027,9 @@ end furnace_off = ModelingToolkit.SymbolicContinuousCallback( [temp ~ furnace_off_threshold], - ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i - @set! x.furnace_on = false - end) - furnace_enable = ModelingToolkit.SymbolicContinuousCallback( - [temp ~ furnace_on_threshold], - ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, i - @set! x.furnace_on = true - end) - @named sys = ODESystem( - eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) - ss = structural_simplify(sys) - prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) - sol = solve(prob, Tsit5(); dtmax = 0.01) - @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) - - furnace_off = ModelingToolkit.SymbolicContinuousCallback( - [temp ~ furnace_off_threshold], - ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o - @set! x.furnace_on = false - end) - furnace_enable = ModelingToolkit.SymbolicContinuousCallback( - [temp ~ furnace_on_threshold], - ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o - @set! x.furnace_on = true - end) - @named sys = ODESystem( - eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) - ss = structural_simplify(sys) - prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) - sol = solve(prob, Tsit5(); dtmax = 0.01) - @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) - - furnace_off = ModelingToolkit.SymbolicContinuousCallback( - [temp ~ furnace_off_threshold], - ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x - @set! x.furnace_on = false - end) - furnace_enable = ModelingToolkit.SymbolicContinuousCallback( - [temp ~ furnace_on_threshold], - ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x - @set! x.furnace_on = true - end) - @named sys = ODESystem( - eqs, t, [temp], params; continuous_events = [furnace_off, furnace_enable]) - ss = structural_simplify(sys) - prob = ODEProblem(ss, [temp => 0.0, furnace_on => true], (0.0, 100.0)) - sol = solve(prob, Tsit5(); dtmax = 0.01) - @test all(sol[temp][sol.t .> 1.0] .<= 0.79) && all(sol[temp][sol.t .> 1.0] .>= 0.49) - - furnace_off = ModelingToolkit.SymbolicContinuousCallback( - [temp ~ furnace_off_threshold], - ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x + ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i @set! x.furnace_on = false - end; initialize = ModelingToolkit.ImperativeAffect(modified = (; temp)) do x + end; initialize = ModelingToolkit.ImperativeAffect(modified = (; temp)) do x, o, c, i @set! x.temp = 0.2 end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( @@ -1180,7 +1129,7 @@ end end end qAevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta) ~ 0], - ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, i, c + ModelingToolkit.ImperativeAffect((; qA, hA, hB, cnt), (; qB)) do x, o, c, i @set! x.hA = x.qA @set! x.hB = o.qB @set! x.qA = 1 @@ -1196,7 +1145,7 @@ end x end; rootfind = SciMLBase.RightRootFind) qBevt = ModelingToolkit.SymbolicContinuousCallback([cos(100 * theta - π / 2) ~ 0], - ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA)) do x, o, i, c + ModelingToolkit.ImperativeAffect((; qB, hA, hB, cnt), (; qA)) do x, o, c, i @set! x.hA = o.qA @set! x.hB = x.qB @set! x.qB = 1 From 1a3f7d4e4916c12d26a4659334a859a9c3d5ca8b Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 23 Oct 2024 15:33:20 -0700 Subject: [PATCH 035/101] Make array_type an actual type, though only in a limited sense --- src/systems/callbacks.jl | 4 ++-- src/systems/diffeqs/odesystem.jl | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 618a1f3a83..c34681a20c 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -1105,7 +1105,7 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. end obs_fun = build_explicit_observed_function( sys, Symbolics.scalarize.(obs_exprs); - array_type = :tuple) + array_type = Tuple) obs_sym_tuple = (obs_syms...,) # okay so now to generate the stuff to assign it back into the system @@ -1113,7 +1113,7 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. mod_names = (mod_syms...,) mod_og_val_fun = build_explicit_observed_function( sys, Symbolics.scalarize.(first.(mod_pairs)); - array_type = :tuple) + array_type = Tuple) upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 1cc8273b4d..374cbab5f8 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -429,7 +429,7 @@ Options not otherwise specified are: * `op = Operator` sets the recursion terminator for the walk done by `vars` to identify the variables that appear in `ts`. See the documentation for `vars` for more detail. * `throw = true` if true, throw an error when generating a function for `ts` that reference variables that do not exist * `drop_expr` is deprecated. -* `array_type`; only used if the output is an array (that is, `!isscalar(ts)`). If `:array`, then it will generate an array, if `:tuple` then it will generate a tuple. +* `array_type`; only used if the output is an array (that is, `!isscalar(ts)`). If it is `Vector`, then it will generate an array, if `Tuple` then it will generate a tuple. """ function build_explicit_observed_function(sys, ts; inputs = nothing, @@ -444,7 +444,7 @@ function build_explicit_observed_function(sys, ts; param_only = false, op = Operator, throw = true, - array_type = :array) + array_type = Vector) if (isscalar = symbolic_type(ts) !== NotSymbolic()) ts = [ts] end @@ -590,7 +590,7 @@ function build_explicit_observed_function(sys, ts; end output_expr = isscalar ? ts[1] : - (array_type == :array ? MakeArray(ts, output_type) : MakeTuple(ts)) + (array_type <: Vector ? MakeArray(ts, output_type) : MakeTuple(ts)) # Need to keep old method of building the function since it uses `output_type`, # which can't be provided to `build_function` oop_fn = Func(args, [], pre(Let(obsexprs, output_expr, false))) |> array_wrapper[1] |> From 6c7e4b84ae40e0bf1cbe1c61f6e4c41202ee449c Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 23 Oct 2024 15:34:59 -0700 Subject: [PATCH 036/101] Clean up language indocs --- docs/src/basics/Events.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index de1c1000ce..902be98fd4 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -404,7 +404,7 @@ eqs = [ ] ``` -Our plant is simple. We have a heater that's turned on and off by the clocked parameter `furnace_on` +Our plant is simple. We have a heater that's turned on and off by the time-indexed parameter `furnace_on` which adds `furnace_power` forcing to the system when enabled. We then leak heat proportional to `leakage` as a function of the square of the current temperature. @@ -499,9 +499,9 @@ eqs = [D(theta) ~ omega omega ~ 1.0] ``` -Our continuous-time system is extremely simple. We have two states, `theta` for the angle of the shaft +Our continuous-time system is extremely simple. We have two unknown variables `theta` for the angle of the shaft and `omega` for the rate at which it's spinning. We then have parameters for the state machine `qA, qB, hA, hB` -and a step count `cnt`. +(corresponding to the current quadrature of the A/B sensors and the historical ones) and a step count `cnt`. We'll then implement the decoder as a simple Julia function. From 06ecd2ee541d94bcf92c39543583f9bc0646286d Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Wed, 23 Oct 2024 21:24:46 -0700 Subject: [PATCH 037/101] Format --- docs/src/basics/Events.md | 7 ++++--- src/systems/callbacks.jl | 2 +- test/symbolic_events.jl | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index 902be98fd4..4c9803ef57 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -456,13 +456,14 @@ once our affect function returns. Furthermore, it will check that it is possible The function given to `ImperativeAffect` needs to have the signature: -```julia - f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple +```julia +f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple ``` + The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions. In our example, if `furnace_on` is `false`, then the value of the `x` that's passed in as `modified` will be `(furnace_on = false)`. The modified values should be passed out in the same format: to set `furnace_on` to `true` we need to return a tuple `(furnace_on = true)`. -The examples does this with Setfield, recreating the result tuple before returning it; the returned tuple may optionally be missing values as +The examples does this with Setfield, recreating the result tuple before returning it; the returned tuple may optionally be missing values as well, in which case those values will not be written back to the problem. Accordingly, we can now interpret the `ImperativeAffect` definitions to mean that when `temp = furnace_off_threshold` we diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index c34681a20c..1061139a8a 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -1135,7 +1135,7 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. # let the user do their thing upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) - + # write the new values back to the integrator _generated_writeback(integ, upd_funs, upd_vals) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 3c3e9168e7..77961cb01f 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1029,7 +1029,8 @@ end [temp ~ furnace_off_threshold], ModelingToolkit.ImperativeAffect(modified = (; furnace_on)) do x, o, c, i @set! x.furnace_on = false - end; initialize = ModelingToolkit.ImperativeAffect(modified = (; temp)) do x, o, c, i + end; initialize = ModelingToolkit.ImperativeAffect(modified = (; + temp)) do x, o, c, i @set! x.temp = 0.2 end) furnace_enable = ModelingToolkit.SymbolicContinuousCallback( From 7f7f65cc3e9dde8fec99037f4df27221088ed818 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Fri, 25 Oct 2024 10:17:26 -0700 Subject: [PATCH 038/101] Clear up some of the documentation language --- docs/src/basics/Events.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index 4c9803ef57..90b1b4036d 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -383,14 +383,13 @@ It can be seen that the timeseries for `c` is not saved. The `ImperativeAffect` can be used as an alternative to the aforementioned functional affect form. Note that `ImperativeAffect` is still experimental; to emphasize this, we do not export it and it should be -included as `ModelingToolkit.ImperativeAffect`. It abstracts over how values are written back to the -system, simplifying the definitions and (in the future) allowing assignments back to observed values -by solving the nonlinear reinitialization problem afterwards. +included as `ModelingToolkit.ImperativeAffect`. `ImperativeAffect` aims to simplify the manipulation of +system state. We will use two examples to describe `ImperativeAffect`: a simple heater and a quadrature encoder. These examples will also demonstrate advanced usage of `ModelingToolkit.SymbolicContinuousCallback`, -the low-level interface that the aforementioned tuple form converts into and allows control over the -exact SciMLCallbacks event that is generated for a continuous event. +the low-level interface of the tuple form converts into that allows control over the SciMLBase-level +event that is generated for a continuous event. ### [Heater](@id heater_events) From e750219fcb2d746b5b3bf2c52bcf22c19c48955c Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Fri, 25 Oct 2024 10:49:42 -0700 Subject: [PATCH 039/101] Format --- docs/src/basics/Events.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/src/basics/Events.md b/docs/src/basics/Events.md index 90b1b4036d..23e1e6d7d1 100644 --- a/docs/src/basics/Events.md +++ b/docs/src/basics/Events.md @@ -383,12 +383,12 @@ It can be seen that the timeseries for `c` is not saved. The `ImperativeAffect` can be used as an alternative to the aforementioned functional affect form. Note that `ImperativeAffect` is still experimental; to emphasize this, we do not export it and it should be -included as `ModelingToolkit.ImperativeAffect`. `ImperativeAffect` aims to simplify the manipulation of +included as `ModelingToolkit.ImperativeAffect`. `ImperativeAffect` aims to simplify the manipulation of system state. We will use two examples to describe `ImperativeAffect`: a simple heater and a quadrature encoder. These examples will also demonstrate advanced usage of `ModelingToolkit.SymbolicContinuousCallback`, -the low-level interface of the tuple form converts into that allows control over the SciMLBase-level +the low-level interface of the tuple form converts into that allows control over the SciMLBase-level event that is generated for a continuous event. ### [Heater](@id heater_events) From 3bc5d236053cebf4b08315f64379c3666d5c9e84 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 12 Nov 2024 16:29:47 -0800 Subject: [PATCH 040/101] Fix merge issues --- src/systems/callbacks.jl | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index c2c895626c..fab573cffb 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -247,8 +247,6 @@ struct SymbolicContinuousCallback initialize = NULL_AFFECT, finalize = NULL_AFFECT, rootfind = SciMLBase.LeftRootFind, - initialize = NULL_AFFECT, - finalize = NULL_AFFECT, reinitializealg = SciMLBase.CheckInit()) new(eqs, initialize, finalize, make_affect(affect), make_affect(affect_neg), rootfind, reinitializealg) @@ -387,16 +385,6 @@ function affect_negs(cbs::Vector{SymbolicContinuousCallback}) mapreduce(affect_negs, vcat, cbs, init = Equation[]) end -initialize_affects(cb::SymbolicContinuousCallback) = cb.initialize -function initialize_affects(cbs::Vector{SymbolicContinuousCallback}) - mapreduce(initialize_affects, vcat, cbs, init = Equation[]) -end - -finalize_affects(cb::SymbolicContinuousCallback) = cb.initialize -function finalize_affects(cbs::Vector{SymbolicContinuousCallback}) - mapreduce(finalize_affects, vcat, cbs, init = Equation[]) -end - reinitialization_alg(cb::SymbolicContinuousCallback) = cb.reinitializealg function reinitialization_algs(cbs::Vector{SymbolicContinuousCallback}) mapreduce( From e69ef194e685fb94016541255030b75764acbbb5 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 12 Nov 2024 23:45:49 -0800 Subject: [PATCH 041/101] Clean up usage of custom init --- src/systems/callbacks.jl | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index fab573cffb..9c8187f6b7 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -922,18 +922,11 @@ function generate_vector_rootfinding_callback( (cb, fn) -> begin if (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing let save_idxs = save_idxs - if !isnothing(fn.initialize) - (i) -> begin - fn.initialize(i) - for idx in save_idxs - SciMLBase.save_discretes!(i, idx) - end - end - else - (i) -> begin - for idx in save_idxs - SciMLBase.save_discretes!(i, idx) - end + custom_init = fn.initialize + (i) -> begin + isnothing(custom_init) && custom_init(i) + for idx in save_idxs + SciMLBase.save_discretes!(i, idx) end end end From 3c6b37ed3852d3590afe9b29f9690b0824f84a09 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 12 Nov 2024 23:46:52 -0800 Subject: [PATCH 042/101] Simplify setup function construction --- src/systems/callbacks.jl | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 9c8187f6b7..b9423b8f07 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -918,24 +918,21 @@ function generate_vector_rootfinding_callback( initialize = nothing if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing initialize = handle_optional_setup_fn( - map( - (cb, fn) -> begin - if (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing - let save_idxs = save_idxs - custom_init = fn.initialize - (i) -> begin - isnothing(custom_init) && custom_init(i) - for idx in save_idxs - SciMLBase.save_discretes!(i, idx) - end + map(cbs, affect_functions) do cb, fn + if (save_idxs = get(ic.callback_to_clocks, cb, nothing)) !== nothing + let save_idxs = save_idxs + custom_init = fn.initialize + (i) -> begin + isnothing(custom_init) && custom_init(i) + for idx in save_idxs + SciMLBase.save_discretes!(i, idx) end end - else - fn.initialize end - end, - cbs, - affect_functions), + else + fn.initialize + end + end, SciMLBase.INITIALIZE_DEFAULT) else From 9589a1f347673cb85b639fbccdab508201d64402 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Nov 2024 12:50:57 +0530 Subject: [PATCH 043/101] feat: store args and kwargs in `EmptySciMLFunction` --- src/systems/problem_utils.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 1cf110df13..1cfb8eb40a 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -431,12 +431,16 @@ end $(TYPEDEF) A simple utility meant to be used as the `constructor` passed to `process_SciMLProblem` in -case constructing a SciMLFunction is not required. +case constructing a SciMLFunction is not required. The arguments passed to it are available +in the `args` field, and the keyword arguments in the `kwargs` field. """ -struct EmptySciMLFunction end +struct EmptySciMLFunction{A, K} + args::A + kwargs::K +end function EmptySciMLFunction(args...; kwargs...) - return nothing + return EmptySciMLFunction{typeof(args), typeof(kwargs)}(args, kwargs) end """ From 8d4f5423e424279ba0855bf79e008a25e6544a77 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 20 Nov 2024 12:54:33 +0530 Subject: [PATCH 044/101] feat: propagate `ODEProblem` guesses to `remake` --- src/systems/diffeqs/abstractodesystem.jl | 4 +- src/systems/nonlinear/initializesystem.jl | 90 ++++++----------------- src/systems/problem_utils.jl | 14 +++- test/initializationsystem.jl | 23 ++++++ 4 files changed, 60 insertions(+), 71 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index cbf835f48d..af96c4fcfe 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1310,11 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, elseif isempty(u0map) && get_initializesystem(sys) === nothing isys = structural_simplify( generate_initializesystem( - sys; initialization_eqs, check_units, pmap = parammap); fully_determined) + sys; initialization_eqs, check_units, pmap = parammap, guesses); fully_determined) else isys = structural_simplify( generate_initializesystem( - sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined) + sys; u0map, initialization_eqs, check_units, pmap = parammap, guesses); fully_determined) end ts = get_tearing_state(isys) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index eff19afb07..fe41e44d84 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -30,7 +30,8 @@ function generate_initializesystem(sys::ODESystem; # 1) process dummy derivatives and u0map into initialization system eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations defs = copy(defaults(sys)) # copy so we don't modify sys.defaults - guesses = merge(get_guesses(sys), todict(guesses)) + additional_guesses = anydict(guesses) + guesses = merge(get_guesses(sys), additional_guesses) schedule = getfield(sys, :schedule) if !isnothing(schedule) for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub) @@ -178,7 +179,7 @@ function generate_initializesystem(sys::ODESystem; for k in keys(defs) defs[k] = substitute(defs[k], paramsubs) end - meta = InitializationSystemMetadata(Dict{Any, Any}(u0map), Dict{Any, Any}(pmap)) + meta = InitializationSystemMetadata(anydict(u0map), anydict(pmap), additional_guesses) return NonlinearSystem(eqs_ics, vars, pars; @@ -193,6 +194,7 @@ end struct InitializationSystemMetadata u0map::Dict{Any, Any} pmap::Dict{Any, Any} + additional_guesses::Dict{Any, Any} end function is_parameter_solvable(p, pmap, defs, guesses) @@ -263,75 +265,29 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) return initprob, odefn.update_initializeprob!, odefn.initializeprobmap, odefn.initializeprobpmap end - if u0 === missing || isempty(u0) - u0 = Dict() - elseif !(eltype(u0) <: Pair) - u0 = Dict(unknowns(sys) .=> u0) - end - if p === missing - p = Dict() + dvs = unknowns(sys) + ps = parameters(sys) + u0map = to_varmap(u0, dvs) + pmap = to_varmap(p, ps) + guesses = Dict() + if SciMLBase.has_initializeprob(odefn) + oldsys = odefn.initializeprob.f.sys + meta = get_metadata(oldsys) + if meta isa InitializationSystemMetadata + u0map = merge(meta.u0map, u0map) + pmap = merge(meta.pmap, pmap) + merge!(guesses, meta.additional_guesses) + end end if t0 === nothing t0 = 0.0 end - u0 = todict(u0) - defs = defaults(sys) - varmap = merge(defs, u0) - for k in collect(keys(varmap)) - if varmap[k] === nothing - delete!(varmap, k) - end - end - varmap = canonicalize_varmap(varmap) - missingvars = setdiff(unknowns(sys), collect(keys(varmap))) - setobserved = filter(keys(varmap)) do var - has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var)) - end - p = todict(p) - guesses = ModelingToolkit.guesses(sys) - solvablepars = [par - for par in parameters(sys) - if is_parameter_solvable(par, p, defs, guesses)] - pvarmap = merge(defs, p) - setparobserved = filter(keys(pvarmap)) do var - has_parameter_dependency_with_lhs(sys, var) - end - if (((!isempty(missingvars) || !isempty(solvablepars) || - !isempty(setobserved) || !isempty(setparobserved)) && - ModelingToolkit.get_tearing_state(sys) !== nothing) || - !isempty(initialization_equations(sys))) - if SciMLBase.has_initializeprob(odefn) - oldsys = odefn.initializeprob.f.sys - meta = get_metadata(oldsys) - if meta isa InitializationSystemMetadata - u0 = merge(meta.u0map, u0) - p = merge(meta.pmap, p) - end - end - for k in collect(keys(u0)) - if u0[k] === nothing - delete!(u0, k) - end - end - for k in collect(keys(p)) - if p[k] === nothing - delete!(p, k) - end - end - - initprob = InitializationProblem(sys, t0, u0, p) - initprobmap = getu(initprob, unknowns(sys)) - punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)] - getpunknowns = getu(initprob, punknowns) - setpunknowns = setp(sys, punknowns) - initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns) - reqd_syms = parameter_symbols(initprob) - update_initializeprob! = UpdateInitializeprob( - getu(sys, reqd_syms), setu(initprob, reqd_syms)) - return initprob, update_initializeprob!, initprobmap, initprobpmap - else - return nothing, nothing, nothing, nothing - end + filter_missing_values!(u0map) + filter_missing_values!(pmap) + f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0) + kws = f.kwargs + return get(kws, :initializeprob, nothing), get(kws, :update_initializeprob!, nothing), get(kws, :initializeprobmap, nothing), + get(kws, :initializeprobpmap, nothing) end """ diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 1cfb8eb40a..9521df7872 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -4,12 +4,13 @@ const AnyDict = Dict{Any, Any} $(TYPEDSIGNATURES) If called without arguments, return `Dict{Any, Any}`. Otherwise, interpret the input -as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters` -and `nothing`. +as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`, +`missing` and `nothing`. """ anydict() = AnyDict() anydict(::SciMLBase.NullParameters) = AnyDict() anydict(::Nothing) = AnyDict() +anydict(::Missing) = AnyDict() anydict(x::AnyDict) = x anydict(x) = AnyDict(x) @@ -388,6 +389,15 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100) end end +""" + $(TYPEDSIGNATURES) + +Remove keys in `varmap` whose values are `nothing`. +""" +function filter_missing_values!(varmap::AbstractDict) + filter!(kvp -> kvp[2] !== nothing, varmap) +end + struct GetUpdatedMTKParameters{G, S} # `getu` functor which gets parameters that are unknowns during initialization getpunknowns::G diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 770f9f12bd..edb4199356 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -975,3 +975,26 @@ end @test integ.ps[p] ≈ 1.0 @test integ.ps[q]≈cbrt(2) rtol=1e-6 end + +@testset "Guesses provided to `ODEProblem` are used in `remake`" begin + @variables x(t) y(t)=2x + @parameters p q=3x + @mtkbuild sys = ODESystem([D(x) ~ x * p + q, x^3 + y^3 ~ 3], t) + prob = ODEProblem( + sys, [], (0.0, 1.0), [p => 1.0]; guesses = [x => 1.0, y => 1.0, q => 1.0]) + @test prob[x] == 0.0 + @test prob[y] == 0.0 + @test prob.ps[p] == 1.0 + @test prob.ps[q] == 0.0 + integ = init(prob) + @test integ[x] ≈ 1 / cbrt(3) + @test integ[y] ≈ 2 / cbrt(3) + @test integ.ps[p] == 1.0 + @test integ.ps[q] ≈ 3 / cbrt(3) + prob2 = remake(prob; u0 = [y => 3x], p = [q => 2x]) + integ2 = init(prob2) + @test integ2[x] ≈ cbrt(3 / 28) + @test integ2[y] ≈ 3cbrt(3 / 28) + @test integ2.ps[p] == 1.0 + @test integ2.ps[q] ≈ 2cbrt(3 / 28) +end From 2641fd8a97cfc785d340e2fa7cea80a3a3815182 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 22 Nov 2024 14:04:13 +0530 Subject: [PATCH 045/101] fix: handle initial values passed as `Symbol`s --- src/systems/problem_utils.jl | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 9521df7872..b837eef98e 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -52,6 +52,42 @@ function add_toterms(varmap::AbstractDict; toterm = default_toterm) return cp end +""" + $(TYPEDSIGNATURES) + +Turn any `Symbol` keys in `varmap` to the appropriate symbolic variables in `sys`. Any +symbols that cannot be converted are ignored. +""" +function symbols_to_symbolics!(sys::AbstractSystem, varmap::AbstractDict) + if is_split(sys) + ic = get_index_cache(sys) + for k in collect(keys(varmap)) + k isa Symbol || continue + newk = get(ic.symbol_to_variable, k, nothing) + newk === nothing && continue + varmap[newk] = varmap[k] + delete!(varmap, k) + end + else + syms = all_symbols(sys) + for k in collect(keys(varmap)) + k isa Symbol || continue + idx = findfirst(syms) do sym + hasname(sym) || return false + name = getname(sym) + return name == k + end + idx === nothing && continue + newk = syms[idx] + if iscall(newk) && operation(newk) === getindex + newk = arguments(newk)[1] + end + varmap[newk] = varmap[k] + delete!(varmap, k) + end + end +end + """ $(TYPEDSIGNATURES) @@ -530,8 +566,10 @@ function process_SciMLProblem( pType = typeof(pmap) _u0map = u0map u0map = to_varmap(u0map, dvs) + symbols_to_symbolics!(sys, u0map) _pmap = pmap pmap = to_varmap(pmap, ps) + symbols_to_symbolics!(sys, pmap) defs = add_toterms(recursive_unwrap(defaults(sys))) cmap, cs = get_cmap(sys) kwargs = NamedTuple(kwargs) From 26980151eacd8d3b1977f544f4a284022ce71ad9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 22 Nov 2024 14:04:34 +0530 Subject: [PATCH 046/101] fix: handle `remake` with no pre-existing initializeprob --- src/systems/nonlinear/initializesystem.jl | 58 ++++++++++++++++++----- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index fe41e44d84..4abb345822 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -210,17 +210,16 @@ function is_parameter_solvable(p, pmap, defs, guesses) _val1 === nothing && _val2 !== nothing)) && _val3 !== nothing end -function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) +function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newu0, newp) if u0 === missing && p === missing - return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap, - odefn.initializeprobpmap + return odefn.initialization_data end if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair) oldinitprob = odefn.initializeprob - if oldinitprob === nothing || !SciMLBase.has_sys(oldinitprob.f) || - !(oldinitprob.f.sys isa NonlinearSystem) - return oldinitprob, odefn.update_initializeprob!, odefn.initializeprobmap, - odefn.initializeprobpmap + oldinitprob === nothing && return nothing + if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem) + return SciMLBase.OverrideInitData(oldinitprob, odefn.update_initializeprob!, + odefn.initializeprobmap, odefn.initializeprobpmap) end pidxs = ParameterIndex[] pvals = [] @@ -262,14 +261,17 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals) end initprob = remake(oldinitprob; u0 = newu0, p = newp) - return initprob, odefn.update_initializeprob!, odefn.initializeprobmap, - odefn.initializeprobpmap + return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!, + odefn.initializeprobmap, odefn.initializeprobpmap) end dvs = unknowns(sys) ps = parameters(sys) u0map = to_varmap(u0, dvs) + symbols_to_symbolics!(sys, u0map) pmap = to_varmap(p, ps) + symbols_to_symbolics!(sys, pmap) guesses = Dict() + defs = defaults(sys) if SciMLBase.has_initializeprob(odefn) oldsys = odefn.initializeprob.f.sys meta = get_metadata(oldsys) @@ -278,6 +280,35 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) pmap = merge(meta.pmap, pmap) merge!(guesses, meta.additional_guesses) end + else + # there is no initializeprob, so the original problem construction + # had no solvable parameters and had the differential variables + # specified in `u0map`. + if u0 === missing + # the user didn't pass `u0` to `remake`, so they want to retain + # existing values. Fill the differential variables in `u0map`, + # initialization will either be elided or solve for the algebraic + # variables + diff_idxs = isdiffeq.(equations(sys)) + for i in eachindex(dvs) + diff_idxs[i] || continue + u0map[dvs[i]] = newu0[i] + end + end + if p === missing + # the user didn't pass `p` to `remake`, so they want to retain + # existing values. Fill all parameters in `pmap` so that none of + # them are solvable. + for p in ps + pmap[p] = getp(sys, p)(newp) + end + end + # all non-solvable parameters need values regardless + for p in ps + haskey(pmap, p) && continue + is_parameter_solvable(p, pmap, defs, guesses) && continue + pmap[p] = getp(sys, p)(newp) + end end if t0 === nothing t0 = 0.0 @@ -286,8 +317,13 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) filter_missing_values!(pmap) f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0) kws = f.kwargs - return get(kws, :initializeprob, nothing), get(kws, :update_initializeprob!, nothing), get(kws, :initializeprobmap, nothing), - get(kws, :initializeprobpmap, nothing) + initprob = get(kws, :initializeprob, nothing) + if initprob === nothing + return nothing + end + return SciMLBase.OverrideInitData(initprob, get(kws, :update_initializeprob!, nothing), + get(kws, :initializeprobmap, nothing), + get(kws, :initializeprobpmap, nothing)) end """ From 3383bcede7c841530a9d82f601ff70e40f32deca Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 22 Nov 2024 14:05:11 +0530 Subject: [PATCH 047/101] test: test `remake` without initializeprob and `Symbol` values --- test/initializationsystem.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index edb4199356..0b0dc42c1e 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -998,3 +998,37 @@ end @test integ2.ps[p] == 1.0 @test integ2.ps[q] ≈ 2cbrt(3 / 28) end + +@testset "Remake problem with no initializeprob" begin + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p [guess = 1.0] q [guess = 1.0] + @mtkbuild sys = ODESystem( + [D(x) ~ p * x + q * y, y ~ 2x], t; parameter_dependencies = [q ~ 2p]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0]) + @test prob.f.initialization_data === nothing + prob2 = remake(prob; u0 = [x => 2.0]) + @test prob2[x] == 2.0 + @test prob2.f.initialization_data === nothing + prob3 = remake(prob; u0 = [y => 2.0]) + @test prob3.f.initialization_data !== nothing + @test init(prob3)[x] ≈ 1.0 + prob4 = remake(prob; p = [p => 1.0]) + @test prob4.f.initialization_data === nothing + prob5 = remake(prob; p = [p => missing, q => 2.0]) + @test prob5.f.initialization_data !== nothing + @test init(prob5).ps[p] ≈ 1.0 +end + +@testset "Variables provided as symbols" begin + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p [guess = 1.0] q [guess = 1.0] + @mtkbuild sys = ODESystem( + [D(x) ~ p * x + q * y, y ~ 2x], t; parameter_dependencies = [q ~ 2p]) + prob = ODEProblem(sys, [:x => 1.0], (0.0, 1.0), [p => 1.0]) + @test prob.f.initialization_data === nothing + prob2 = remake(prob; u0 = [:x => 2.0]) + @test prob2.f.initialization_data === nothing + prob3 = remake(prob; u0 = [:y => 1.0]) + @test prob3.f.initialization_data !== nothing + @test init(prob3)[x] ≈ 0.5 +end From 83ed891cd7cf2319ffbef0c5bb870424de5ab6ff Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Nov 2024 23:37:12 +0530 Subject: [PATCH 048/101] build: bump SciMLBase compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 093c632c42..1621669c2c 100644 --- a/Project.toml +++ b/Project.toml @@ -126,7 +126,7 @@ REPL = "1" RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" -SciMLBase = "2.57.1" +SciMLBase = "2.64" SciMLStructures = "1.0" Serialization = "1" Setfield = "0.7, 0.8, 1" From 10cc9c16854fc277b88967bbcd709bf42a3b14d2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Nov 2024 23:40:22 +0530 Subject: [PATCH 049/101] refactor: format --- src/systems/optimization/optimizationsystem.jl | 3 ++- src/utils.jl | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 3425216b5f..43e9294dd3 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -168,7 +168,8 @@ function OptimizationSystem(objective; constraints = [], kwargs...) push!(new_ps, p) end end - return OptimizationSystem(objective, collect(allunknowns), collect(new_ps); constraints, kwargs...) + return OptimizationSystem( + objective, collect(allunknowns), collect(new_ps); constraints, kwargs...) end function flatten(sys::OptimizationSystem) diff --git a/src/utils.jl b/src/utils.jl index 1555cd624e..416efd8f2c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -569,7 +569,6 @@ function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differ return nothing end - function collect_var!(unknowns, parameters, var, iv; depth = 0) isequal(var, iv) && return nothing check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing From b063e6b7cc12747b6ffb7a474d5bbe06165867d5 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 13:18:11 +0530 Subject: [PATCH 050/101] refactor: separate out `resid_prototype` calculation --- src/systems/nonlinear/nonlinearsystem.jl | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 94b44e0a6f..1289388197 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -283,6 +283,16 @@ function hessian_sparsity(sys::NonlinearSystem) unknowns(sys)) for eq in equations(sys)] end +function calculate_resid_prototype(N, u0, p) + u0ElType = u0 === nothing ? Float64 : eltype(u0) + if SciMLStructures.isscimlstructure(p) + u0ElType = promote_type( + eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]), + u0ElType) + end + return zeros(u0ElType, N) +end + """ ```julia SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(sys), @@ -337,13 +347,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s if length(dvs) == length(equations(sys)) resid_prototype = nothing else - u0ElType = u0 === nothing ? Float64 : eltype(u0) - if SciMLStructures.isscimlstructure(p) - u0ElType = promote_type( - eltype(SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)[1]), - u0ElType) - end - resid_prototype = zeros(u0ElType, length(equations(sys))) + resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p) end NonlinearFunction{iip}(f, From c578da182393e8ba8efac181b796e9838d093680 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 13:18:33 +0530 Subject: [PATCH 051/101] fix: recalculate `resid_prototype` in `remake_initialization_data` --- src/systems/nonlinear/initializesystem.jl | 11 ++++++++++- test/initializationsystem.jl | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 4abb345822..726d171bd0 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -260,7 +260,16 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newp = remake_buffer( oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals) end - initprob = remake(oldinitprob; u0 = newu0, p = newp) + if oldinitprob.f.resid_prototype === nothing + newf = oldinitprob.f + else + newf = NonlinearFunction{ + SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}( + oldinitprob.f; + resid_prototype = calculate_resid_prototype( + length(oldinitprob.f.resid_prototype), newu0, newp)) + end + initprob = remake(oldinitprob; f = newf, u0 = newu0, p = newp) return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!, odefn.initializeprobmap, odefn.initializeprobpmap) end diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 0b0dc42c1e..f3015f7db0 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1032,3 +1032,20 @@ end @test prob3.f.initialization_data !== nothing @test init(prob3)[x] ≈ 0.5 end + +@testset "Issue#3246: type promotion with parameter dependent initialization_eqs" begin + @variables x(t)=1 y(t)=1 + @parameters a = 1 + @named sys = ODESystem([D(x) ~ 0, D(y) ~ x + a], t; initialization_eqs = [y ~ a]) + + ssys = structural_simplify(sys) + prob = ODEProblem(ssys, [], (0, 1), []) + + @test SciMLBase.successful_retcode(solve(prob)) + + seta = setsym_oop(prob, [a]) + (newu0, newp) = seta(prob, ForwardDiff.Dual{ForwardDiff.Tag{:tag, Float64}}.([1.0], 1)) + newprob = remake(prob, u0 = newu0, p = newp) + + @test SciMLBase.successful_retcode(solve(newprob)) +end From 44a60c5c171f312f7ee0afe7bc3acad62ad0714a Mon Sep 17 00:00:00 2001 From: ArnoStrouwen Date: Tue, 3 Dec 2024 01:09:00 +0100 Subject: [PATCH 052/101] update higher order documentation to modern MTK --- docs/src/examples/higher_order.md | 55 ++++++++++++++++--------------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/docs/src/examples/higher_order.md b/docs/src/examples/higher_order.md index 7dafe758dc..fac707525f 100644 --- a/docs/src/examples/higher_order.md +++ b/docs/src/examples/higher_order.md @@ -3,7 +3,7 @@ ModelingToolkit has a system for transformations of mathematical systems. These transformations allow for symbolically changing the representation of the model to problems that are easier to -numerically solve. One simple to demonstrate transformation is the +numerically solve. One simple to demonstrate transformation, is `structural_simplify`, which does a lot of tricks, one being the transformation that turns an Nth order ODE into N coupled 1st order ODEs. @@ -15,16 +15,28 @@ We utilize the derivative operator twice here to define the second order: using ModelingToolkit, OrdinaryDiffEq using ModelingToolkit: t_nounits as t, D_nounits as D -@parameters σ ρ β -@variables x(t) y(t) z(t) - -eqs = [D(D(x)) ~ σ * (y - x), - D(y) ~ x * (ρ - z) - y, - D(z) ~ x * y - β * z] - -@named sys = ODESystem(eqs, t) +@mtkmodel SECOND_ORDER begin + @parameters begin + σ = 28.0 + ρ = 10.0 + β = 8 / 3 + end + @variables begin + x(t) = 1.0 + y(t) = 0.0 + z(t) = 0.0 + end + @equations begin + D(D(x)) ~ σ * (y - x) + D(y) ~ x * (ρ - z) - y + D(z) ~ x * y - β * z + end +end +@mtkbuild sys = SECOND_ORDER() ``` +The second order ODE has been automatically transformed to two first order ODEs. + Note that we could've used an alternative syntax for 2nd order, i.e. `D = Differential(t)^2` and then `D(x)` would be the second derivative, and this syntax extends to `N`-th order. Also, we can use `*` or `∘` to compose @@ -33,28 +45,17 @@ and this syntax extends to `N`-th order. Also, we can use `*` or `∘` to compos Now let's transform this into the `ODESystem` of first order components. We do this by calling `structural_simplify`: -```@example orderlowering -sys = structural_simplify(sys) -``` - Now we can directly numerically solve the lowered system. Note that, following the original problem, the solution requires knowing the -initial condition for `x'`, and thus we include that in our input -specification: +initial condition for both `x` and `D(x)`. +The former already got assigned a default value in the `@mtkmodel`, +but we still have to provide a value for the latter. ```@example orderlowering -u0 = [D(x) => 2.0, - x => 1.0, - y => 0.0, - z => 0.0] - -p = [σ => 28.0, - ρ => 10.0, - β => 8 / 3] - +u0 = [D(sys.x) => 2.0] tspan = (0.0, 100.0) -prob = ODEProblem(sys, u0, tspan, p, jac = true) +prob = ODEProblem(sys, u0, tspan, [], jac = true) sol = solve(prob, Tsit5()) -using Plots; -plot(sol, idxs = (x, y)); +using Plots +plot(sol, idxs = (sys.x, sys.y)) ``` From c11e76a6bdeee93ffbfacd144356026ee6589eb9 Mon Sep 17 00:00:00 2001 From: ArnoStrouwen Date: Wed, 4 Dec 2024 01:32:46 +0100 Subject: [PATCH 053/101] add Lagrangian explanation to DAE reduction tutorial. --- docs/src/examples/modelingtoolkitize_index_reduction.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/docs/src/examples/modelingtoolkitize_index_reduction.md b/docs/src/examples/modelingtoolkitize_index_reduction.md index 415d5b85ff..8686fd60d4 100644 --- a/docs/src/examples/modelingtoolkitize_index_reduction.md +++ b/docs/src/examples/modelingtoolkitize_index_reduction.md @@ -51,6 +51,14 @@ In this tutorial, we will look at the pendulum system: \end{aligned} ``` +These equations can be derived using the [Lagrangian equation of the first kind.](https://en.wikipedia.org/wiki/Lagrangian_mechanics#Lagrangian) +Specifically, for a pendulum with unit mass and length $L$, which thus has +kinetic energy $\frac{1}{2}(v_x^2 + v_y^2)$, +potential energy $gy$, +and holonomic constraint $x^2 + y^2 - L^2 = 0$. +The Lagrange multiplier related to this constraint is equal to half of $T$, +and represents the tension in the rope of the pendulum. + As a good DifferentialEquations.jl user, one would follow [the mass matrix DAE tutorial](https://docs.sciml.ai/DiffEqDocs/stable/tutorials/dae_example/#Mass-Matrix-Differential-Algebraic-Equations-(DAEs)) to arrive at code for simulating the model: From 56a9bb33b582aff924d8b1a30c140f2f3382ca40 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 17:18:07 +0530 Subject: [PATCH 054/101] fix: retain system data on `structural_simplify` of `SDESystem` --- src/systems/systems.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index a54206d1dd..862718968d 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -154,8 +154,9 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal end noise_eqs = StructuralTransformations.tearing_substitute_expr(ode_sys, noise_eqs) - return SDESystem(full_equations(ode_sys), noise_eqs, + return SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs, get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys); - name = nameof(ode_sys), is_scalar_noise, observed = observed(ode_sys)) + name = nameof(ode_sys), is_scalar_noise, observed = observed(ode_sys), defaults = defaults(sys), + parameter_dependencies = parameter_dependencies(sys)) end end From 3b63f826b4b69838cff4e12d0a6146c32e614f40 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 13:18:01 +0530 Subject: [PATCH 055/101] test: test observed equations are retained after simplifying `SDESystem` --- test/dde.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/test/dde.jl b/test/dde.jl index 2030a90d06..c7561e6c24 100644 --- a/test/dde.jl +++ b/test/dde.jl @@ -76,12 +76,13 @@ prob = SDDEProblem(hayes_modelf, hayes_modelg, [1.0], h, tspan, pmul; constant_lags = (pmul[1],)); sol = solve(prob, RKMil(), seed = 100) -@variables x(..) +@variables x(..) delx(t) @parameters a=-4.0 b=-2.0 c=10.0 α=-1.3 β=-1.2 γ=1.1 @brownian η τ = 1.0 -eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c + (α * x(t) + γ) * η] +eqs = [D(x(t)) ~ a * x(t) + b * x(t - τ) + c + (α * x(t) + γ) * η, delx ~ x(t - τ)] @mtkbuild sys = System(eqs, t) +@test ModelingToolkit.has_observed_with_lhs(sys, delx) @test ModelingToolkit.is_dde(sys) @test !is_markovian(sys) @test equations(sys) == [D(x(t)) ~ a * x(t) + b * x(t - τ) + c] From 9c1bed9f22aeb0c1e490871509cc5e3361e5b1f7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 14:17:55 +0530 Subject: [PATCH 056/101] refactor: format --- docs/src/examples/modelingtoolkitize_index_reduction.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/examples/modelingtoolkitize_index_reduction.md b/docs/src/examples/modelingtoolkitize_index_reduction.md index 8686fd60d4..b19ea46701 100644 --- a/docs/src/examples/modelingtoolkitize_index_reduction.md +++ b/docs/src/examples/modelingtoolkitize_index_reduction.md @@ -56,7 +56,7 @@ Specifically, for a pendulum with unit mass and length $L$, which thus has kinetic energy $\frac{1}{2}(v_x^2 + v_y^2)$, potential energy $gy$, and holonomic constraint $x^2 + y^2 - L^2 = 0$. -The Lagrange multiplier related to this constraint is equal to half of $T$, +The Lagrange multiplier related to this constraint is equal to half of $T$, and represents the tension in the rope of the pendulum. As a good DifferentialEquations.jl user, one would follow From 565f02ada5ec6ee30bdf49f104067e2a44a52c51 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 15:43:33 +0530 Subject: [PATCH 057/101] build: bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 1621669c2c..b1f5701d13 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ModelingToolkit" uuid = "961ee093-0014-501f-94e3-6117800e7a78" authors = ["Yingbo Ma ", "Chris Rackauckas and contributors"] -version = "9.54.0" +version = "9.55.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" From 17359ce72a9c690ce53b809c041888038cd3f468 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 15 Nov 2024 18:49:53 +0530 Subject: [PATCH 058/101] feat: initial implementation of `SCCNonlinearProblem` codegen --- .../bipartite_tearing/modia_tearing.jl | 14 +- src/systems/abstractsystem.jl | 190 ++++++------------ src/systems/nonlinear/nonlinearsystem.jl | 111 ++++++++++ src/systems/parameter_buffer.jl | 35 +++- src/utils.jl | 47 +++++ 5 files changed, 250 insertions(+), 147 deletions(-) diff --git a/src/structural_transformation/bipartite_tearing/modia_tearing.jl b/src/structural_transformation/bipartite_tearing/modia_tearing.jl index cef2f5f6d7..5da873afdf 100644 --- a/src/structural_transformation/bipartite_tearing/modia_tearing.jl +++ b/src/structural_transformation/bipartite_tearing/modia_tearing.jl @@ -62,6 +62,15 @@ function tear_graph_block_modia!(var_eq_matching, ict, solvable_graph, eqs, vars return nothing end +function build_var_eq_matching(structure::SystemStructure, ::Type{U} = Unassigned; + varfilter::F2 = v -> true, eqfilter::F3 = eq -> true) where {U, F2, F3} + @unpack graph, solvable_graph = structure + var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U) + matching_len = max(length(var_eq_matching), + maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0)) + return complete(var_eq_matching, matching_len), matching_len +end + function tear_graph_modia(structure::SystemStructure, isder::F = nothing, ::Type{U} = Unassigned; varfilter::F2 = v -> true, @@ -78,10 +87,7 @@ function tear_graph_modia(structure::SystemStructure, isder::F = nothing, # find them here [TODO: It would be good to have an explicit example of this.] @unpack graph, solvable_graph = structure - var_eq_matching = maximal_matching(graph, eqfilter, varfilter, U) - matching_len = max(length(var_eq_matching), - maximum(x -> x isa Int ? x : 0, var_eq_matching, init = 0)) - var_eq_matching = complete(var_eq_matching, matching_len) + var_eq_matching, matching_len = build_var_eq_matching(structure, U; varfilter, eqfilter) full_var_eq_matching = copy(var_eq_matching) var_sccs = find_var_sccs(graph, var_eq_matching) vargraph = DiCMOBiGraph{true}(graph, 0, Matching(matching_len)) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 3a9bbd33e1..60ed2fa1ce 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -162,11 +162,12 @@ object. """ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys), ps = parameters(sys); wrap_code = nothing, postprocess_fbody = nothing, states = nothing, - expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, kwargs...) + expression = Val{true}, eval_expression = false, eval_module = @__MODULE__, + cachesyms::Tuple = (), kwargs...) if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system.") end - p = reorder_parameters(sys, unwrap.(ps)) + p = (reorder_parameters(sys, unwrap.(ps))..., cachesyms...) isscalar = !(exprs isa AbstractArray) if wrap_code === nothing wrap_code = isscalar ? identity : (identity, identity) @@ -187,7 +188,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys postprocess_fbody, states, wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘ - wrap_array_vars(sys, exprs; dvs) .∘ + wrap_array_vars(sys, exprs; dvs, cachesyms) .∘ wrap_parameter_dependencies(sys, isscalar), expression = Val{true} ) @@ -199,7 +200,7 @@ function generate_custom_function(sys::AbstractSystem, exprs, dvs = unknowns(sys postprocess_fbody, states, wrap_code = wrap_code .∘ wrap_mtkparameters(sys, isscalar) .∘ - wrap_array_vars(sys, exprs; dvs) .∘ + wrap_array_vars(sys, exprs; dvs, cachesyms) .∘ wrap_parameter_dependencies(sys, isscalar), expression = Val{true} ) @@ -231,116 +232,51 @@ end function wrap_array_vars( sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), - inputs = nothing, history = false) + inputs = nothing, history = false, cachesyms::Tuple = ()) isscalar = !(exprs isa AbstractArray) - array_vars = Dict{Any, AbstractArray{Int}}() - if dvs !== nothing - for (j, x) in enumerate(dvs) - if iscall(x) && operation(x) == getindex - arg = arguments(x)[1] - inds = get!(() -> Int[], array_vars, arg) - push!(inds, j) - end - end - for (k, inds) in array_vars - if inds == (inds′ = inds[1]:inds[end]) - array_vars[k] = inds′ - end - end + var_to_arridxs = Dict() - uind = 1 - else + if dvs === nothing uind = 0 - end - # values are (indexes, index of buffer, size of parameter) - array_parameters = Dict{Any, Tuple{AbstractArray{Int}, Int, Tuple{Vararg{Int}}}}() - # If for some reason different elements of an array parameter are in different buffers - other_array_parameters = Dict{Any, Any}() - - hasinputs = inputs !== nothing - input_vars = Dict{Any, AbstractArray{Int}}() - if hasinputs - for (j, x) in enumerate(inputs) - if iscall(x) && operation(x) == getindex - arg = arguments(x)[1] - inds = get!(() -> Int[], input_vars, arg) - push!(inds, j) - end - end - for (k, inds) in input_vars - if inds == (inds′ = inds[1]:inds[end]) - input_vars[k] = inds′ - end - end - end - if has_index_cache(sys) - ic = get_index_cache(sys) else - ic = nothing - end - if ps isa Tuple && eltype(ps) <: AbstractArray - ps = Iterators.flatten(ps) - end - for p in ps - p = unwrap(p) - if iscall(p) && operation(p) == getindex - p = arguments(p)[1] - end - symtype(p) <: AbstractArray && Symbolics.shape(p) != Symbolics.Unknown() || continue - scal = collect(p) - # all scalarized variables are in `ps` - any(isequal(p), ps) || all(x -> any(isequal(x), ps), scal) || continue - (haskey(array_parameters, p) || haskey(other_array_parameters, p)) && continue - - idx = parameter_index(sys, p) - idx isa Int && continue - if idx isa ParameterIndex - if idx.portion != SciMLStructures.Tunable() + uind = 1 + for (i, x) in enumerate(dvs) + iscall(x) && operation(x) == getindex || continue + arg = arguments(x)[1] + inds = get!(() -> [], var_to_arridxs, arg) + push!(inds, (uind, i)) + end + end + p_start = uind + 1 + (inputs !== nothing) + history + input_ind = inputs === nothing ? -1 : (p_start - 1) + rps = (reorder_parameters(sys, ps)..., cachesyms...) + for sym in reduce(vcat, rps; init = []) + iscall(sym) && operation(sym) == getindex || continue + arg = arguments(sym)[1] + if inputs !== nothing + idx = findfirst(isequal(sym), inputs) + if idx !== nothing + inds = get!(() -> [], var_to_arridxs, arg) + push!(inds, (input_ind, idx)) continue end - array_parameters[p] = (vec(idx.idx), 1, size(idx.idx)) - else - # idx === nothing - idxs = map(Base.Fix1(parameter_index, sys), scal) - if first(idxs) isa ParameterIndex - buffer_idxs = map(Base.Fix1(iterated_buffer_index, ic), idxs) - if allequal(buffer_idxs) - buffer_idx = first(buffer_idxs) - if first(idxs).portion == SciMLStructures.Tunable() - idxs = map(x -> x.idx, idxs) - else - idxs = map(x -> x.idx[end], idxs) - end - else - other_array_parameters[p] = scal - continue - end - else - buffer_idx = 1 - end - - sz = size(idxs) - if vec(idxs) == idxs[begin]:idxs[end] - idxs = idxs[begin]:idxs[end] - elseif vec(idxs) == idxs[begin]:-1:idxs[end] - idxs = idxs[begin]:-1:idxs[end] - end - idxs = vec(idxs) - array_parameters[p] = (idxs, buffer_idx, sz) end + bufferidx = findfirst(buf -> any(isequal(sym), buf), rps) + idxinbuffer = findfirst(isequal(sym), rps[bufferidx]) + inds = get!(() -> [], var_to_arridxs, arg) + push!(inds, (p_start + bufferidx - 1, idxinbuffer)) end - inputind = if history - uind + 2 - else - uind + 1 - end - params_offset = if history && hasinputs - uind + 2 - elseif history || hasinputs - uind + 1 - else - uind + viewsyms = Dict() + splitsyms = Dict() + for (arrsym, idxs) in var_to_arridxs + length(idxs) == length(arrsym) || continue + # allequal(first, idxs) is a 1.11 feature + if allequal(Iterators.map(first, idxs)) + viewsyms[arrsym] = (first(first(idxs)), reshape(last.(idxs), size(arrsym))) + else + splitsyms[arrsym] = reshape(idxs, size(arrsym)) + end end if isscalar function (expr) @@ -349,15 +285,11 @@ function wrap_array_vars( [], Let( vcat( - [k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars], - [k ← :(view($(expr.args[inputind].name), $v)) - for (k, v) in input_vars], - [k ← :(reshape( - view($(expr.args[params_offset + buffer_idx].name), $idxs), - $sz)) - for (k, (idxs, buffer_idx, sz)) in array_parameters], - [k ← Code.MakeArray(v, symtype(k)) - for (k, v) in other_array_parameters] + [sym ← :(view($(expr.args[i].name), $idxs)) + for (sym, (i, idxs)) in viewsyms], + [sym ← + MakeArray([expr.args[bufi].elems[vali] for (bufi, vali) in idxs], + expr.args[idxs[1][1]]) for (sym, idxs) in splitsyms] ), expr.body, false @@ -371,15 +303,11 @@ function wrap_array_vars( [], Let( vcat( - [k ← :(view($(expr.args[uind].name), $v)) for (k, v) in array_vars], - [k ← :(view($(expr.args[inputind].name), $v)) - for (k, v) in input_vars], - [k ← :(reshape( - view($(expr.args[params_offset + buffer_idx].name), $idxs), - $sz)) - for (k, (idxs, buffer_idx, sz)) in array_parameters], - [k ← Code.MakeArray(v, symtype(k)) - for (k, v) in other_array_parameters] + [sym ← :(view($(expr.args[i].name), $idxs)) + for (sym, (i, idxs)) in viewsyms], + [sym ← + MakeArray([expr.args[bufi].elems[vali] for (bufi, vali) in idxs], + expr.args[idxs[1][1]]) for (sym, idxs) in splitsyms] ), expr.body, false @@ -392,17 +320,11 @@ function wrap_array_vars( [], Let( vcat( - [k ← :(view($(expr.args[uind + 1].name), $v)) - for (k, v) in array_vars], - [k ← :(view($(expr.args[inputind + 1].name), $v)) - for (k, v) in input_vars], - [k ← :(reshape( - view($(expr.args[params_offset + buffer_idx + 1].name), - $idxs), - $sz)) - for (k, (idxs, buffer_idx, sz)) in array_parameters], - [k ← Code.MakeArray(v, symtype(k)) - for (k, v) in other_array_parameters] + [sym ← :(view($(expr.args[i + 1].name), $idxs)) + for (sym, (i, idxs)) in viewsyms], + [sym ← MakeArray( + [expr.args[bufi + 1].elems[vali] for (bufi, vali) in idxs], + expr.args[idxs[1][1] + 1]) for (sym, idxs) in splitsyms] ), expr.body, false diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 1289388197..a7028b3a1e 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -535,6 +535,117 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...) end +struct CacheWriter{F} + fn::F +end + +function (cw::CacheWriter)(p, sols) + cw.fn(p.caches[1], sols, p...) +end + +function CacheWriter(sys::AbstractSystem, exprs, solsyms; + eval_expression = false, eval_module = @__MODULE__) + ps = parameters(sys) + rps = reorder_parameters(sys, ps) + fn = Func( + [:out, DestructuredArgs(DestructuredArgs.(solsyms)), + DestructuredArgs.(rps)...], + [], + SetArray(true, :out, exprs) + ) |> wrap_parameter_dependencies(sys, false)[2] |> + wrap_array_vars(sys, exprs; dvs = nothing)[2] |> toexpr + return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module)) +end + +struct SCCNonlinearFunction{iip} end + +function SCCNonlinearFunction{iip}( + sys::NonlinearSystem, vscc, escc, cachesyms; eval_expression = false, + eval_module = @__MODULE__, kwargs...) where {iip} + dvs = unknowns(sys) + ps = parameters(sys) + rps = reorder_parameters(sys, ps) + eqs = equations(sys) + obs = observed(sys) + + _dvs = dvs[vscc] + _eqs = eqs[escc] + obsidxs = observed_equations_used_by(sys, _eqs) + _obs = obs[obsidxs] + + cmap, cs = get_cmap(sys) + assignments = [eq.lhs ← eq.rhs for eq in cmap] + rhss = [eq.rhs - eq.lhs for eq in _eqs] + wrap_code = wrap_assignments(false, assignments) .∘ + (wrap_array_vars(sys, rhss; dvs = _dvs, cachesyms)) .∘ + wrap_parameter_dependencies(sys, false) + @show _dvs + f_gen = build_function( + rhss, _dvs, ps..., cachesyms...; wrap_code, expression = Val{true}) + f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) + + f(u, p) = f_oop(u, p) + f(u, p::MTKParameters) = f_oop(u, p...) + f(resid, u, p) = f_iip(resid, u, p) + f(resid, u, p::MTKParameters) = f_iip(resid, u, p...) + + return NonlinearFunction{iip}(f) +end + +function SciMLBase.SCCNonlinearProblem(sys::NonlinearSystem, args...; kwargs...) + SCCNonlinearProblem{true}(sys, args...; kwargs...) +end + +function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, + parammap = SciMLBase.NullParameters(); eval_expression = false, eval_module = @__MODULE__, kwargs...) where {iip} + if !iscomplete(sys) || get_tearing_state(sys) === nothing + error("A simplified `NonlinearSystem` is required. Call `structural_simplify` on the system before creating an `SCCNonlinearProblem`.") + end + + if !is_split(sys) + error("The system has been simplified with `split = false`. `SCCNonlinearProblem` is not compatible with this system. Pass `split = true` to `structural_simplify` to use `SCCNonlinearProblem`.") + end + + ts = get_tearing_state(sys) + var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts) + condensed_graph = StructuralTransformations.MatchedCondensationGraph( + StructuralTransformations.DiCMOBiGraph{true}(ts.structure.graph, var_eq_matching), + var_sccs) + toporder = topological_sort_by_dfs(condensed_graph) + var_sccs = var_sccs[toporder] + eq_sccs = map(Base.Fix1(getindex, var_eq_matching), var_sccs) + + dvs = unknowns(sys) + ps = parameters(sys) + eqs = equations(sys) + obs = observed(sys) + + _, u0, p = process_SciMLProblem( + EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...) + p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(u0))) + + subprobs = [] + explicitfuns = [] + for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs)) + oldvars = dvs[reduce(vcat, view(var_sccs, 1:(i - 1)); init = Int[])] + if isempty(oldvars) + push!(explicitfuns, (_...) -> nothing) + else + solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1))) + push!(explicitfuns, + CacheWriter(sys, oldvars, solsyms; eval_expression, eval_module)) + end + prob = NonlinearProblem( + SCCNonlinearFunction{iip}( + sys, vscc, escc, (oldvars,); eval_expression, eval_module, kwargs...), + u0[vscc], + p) + push!(subprobs, prob) + end + + return SCCNonlinearProblem(subprobs, explicitfuns) +end + """ $(TYPEDSIGNATURES) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 72af9f594d..7d64054acc 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -3,11 +3,12 @@ symconvert(::Type{T}, x) where {T} = convert(T, x) symconvert(::Type{Real}, x::Integer) = convert(Float64, x) symconvert(::Type{V}, x) where {V <: AbstractArray} = convert(V, symconvert.(eltype(V), x)) -struct MTKParameters{T, D, C, N} +struct MTKParameters{T, D, C, N, H} tunable::T discrete::D constant::C nonnumeric::N + caches::H end """ @@ -181,11 +182,18 @@ function MTKParameters( mtkps = MTKParameters{ typeof(tunable_buffer), typeof(disc_buffer), typeof(const_buffer), - typeof(nonnumeric_buffer)}(tunable_buffer, disc_buffer, const_buffer, - nonnumeric_buffer) + typeof(nonnumeric_buffer), typeof(())}(tunable_buffer, + disc_buffer, const_buffer, nonnumeric_buffer, ()) return mtkps end +function rebuild_with_caches(p::MTKParameters, cache_templates::BufferTemplate...) + buffers = map(cache_templates) do template + Vector{template.type}(undef, template.length) + end + @set p.caches = buffers +end + function narrow_buffer_type(buffer::AbstractArray) type = Union{} for x in buffer @@ -297,7 +305,8 @@ end for (Portion, field, recurse) in [(SciMLStructures.Discrete, :discrete, 1) (SciMLStructures.Constants, :constant, 1) - (Nonnumeric, :nonnumeric, 1)] + (Nonnumeric, :nonnumeric, 1) + (SciMLStructures.Caches, :caches, 1)] @eval function SciMLStructures.canonicalize(::$Portion, p::MTKParameters) as_vector = buffer_to_arraypartition(p.$field) repack = let p = p @@ -324,11 +333,13 @@ function Base.copy(p::MTKParameters) discrete = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.discrete) constant = Tuple(eltype(buf) <: Real ? copy(buf) : copy.(buf) for buf in p.constant) nonnumeric = copy.(p.nonnumeric) + caches = copy.(p.caches) return MTKParameters( tunable, discrete, constant, - nonnumeric + nonnumeric, + caches ) end @@ -640,7 +651,7 @@ end # getindex indexes the vectors, setindex! linearly indexes values # it's inconsistent, but we need it to be this way @generated function Base.getindex( - ps::MTKParameters{T, D, C, N}, idx::Int) where {T, D, C, N} + ps::MTKParameters{T, D, C, N, H}, idx::Int) where {T, D, C, N, H} paths = [] if !(T <: SizedVector{0, Float64}) push!(paths, :(ps.tunable)) @@ -654,6 +665,9 @@ end for i in 1:fieldcount(N) push!(paths, :(ps.nonnumeric[$i])) end + for i in 1:fieldcount(H) + push!(paths, :(ps.caches[$i])) + end expr = Expr(:if, :(idx == 1), :(return $(paths[1]))) curexpr = expr for i in 2:length(paths) @@ -663,12 +677,12 @@ end return Expr(:block, expr, :(throw(BoundsError(ps, idx)))) end -@generated function Base.length(ps::MTKParameters{T, D, C, N}) where {T, D, C, N} +@generated function Base.length(ps::MTKParameters{T, D, C, N, H}) where {T, D, C, N, H} len = 0 if !(T <: SizedVector{0, Float64}) len += 1 end - len += fieldcount(D) + fieldcount(C) + fieldcount(N) + len += fieldcount(D) + fieldcount(C) + fieldcount(N) + fieldcount(H) return len end @@ -691,7 +705,10 @@ end function Base.:(==)(a::MTKParameters, b::MTKParameters) return a.tunable == b.tunable && a.discrete == b.discrete && - a.constant == b.constant && a.nonnumeric == b.nonnumeric + a.constant == b.constant && a.nonnumeric == b.nonnumeric && + all(Iterators.map(a.caches, b.caches) do acache, bcache + eltype(acache) == eltype(bcache) && length(acache) == length(bcache) + end) end # to support linearize/linearization_function diff --git a/src/utils.jl b/src/utils.jl index 416efd8f2c..7be247429c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1007,3 +1007,50 @@ function is_variable_floatingpoint(sym) return T == Real || T <: AbstractFloat || T <: AbstractArray{Real} || T <: AbstractArray{<:AbstractFloat} end + +function observed_dependency_graph(eqs::Vector{Equation}) + for eq in eqs + if symbolic_type(eq.lhs) == NotSymbolic() + error("All equations must be observed equations of the form `var ~ expr`. Got $eq") + end + end + + idxmap = Dict(eq.lhs => i for (i, eq) in enumerate(eqs)) + g = SimpleDiGraph(length(eqs)) + + syms = Set() + for (i, eq) in enumerate(eqs) + vars!(syms, eq) + for sym in syms + idx = get(idxmap, sym, nothing) + idx === nothing && continue + add_edge!(g, i, idx) + end + end + + return g +end + +function observed_equations_used_by(sys::AbstractSystem, exprs) + obs = observed(sys) + + obsvars = getproperty.(obs, :lhs) + graph = observed_dependency_graph(obs) + + syms = vars(exprs) + + obsidxs = BitSet() + for sym in syms + idx = findfirst(isequal(sym), obsvars) + idx === nothing && continue + parents = dfs_parents(graph, idx) + for i in eachindex(parents) + parents[i] == 0 && continue + push!(obsidxs, i) + end + end + + obsidxs = collect(obsidxs) + sort!(obsidxs) + return obsidxs +end From 7d3b3f41dc7cd9cd746b2cdef7ca24896ba2e6ed Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 18 Nov 2024 15:35:49 +0530 Subject: [PATCH 059/101] refactor: no need to re-sort SCCs --- src/systems/nonlinear/nonlinearsystem.jl | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index a7028b3a1e..68758de169 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -579,9 +579,8 @@ function SCCNonlinearFunction{iip}( wrap_code = wrap_assignments(false, assignments) .∘ (wrap_array_vars(sys, rhss; dvs = _dvs, cachesyms)) .∘ wrap_parameter_dependencies(sys, false) - @show _dvs f_gen = build_function( - rhss, _dvs, ps..., cachesyms...; wrap_code, expression = Val{true}) + rhss, _dvs, rps..., cachesyms...; wrap_code, expression = Val{true}) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) f(u, p) = f_oop(u, p) @@ -608,11 +607,9 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, ts = get_tearing_state(sys) var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts) - condensed_graph = StructuralTransformations.MatchedCondensationGraph( - StructuralTransformations.DiCMOBiGraph{true}(ts.structure.graph, var_eq_matching), - var_sccs) - toporder = topological_sort_by_dfs(condensed_graph) - var_sccs = var_sccs[toporder] + # The system is simplified, so SCCs are already in sorted order. We just need to get them and sort + # according to index in unknowns(sys) + sort!(var_sccs) eq_sccs = map(Base.Fix1(getindex, var_eq_matching), var_sccs) dvs = unknowns(sys) From 3e4a648488e9bd1a1fc276798a5b1274d8463eac Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 18 Nov 2024 19:27:58 +0530 Subject: [PATCH 060/101] fix: minor bug fix --- src/systems/nonlinear/nonlinearsystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 68758de169..1742e83371 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -553,7 +553,7 @@ function CacheWriter(sys::AbstractSystem, exprs, solsyms; [], SetArray(true, :out, exprs) ) |> wrap_parameter_dependencies(sys, false)[2] |> - wrap_array_vars(sys, exprs; dvs = nothing)[2] |> toexpr + wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |> toexpr return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module)) end From b99ab7d12c897333d28c4fe3bd01c34103069065 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 18 Nov 2024 21:51:43 +0530 Subject: [PATCH 061/101] fix: fix observed equations not being generated --- src/systems/nonlinear/nonlinearsystem.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 1742e83371..9e6208ed10 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -572,13 +572,15 @@ function SCCNonlinearFunction{iip}( _eqs = eqs[escc] obsidxs = observed_equations_used_by(sys, _eqs) _obs = obs[obsidxs] + obs_assignments = [eq.lhs ← eq.rhs for eq in _obs] cmap, cs = get_cmap(sys) - assignments = [eq.lhs ← eq.rhs for eq in cmap] + cmap_assignments = [eq.lhs ← eq.rhs for eq in cmap] rhss = [eq.rhs - eq.lhs for eq in _eqs] - wrap_code = wrap_assignments(false, assignments) .∘ + wrap_code = wrap_assignments(false, cmap_assignments) .∘ (wrap_array_vars(sys, rhss; dvs = _dvs, cachesyms)) .∘ - wrap_parameter_dependencies(sys, false) + wrap_parameter_dependencies(sys, false) .∘ + wrap_assignments(false, obs_assignments) f_gen = build_function( rhss, _dvs, rps..., cachesyms...; wrap_code, expression = Val{true}) f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module) From 04c0cf883fe5714d32d3b04ecdc69c503318292b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 18 Nov 2024 21:52:02 +0530 Subject: [PATCH 062/101] test: add tests for `SCCNonlinearProblem` codegen --- test/runtests.jl | 1 + test/scc_nonlinear_problem.jl | 145 ++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+) create mode 100644 test/scc_nonlinear_problem.jl diff --git a/test/runtests.jl b/test/runtests.jl index 44846eed57..677f40c717 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -78,6 +78,7 @@ end @safetestset "SDESystem Test" include("sdesystem.jl") @safetestset "DDESystem Test" include("dde.jl") @safetestset "NonlinearSystem Test" include("nonlinearsystem.jl") + @safetestset "SCCNonlinearProblem Test" include("scc_nonlinear_problem.jl") @safetestset "PDE Construction Test" include("pde.jl") @safetestset "JumpSystem Test" include("jumpsystem.jl") @safetestset "print_tree" include("print_tree.jl") diff --git a/test/scc_nonlinear_problem.jl b/test/scc_nonlinear_problem.jl new file mode 100644 index 0000000000..c4368bf0b9 --- /dev/null +++ b/test/scc_nonlinear_problem.jl @@ -0,0 +1,145 @@ +using ModelingToolkit +using NonlinearSolve, SCCNonlinearSolve +using OrdinaryDiffEq +using SciMLBase, Symbolics +using LinearAlgebra, Test + +@testset "Trivial case" begin + function f!(du, u, p) + du[1] = cos(u[2]) - u[1] + du[2] = sin(u[1] + u[2]) + u[2] + du[3] = 2u[4] + u[3] + 1.0 + du[4] = u[5]^2 + u[4] + du[5] = u[3]^2 + u[5] + du[6] = u[1] + u[2] + u[3] + u[4] + u[5] + 2.0u[6] + 2.5u[7] + 1.5u[8] + du[7] = u[1] + u[2] + u[3] + 2.0u[4] + u[5] + 4.0u[6] - 1.5u[7] + 1.5u[8] + du[8] = u[1] + 2.0u[2] + 3.0u[3] + 5.0u[4] + 6.0u[5] + u[6] - u[7] - u[8] + end + @variables u[1:8] [irreducible = true] + eqs = Any[0 for _ in 1:8] + f!(eqs, u, nothing) + eqs = 0 .~ eqs + @named model = NonlinearSystem(eqs) + @test_throws ["simplified", "required"] SCCNonlinearProblem(model, []) + _model = structural_simplify(model; split = false) + @test_throws ["not compatible"] SCCNonlinearProblem(_model, []) + model = structural_simplify(model) + prob = NonlinearProblem(model, [u => zeros(8)]) + sccprob = SCCNonlinearProblem(model, [u => zeros(8)]) + sol1 = solve(prob, NewtonRaphson()) + sol2 = solve(sccprob, NewtonRaphson()) + @test SciMLBase.successful_retcode(sol1) + @test SciMLBase.successful_retcode(sol2) + @test sol1.u ≈ sol2.u +end + +@testset "With parameters" begin + function f!(du, u, (p1, p2), t) + x = (*)(p1[4], u[1]) + y = (*)(p1[4], (+)(0.1016, (*)(-1, u[1]))) + z1 = ifelse((<)(p2[1], 0), + (*)((*)(457896.07999999996, p1[2]), sqrt((*)(1.1686468413521012e-5, p1[3]))), + 0) + z2 = ifelse((>)(p2[1], 0), + (*)((*)((*)(0.58, p1[2]), sqrt((*)(1 // 86100, p1[3]))), u[4]), + 0) + z3 = ifelse((>)(p2[1], 0), + (*)((*)(457896.07999999996, p1[2]), sqrt((*)(1.1686468413521012e-5, p1[3]))), + 0) + z4 = ifelse((<)(p2[1], 0), + (*)((*)((*)(0.58, p1[2]), sqrt((*)(1 // 86100, p1[3]))), u[5]), + 0) + du[1] = p2[1] + du[2] = (+)(z1, (*)(-1, z2)) + du[3] = (+)(z3, (*)(-1, z4)) + du[4] = (+)((*)(-1, u[2]), (*)((*)(1 // 86100, y), u[4])) + du[5] = (+)((*)(-1, u[3]), (*)((*)(1 // 86100, x), u[5])) + end + p = ( + [0.04864391799335977, 7.853981633974484e-5, 1.4034843205574914, + 0.018241469247509915, 300237.05, 9.226186337232914], + [0.0508]) + u0 = [0.0, 0.0, 0.0, 789476.0, 101325.0] + tspan = (0.0, 1.0) + mass_matrix = [1.0 0.0 0.0 0.0 0.0; 0.0 1.0 0.0 0.0 0.0; 0.0 0.0 1.0 0.0 0.0; + 0.0 0.0 0.0 0.0 0.0; 0.0 0.0 0.0 0.0 0.0] + dt = 1e-3 + function nlf(u1, (u0, p)) + resid = Any[0 for _ in u0] + f!(resid, u1, p, 0.0) + return mass_matrix * (u1 - u0) - dt * resid + end + + prob = NonlinearProblem(nlf, u0, (u0, p)) + @test_throws Exception solve(prob, SimpleNewtonRaphson(), abstol = 1e-9) + sol = solve(prob, TrustRegion(); abstol = 1e-9) + + @variables u[1:5] [irreducible = true] + @parameters p1[1:6] p2 + eqs = 0 .~ collect(nlf(u, (u0, (p1, p2)))) + @mtkbuild sys = NonlinearSystem(eqs, [u], [p1, p2]) + sccprob = SCCNonlinearProblem(sys, [u => u0], [p1 => p[1], p2 => p[2][]]) + sccsol = solve(sccprob, SimpleNewtonRaphson(); abstol = 1e-9) + @test SciMLBase.successful_retcode(sccsol) + @test norm(sccsol.resid) < norm(sol.resid) +end + +@testset "Transistor amplifier" begin + C = [k * 1e-6 for k in 1:5] + Ub = 6 + UF = 0.026 + α = 0.99 + β = 1e-6 + R0 = 1000 + R = 9000 + Ue(t) = 0.1 * sin(200 * π * t) + + function transamp(out, du, u, p, t) + g(x) = 1e-6 * (exp(x / 0.026) - 1) + y1, y2, y3, y4, y5, y6, y7, y8 = u + out[1] = -Ue(t) / R0 + y1 / R0 + C[1] * du[1] - C[1] * du[2] + out[2] = -Ub / R + y2 * 2 / R - (α - 1) * g(y2 - y3) - C[1] * du[1] + C[1] * du[2] + out[3] = -g(y2 - y3) + y3 / R + C[2] * du[3] + out[4] = -Ub / R + y4 / R + α * g(y2 - y3) + C[3] * du[4] - C[3] * du[5] + out[5] = -Ub / R + y5 * 2 / R - (α - 1) * g(y5 - y6) - C[3] * du[4] + C[3] * du[5] + out[6] = -g(y5 - y6) + y6 / R + C[4] * du[6] + out[7] = -Ub / R + y7 / R + α * g(y5 - y6) + C[5] * du[7] - C[5] * du[8] + out[8] = y8 / R - C[5] * du[7] + C[5] * du[8] + end + + u0 = [0, Ub / 2, Ub / 2, Ub, Ub / 2, Ub / 2, Ub, 0] + du0 = [ + 51.338775, + 51.338775, + -Ub / (2 * (C[2] * R)), + -24.9757667, + -24.9757667, + -Ub / (2 * (C[4] * R)), + -10.00564453, + -10.00564453 + ] + daeprob = DAEProblem(transamp, du0, u0, (0.0, 0.1)) + daesol = solve(daeprob, DImplicitEuler()) + + t0 = daesol.t[5] + t1 = daesol.t[6] + u0 = daesol.u[5] + u1 = daesol.u[6] + dt = t1 - t0 + + @variables y(t)[1:8] + eqs = Any[0 for _ in 1:8] + transamp(eqs, collect(D(y)), y, nothing, t) + eqs = 0 .~ eqs + subrules = Dict(Symbolics.unwrap(D(y[i])) => ((y[i] - u0[i]) / dt) for i in 1:8) + eqs = substitute.(eqs, (subrules,)) + @mtkbuild sys = NonlinearSystem(eqs) + prob = NonlinearProblem(sys, [y => u0], [t => t0]) + sol = solve(prob, NewtonRaphson(); abstol = 1e-12) + + sccprob = SCCNonlinearProblem(sys, [y => u0], [t => t0]) + sccsol = solve(sccprob, NewtonRaphson(); abstol = 1e-12) + + @test sol.u≈sccsol.u atol=1e-10 +end + From 5639bd11c390974a37ead01cca46d4959680dee8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 19 Nov 2024 14:37:05 +0530 Subject: [PATCH 063/101] feat: pre-compute observed equations of previous SCCs --- src/systems/nonlinear/nonlinearsystem.jl | 53 +++++++++++++++--------- 1 file changed, 34 insertions(+), 19 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 9e6208ed10..b608fd642a 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -560,18 +560,11 @@ end struct SCCNonlinearFunction{iip} end function SCCNonlinearFunction{iip}( - sys::NonlinearSystem, vscc, escc, cachesyms; eval_expression = false, + sys::NonlinearSystem, _eqs, _dvs, _obs, cachesyms; eval_expression = false, eval_module = @__MODULE__, kwargs...) where {iip} - dvs = unknowns(sys) ps = parameters(sys) rps = reorder_parameters(sys, ps) - eqs = equations(sys) - obs = observed(sys) - _dvs = dvs[vscc] - _eqs = eqs[escc] - obsidxs = observed_equations_used_by(sys, _eqs) - _obs = obs[obsidxs] obs_assignments = [eq.lhs ← eq.rhs for eq in _obs] cmap, cs = get_cmap(sys) @@ -621,24 +614,46 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, _, u0, p = process_SciMLProblem( EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module, kwargs...) - p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(u0))) - subprobs = [] explicitfuns = [] + nlfuns = [] + prevobsidxs = Int[] + cachevars = [] + cacheexprs = [] for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs)) - oldvars = dvs[reduce(vcat, view(var_sccs, 1:(i - 1)); init = Int[])] - if isempty(oldvars) - push!(explicitfuns, (_...) -> nothing) + # subset unknowns and equations + _dvs = dvs[vscc] + _eqs = eqs[escc] + # get observed equations required by this SCC + obsidxs = observed_equations_used_by(sys, _eqs) + # the ones used by previous SCCs can be precomputed into the cache + setdiff!(obsidxs, prevobsidxs) + _obs = obs[obsidxs] + + if isempty(cachevars) + push!(explicitfuns, Returns(nothing)) else solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1))) push!(explicitfuns, - CacheWriter(sys, oldvars, solsyms; eval_expression, eval_module)) + CacheWriter(sys, cacheexprs, solsyms; eval_expression, eval_module)) + end + f = SCCNonlinearFunction{iip}( + sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...) + push!(nlfuns, f) + append!(cachevars, _dvs) + append!(cacheexprs, _dvs) + for i in obsidxs + push!(cachevars, obs[i].lhs) + push!(cacheexprs, obs[i].rhs) end - prob = NonlinearProblem( - SCCNonlinearFunction{iip}( - sys, vscc, escc, (oldvars,); eval_expression, eval_module, kwargs...), - u0[vscc], - p) + append!(prevobsidxs, obsidxs) + end + + p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(cachevars))) + + subprobs = [] + for (f, vscc) in zip(nlfuns, var_sccs) + prob = NonlinearProblem(f, u0[vscc], p) push!(subprobs, prob) end From 26269d974ebddd971e918b840a0d2d7fdf4f39b2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 19 Nov 2024 14:37:22 +0530 Subject: [PATCH 064/101] feat: subset system and pass to SCC problems --- src/systems/index_cache.jl | 25 ++++++++++++++++++++++++ src/systems/nonlinear/nonlinearsystem.jl | 8 +++++++- 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index ab0dd08764..aca6a0547d 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -594,3 +594,28 @@ function reorder_dimension_by_tunables( reorder_dimension_by_tunables!(buffer, sys, arr, syms; dim) return buffer end + +function subset_unknowns_observed( + ic::IndexCache, sys::AbstractSystem, newunknowns, newobsvars) + unknown_idx = copy(ic.unknown_idx) + empty!(unknown_idx) + for (i, sym) in enumerate(newunknowns) + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + unknown_idx[sym] = unknown_idx[ttsym] = unknown_idx[rsym] = unknown_idx[rttsym] = i + end + observed_syms_to_timeseries = copy(ic.observed_syms_to_timeseries) + empty!(observed_syms_to_timeseries) + for sym in newobsvars + ttsym = default_toterm(sym) + rsym = renamespace(sys, sym) + rttsym = renamespace(sys, ttsym) + for s in (sym, ttsym, rsym, rttsym) + observed_syms_to_timeseries[s] = ic.observed_syms_to_timeseries[sym] + end + end + ic = @set ic.unknown_idx = unknown_idx + @set! ic.observed_syms_to_timeseries = observed_syms_to_timeseries + return ic +end diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index b608fd642a..40e883222d 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -583,7 +583,13 @@ function SCCNonlinearFunction{iip}( f(resid, u, p) = f_iip(resid, u, p) f(resid, u, p::MTKParameters) = f_iip(resid, u, p...) - return NonlinearFunction{iip}(f) + subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs, parameter_dependencies = parameter_dependencies(sys), name = nameof(sys)) + if get_index_cache(sys) !== nothing + @set! subsys.index_cache = subset_unknowns_observed(get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,))) + @set! subsys.complete = true + end + + return NonlinearFunction{iip}(f; sys = subsys) end function SciMLBase.SCCNonlinearProblem(sys::NonlinearSystem, args...; kwargs...) From beb7070234a6cb24e9601c1fb7bedc89de847922 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 20 Nov 2024 18:12:24 +0530 Subject: [PATCH 065/101] refactor: improve `observed_dependency_graph` and add docstrings --- src/utils.jl | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 7be247429c..d223c835e2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1008,29 +1008,27 @@ function is_variable_floatingpoint(sym) T <: AbstractArray{<:AbstractFloat} end +""" + $(TYPEDSIGNATURES) + +Return the `DiCMOBiGraph` denoting the dependencies between observed equations `eqs`. +""" function observed_dependency_graph(eqs::Vector{Equation}) for eq in eqs if symbolic_type(eq.lhs) == NotSymbolic() error("All equations must be observed equations of the form `var ~ expr`. Got $eq") end end - - idxmap = Dict(eq.lhs => i for (i, eq) in enumerate(eqs)) - g = SimpleDiGraph(length(eqs)) - - syms = Set() - for (i, eq) in enumerate(eqs) - vars!(syms, eq) - for sym in syms - idx = get(idxmap, sym, nothing) - idx === nothing && continue - add_edge!(g, i, idx) - end - end - - return g + graph, assigns = observed2graph(eqs, getproperty.(eqs, (:lhs,))) + matching = complete(Matching(Vector{Union{Unassigned, Int}}(assigns))) + return DiCMOBiGraph{false}(graph, matching) end +""" + $(TYPEDSIGNATURES) + +Return the indexes of observed equations of `sys` used by expression `exprs`. +""" function observed_equations_used_by(sys::AbstractSystem, exprs) obs = observed(sys) From 27d7f706dfb56eb995e1d7c686fbb16ce14c9d95 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 20 Nov 2024 18:14:15 +0530 Subject: [PATCH 066/101] feat: cache subexpressions dependent only on previous SCCs --- src/systems/nonlinear/nonlinearsystem.jl | 39 +++++++++-- src/utils.jl | 89 ++++++++++++++++++++++++ test/scc_nonlinear_problem.jl | 17 +++++ 3 files changed, 140 insertions(+), 5 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 40e883222d..84cfe58189 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -583,9 +583,11 @@ function SCCNonlinearFunction{iip}( f(resid, u, p) = f_iip(resid, u, p) f(resid, u, p::MTKParameters) = f_iip(resid, u, p...) - subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs, parameter_dependencies = parameter_dependencies(sys), name = nameof(sys)) + subsys = NonlinearSystem(_eqs, _dvs, ps; observed = _obs, + parameter_dependencies = parameter_dependencies(sys), name = nameof(sys)) if get_index_cache(sys) !== nothing - @set! subsys.index_cache = subset_unknowns_observed(get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,))) + @set! subsys.index_cache = subset_unknowns_observed( + get_index_cache(sys), sys, _dvs, getproperty.(_obs, (:lhs,))) @set! subsys.complete = true end @@ -624,8 +626,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, explicitfuns = [] nlfuns = [] prevobsidxs = Int[] - cachevars = [] - cacheexprs = [] + cachesize = 0 for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs)) # subset unknowns and equations _dvs = dvs[vscc] @@ -636,6 +637,32 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, setdiff!(obsidxs, prevobsidxs) _obs = obs[obsidxs] + # get all subexpressions in the RHS which we can precompute in the cache + banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,)))) + for var in banned_vars + iscall(var) || continue + operation(var) === getindex || continue + push!(banned_vars, arguments(var)[1]) + end + state = Dict() + for i in eachindex(_obs) + _obs[i] = _obs[i].lhs ~ subexpressions_not_involving_vars!( + _obs[i].rhs, banned_vars, state) + end + for i in eachindex(_eqs) + _eqs[i] = _eqs[i].lhs ~ subexpressions_not_involving_vars!( + _eqs[i].rhs, banned_vars, state) + end + + # cached variables and their corresponding expressions + cachevars = Any[obs[i].lhs for i in prevobsidxs] + cacheexprs = Any[obs[i].rhs for i in prevobsidxs] + for (k, v) in state + push!(cachevars, unwrap(v)) + push!(cacheexprs, unwrap(k)) + end + cachesize = max(cachesize, length(cachevars)) + if isempty(cachevars) push!(explicitfuns, Returns(nothing)) else @@ -655,7 +682,9 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, append!(prevobsidxs, obsidxs) end - p = rebuild_with_caches(p, BufferTemplate(eltype(u0), length(cachevars))) + if cachesize != 0 + p = rebuild_with_caches(p, BufferTemplate(eltype(u0), cachesize)) + end subprobs = [] for (f, vscc) in zip(nlfuns, var_sccs) diff --git a/src/utils.jl b/src/utils.jl index d223c835e2..dd7113cbe6 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1001,6 +1001,11 @@ end diff2term_with_unit(x, t) = _with_unit(diff2term, x, t) lower_varname_with_unit(var, iv, order) = _with_unit(lower_varname, var, iv, iv, order) +""" + $(TYPEDSIGNATURES) + +Check if `sym` represents a symbolic floating point number or array of such numbers. +""" function is_variable_floatingpoint(sym) sym = unwrap(sym) T = symtype(sym) @@ -1052,3 +1057,87 @@ function observed_equations_used_by(sys::AbstractSystem, exprs) sort!(obsidxs) return obsidxs end + +""" + $(TYPEDSIGNATURES) + +Given an expression `expr`, return a dictionary mapping subexpressions of `expr` that do +not involve variables in `vars` to anonymous symbolic variables. Also return the modified +`expr` with the substitutions indicated by the dictionary. If `expr` is a function +of only `vars`, then all of the returned subexpressions can be precomputed. + +Note that this will only process subexpressions floating point value. Additionally, +array variables must be passed in both scalarized and non-scalarized forms in `vars`. +""" +function subexpressions_not_involving_vars(expr, vars) + expr = unwrap(expr) + vars = map(unwrap, vars) + state = Dict() + newexpr = subexpressions_not_involving_vars!(expr, vars, state) + return state, newexpr +end + +""" + $(TYPEDSIGNATURES) + +Mutating version of `subexpressions_not_involving_vars` which writes to `state`. Only +returns the modified `expr`. +""" +function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any}) + expr = unwrap(expr) + symbolic_type(expr) == NotSymbolic() && return expr + iscall(expr) || return expr + is_variable_floatingpoint(expr) || return expr + symtype(expr) <: Union{Real, AbstractArray{<:Real}} || return expr + Symbolics.shape(expr) == Symbolics.Unknown() && return expr + haskey(state, expr) && return state[expr] + vs = ModelingToolkit.vars(expr) + intersect!(vs, vars) + if isempty(vs) + sym = gensym(:subexpr) + stype = symtype(expr) + var = similar_variable(expr, sym) + state[expr] = var + return var + end + op = operation(expr) + args = arguments(expr) + if (op == (+) || op == (*)) && symbolic_type(expr) !== ArraySymbolic() + indep_args = [] + dep_args = [] + for arg in args + _vs = ModelingToolkit.vars(arg) + intersect!(_vs, vars) + if !isempty(_vs) + push!(dep_args, subexpressions_not_involving_vars!(arg, vars, state)) + else + push!(indep_args, arg) + end + end + indep_term = reduce(op, indep_args; init = Int(op == (*))) + indep_term = subexpressions_not_involving_vars!(indep_term, vars, state) + dep_term = reduce(op, dep_args; init = Int(op == (*))) + return op(indep_term, dep_term) + end + newargs = map(args) do arg + symbolic_type(arg) != NotSymbolic() || is_array_of_symbolics(arg) || return arg + subexpressions_not_involving_vars!(arg, vars, state) + end + return maketerm(typeof(expr), op, newargs, metadata(expr)) +end + +""" + $(TYPEDSIGNATURES) + +Create an anonymous symbolic variable of the same shape, size and symtype as `var`, with +name `gensym(name)`. Does not support unsized array symbolics. +""" +function similar_variable(var::BasicSymbolic, name = :anon) + name = gensym(name) + stype = symtype(var) + sym = Symbolics.variable(name; T = stype) + if size(var) !== () + sym = setmetadata(sym, Symbolics.ArrayShapeCtx, map(Base.OneTo, size(var))) + end + return sym +end diff --git a/test/scc_nonlinear_problem.jl b/test/scc_nonlinear_problem.jl index c4368bf0b9..0c4bfe61d0 100644 --- a/test/scc_nonlinear_problem.jl +++ b/test/scc_nonlinear_problem.jl @@ -143,3 +143,20 @@ end @test sol.u≈sccsol.u atol=1e-10 end +@testset "Expression caching" begin + @variables x[1:4] = rand(4) + val = Ref(0) + function func(x, y) + val[] += 1 + x + y + end + @register_symbolic func(x, y) + @mtkbuild sys = NonlinearSystem([0 ~ x[1]^3 + x[2]^3 - 5 + 0 ~ sin(x[1] - x[2]) - 0.5 + 0 ~ func(x[1], x[2]) * exp(x[3]) - x[4]^3 - 5 + 0 ~ func(x[1], x[2]) * exp(x[4]) - x[3]^3 - 4]) + sccprob = SCCNonlinearProblem(sys, []) + sccsol = solve(sccprob, NewtonRaphson()) + @test SciMLBase.successful_retcode(sccsol) + @test val[] == 1 +end From 691747aa1053a6d18eb0ce8fdc80e063eaa3eaed Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 26 Nov 2024 12:16:31 +0530 Subject: [PATCH 067/101] feat: use SCCNonlinearProblem for initialization --- Project.toml | 4 +++- src/ModelingToolkit.jl | 1 + src/systems/diffeqs/abstractodesystem.jl | 13 ++++++++----- src/systems/nonlinear/nonlinearsystem.jl | 2 +- src/systems/problem_utils.jl | 6 ++++-- 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index b1f5701d13..b520fa3e0c 100644 --- a/Project.toml +++ b/Project.toml @@ -44,6 +44,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" +SCCNonlinearSolve = "9dfe8606-65a1-4bb3-9748-cb89d1561431" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b" @@ -126,7 +127,8 @@ REPL = "1" RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" -SciMLBase = "2.64" +SCCNonlinearSolve = "1.0.0" +SciMLBase = "2.65" SciMLStructures = "1.0" Serialization = "1" Setfield = "0.7, 0.8, 1" diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index de8e69c41f..59be358349 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -50,6 +50,7 @@ using Distributed import JuliaFormatter using MLStyle using NonlinearSolve +import SCCNonlinearSolve using Reexport using RecursiveArrayTools import Graphs: SimpleDiGraph, add_edge!, incidence_matrix diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index af96c4fcfe..ac8870fcaa 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1301,6 +1301,7 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, initialization_eqs = [], fully_determined = nothing, check_units = true, + use_scc = true, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEProblem`") @@ -1318,8 +1319,8 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, end ts = get_tearing_state(isys) - if warn_initialize_determined && - (unassigned_vars = StructuralTransformations.singular_check(ts); !isempty(unassigned_vars)) + unassigned_vars = StructuralTransformations.singular_check(ts) + if warn_initialize_determined && !isempty(unassigned_vars) errmsg = """ The initialization system is structurally singular. Guess values may \ significantly affect the initial values of the ODE. The problematic variables \ @@ -1381,9 +1382,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, end for (k, v) in u0map) end - if neqs == nunknown - NonlinearProblem(isys, u0map, parammap; kwargs...) + + TProb = if neqs == nunknown && isempty(unassigned_vars) + use_scc && neqs > 0 && is_split(isys) ? SCCNonlinearProblem : NonlinearProblem else - NonlinearLeastSquaresProblem(isys, u0map, parammap; kwargs...) + NonlinearLeastSquaresProblem end + TProb(isys, u0map, parammap; kwargs...) end diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 84cfe58189..4299eba62d 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -692,7 +692,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, push!(subprobs, prob) end - return SCCNonlinearProblem(subprobs, explicitfuns) + return SCCNonlinearProblem(subprobs, explicitfuns, sys, p) end """ diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index b837eef98e..ebc5a5bd75 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -541,6 +541,8 @@ Keyword arguments: Only applicable if `warn_cyclic_dependency == true`. - `substitution_limit`: The number times to substitute initial conditions into each other to attempt to arrive at a numeric value. +- `use_scc`: Whether to use `SCCNonlinearProblem` for initialization if the system is fully + determined. All other keyword arguments are passed as-is to `constructor`. """ @@ -554,7 +556,7 @@ function process_SciMLProblem( symbolic_u0 = false, warn_cyclic_dependency = false, circular_dependency_max_cycle_length = length(all_symbols(sys)), circular_dependency_max_cycles = 10, - substitution_limit = 100, kwargs...) + substitution_limit = 100, use_scc = true, kwargs...) dvs = unknowns(sys) ps = parameters(sys) iv = has_iv(sys) ? get_iv(sys) : nothing @@ -607,7 +609,7 @@ function process_SciMLProblem( sys, t, u0map, pmap; guesses, warn_initialize_determined, initialization_eqs, eval_expression, eval_module, fully_determined, warn_cyclic_dependency, check_units = check_initialization_units, - circular_dependency_max_cycle_length, circular_dependency_max_cycles) + circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc) initializeprobmap = getu(initializeprob, unknowns(sys)) punknowns = [p From b5fc3ed1e19b9a0d38ba22c67afdf9dcc43751c1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 26 Nov 2024 16:57:58 +0530 Subject: [PATCH 068/101] refactor: better handle inputs in `wrap_array_vars` --- src/systems/abstractsystem.jl | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 60ed2fa1ce..e44f250a7f 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -247,20 +247,15 @@ function wrap_array_vars( push!(inds, (uind, i)) end end - p_start = uind + 1 + (inputs !== nothing) + history - input_ind = inputs === nothing ? -1 : (p_start - 1) + p_start = uind + 1 + history rps = (reorder_parameters(sys, ps)..., cachesyms...) + if inputs !== nothing + rps = (inputs, rps...) + end for sym in reduce(vcat, rps; init = []) iscall(sym) && operation(sym) == getindex || continue arg = arguments(sym)[1] - if inputs !== nothing - idx = findfirst(isequal(sym), inputs) - if idx !== nothing - inds = get!(() -> [], var_to_arridxs, arg) - push!(inds, (input_ind, idx)) - continue - end - end + bufferidx = findfirst(buf -> any(isequal(sym), buf), rps) idxinbuffer = findfirst(isequal(sym), rps[bufferidx]) inds = get!(() -> [], var_to_arridxs, arg) From e79dc902fc3377a8ff5adb5623d1bb26a969ecbe Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 26 Nov 2024 16:58:22 +0530 Subject: [PATCH 069/101] fix: properly sort SCCs --- src/systems/nonlinear/nonlinearsystem.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 4299eba62d..bb37dadb8c 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -610,9 +610,17 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, ts = get_tearing_state(sys) var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts) - # The system is simplified, so SCCs are already in sorted order. We just need to get them and sort - # according to index in unknowns(sys) - sort!(var_sccs) + + if length(var_sccs) == 1 + return NonlinearProblem{iip}(sys, u0map, parammap; eval_expression, eval_module, kwargs...) + end + + condensed_graph = MatchedCondensationGraph( + DiCMOBiGraph{true}(complete(ts.structure.graph), + complete(var_eq_matching)), + var_sccs) + toporder = topological_sort_by_dfs(condensed_graph) + var_sccs = var_sccs[toporder] eq_sccs = map(Base.Fix1(getindex, var_eq_matching), var_sccs) dvs = unknowns(sys) From c1e1523cae45dc2e5a1ecb6d8aa8a914cb36aa3b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 26 Nov 2024 16:58:38 +0530 Subject: [PATCH 070/101] feat: better handle observed variables, constants in SCCNonlinearProblem --- src/systems/nonlinear/nonlinearsystem.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index bb37dadb8c..636c1202f7 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -543,17 +543,22 @@ function (cw::CacheWriter)(p, sols) cw.fn(p.caches[1], sols, p...) end -function CacheWriter(sys::AbstractSystem, exprs, solsyms; +function CacheWriter(sys::AbstractSystem, exprs, solsyms, obseqs::Vector{Equation}; eval_expression = false, eval_module = @__MODULE__) ps = parameters(sys) rps = reorder_parameters(sys, ps) + obs_assigns = [eq.lhs ← eq.rhs for eq in obseqs] + cmap, cs = get_cmap(sys) + cmap_assigns = [eq.lhs ← eq.rhs for eq in cmap] fn = Func( [:out, DestructuredArgs(DestructuredArgs.(solsyms)), DestructuredArgs.(rps)...], [], SetArray(true, :out, exprs) - ) |> wrap_parameter_dependencies(sys, false)[2] |> - wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |> toexpr + ) |> wrap_assignments(false, obs_assigns)[2] |> + wrap_parameter_dependencies(sys, false)[2] |> + wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |> + wrap_assignments(false, cmap_assigns)[2] |> toexpr return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module)) end @@ -612,7 +617,8 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, var_eq_matching, var_sccs = StructuralTransformations.algebraic_variables_scc(ts) if length(var_sccs) == 1 - return NonlinearProblem{iip}(sys, u0map, parammap; eval_expression, eval_module, kwargs...) + return NonlinearProblem{iip}( + sys, u0map, parammap; eval_expression, eval_module, kwargs...) end condensed_graph = MatchedCondensationGraph( @@ -664,7 +670,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, # cached variables and their corresponding expressions cachevars = Any[obs[i].lhs for i in prevobsidxs] - cacheexprs = Any[obs[i].rhs for i in prevobsidxs] + cacheexprs = Any[obs[i].lhs for i in prevobsidxs] for (k, v) in state push!(cachevars, unwrap(v)) push!(cacheexprs, unwrap(k)) @@ -676,7 +682,8 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, else solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1))) push!(explicitfuns, - CacheWriter(sys, cacheexprs, solsyms; eval_expression, eval_module)) + CacheWriter(sys, cacheexprs, solsyms, obs[prevobsidxs]; + eval_expression, eval_module)) end f = SCCNonlinearFunction{iip}( sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...) From 205d76a2f6b64820f59c17a2fdf8e2253f914cf4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 26 Nov 2024 16:59:17 +0530 Subject: [PATCH 071/101] test: fix tests --- test/initializationsystem.jl | 2 +- test/mtkparameters.jl | 4 ++-- test/scc_nonlinear_problem.jl | 3 ++- test/split_parameters.jl | 2 +- 4 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index f3015f7db0..9c06dd2030 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -28,7 +28,7 @@ sol = solve(initprob) initprob = ModelingToolkit.InitializationProblem(pend, 0.0, [x => 1, y => 0], [g => 1]; guesses = ModelingToolkit.missing_variable_defaults(pend)) -@test initprob isa NonlinearProblem +@test initprob isa NonlinearLeastSquaresProblem sol = solve(initprob) @test SciMLBase.successful_retcode(sol) @test sol.u == [0.0, 0.0, 0.0, 0.0] diff --git a/test/mtkparameters.jl b/test/mtkparameters.jl index 603426aaf7..ce524fdb76 100644 --- a/test/mtkparameters.jl +++ b/test/mtkparameters.jl @@ -299,7 +299,7 @@ end # Parameter timeseries ps = MTKParameters(([1.0, 1.0],), (BlockedArray(zeros(4), [2, 2]),), - (), ()) + (), (), ()) ps2 = SciMLStructures.replace(Discrete(), ps, ones(4)) @test typeof(ps2.discrete) == typeof(ps.discrete) with_updated_parameter_timeseries_values( @@ -316,7 +316,7 @@ with_updated_parameter_timeseries_values( ps = MTKParameters( (), (BlockedArray([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [3, 3]), BlockedArray(falses(1), [1, 0])), - (), ()) + (), (), ()) @test SciMLBase.get_saveable_values(sys, ps, 1).x isa Tuple{Vector{Float64}, Vector{Bool}} tsidx1 = 1 tsidx2 = 2 diff --git a/test/scc_nonlinear_problem.jl b/test/scc_nonlinear_problem.jl index 0c4bfe61d0..fdf1646343 100644 --- a/test/scc_nonlinear_problem.jl +++ b/test/scc_nonlinear_problem.jl @@ -3,6 +3,7 @@ using NonlinearSolve, SCCNonlinearSolve using OrdinaryDiffEq using SciMLBase, Symbolics using LinearAlgebra, Test +using ModelingToolkit: t_nounits as t, D_nounits as D @testset "Trivial case" begin function f!(du, u, p) @@ -30,7 +31,7 @@ using LinearAlgebra, Test sol2 = solve(sccprob, NewtonRaphson()) @test SciMLBase.successful_retcode(sol1) @test SciMLBase.successful_retcode(sol2) - @test sol1.u ≈ sol2.u + @test sol1[u] ≈ sol2[u] end @testset "With parameters" begin diff --git a/test/split_parameters.jl b/test/split_parameters.jl index 22c90edf7a..2f8667faa8 100644 --- a/test/split_parameters.jl +++ b/test/split_parameters.jl @@ -206,7 +206,7 @@ S = get_sensitivity(closed_loop, :u) BlockedArray([[1 2; 3 4], [2 4; 6 8]], [1, 1])), # (BlockedArray([[true, false], [false, true]]), BlockedArray([[[1 2; 3 4]], [[2 4; 6 8]]])), ([5, 6],), - (["hi", "bye"], [:lie, :die])) + (["hi", "bye"], [:lie, :die]), ()) @test ps[ParameterIndex(Tunable(), 1)] == 1.0 @test ps[ParameterIndex(Tunable(), 2:4)] == collect(2.0:4.0) @test ps[ParameterIndex(Tunable(), reshape(4:7, 2, 2))] == reshape(4.0:7.0, 2, 2) From ed9cdf380488fdbd1e43753fb689f76e9474437e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 26 Nov 2024 23:54:58 +0530 Subject: [PATCH 072/101] fix: reorder system in `SCCNonlinearProblem` --- src/systems/nonlinear/nonlinearsystem.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 636c1202f7..24680fead4 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -707,6 +707,11 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, push!(subprobs, prob) end + new_dvs = dvs[reduce(vcat, var_sccs)] + new_eqs = eqs[reduce(vcat, eq_sccs)] + @set! sys.unknowns = new_dvs + @set! sys.eqs = new_eqs + sys = complete(sys) return SCCNonlinearProblem(subprobs, explicitfuns, sys, p) end From a3819300c07a9e8c1a70ccff93b71574a832bc75 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Nov 2024 23:28:06 +0530 Subject: [PATCH 073/101] build: bump OrdinaryDiffEqCore, OrdinaryDiffEqNonlinearSolve compats --- Project.toml | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index b520fa3e0c..7ac6f940be 100644 --- a/Project.toml +++ b/Project.toml @@ -121,7 +121,8 @@ NonlinearSolve = "3.14, 4" OffsetArrays = "1" OrderedCollections = "1" OrdinaryDiffEq = "6.82.0" -OrdinaryDiffEqCore = "1.7.0" +OrdinaryDiffEqCore = "1.13.0" +OrdinaryDiffEqNonlinearSolve = "1.3.0" PrecompileTools = "1" REPL = "1" RecursiveArrayTools = "3.26" @@ -162,6 +163,7 @@ OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -176,4 +178,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET"] +test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"] From 52eba502e7e85670095ee4dafe7b131370c5d14f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 14:14:27 +0530 Subject: [PATCH 074/101] refactor: update to new `SCCNonlinearProblem` constructor --- Project.toml | 2 +- src/systems/nonlinear/nonlinearsystem.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 7ac6f940be..d81b3d4ef9 100644 --- a/Project.toml +++ b/Project.toml @@ -129,7 +129,7 @@ RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SCCNonlinearSolve = "1.0.0" -SciMLBase = "2.65" +SciMLBase = "2.66" SciMLStructures = "1.0" Serialization = "1" Setfield = "0.7, 0.8, 1" diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 24680fead4..7826e06f76 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -712,7 +712,7 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map, @set! sys.unknowns = new_dvs @set! sys.eqs = new_eqs sys = complete(sys) - return SCCNonlinearProblem(subprobs, explicitfuns, sys, p) + return SCCNonlinearProblem(subprobs, explicitfuns, p, true; sys) end """ From 24d39af42431ac0fff321bb785e7ff6dac864120 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 4 Dec 2024 14:14:41 +0530 Subject: [PATCH 075/101] test: fix test to new SciMLBase error message --- test/symbolic_events.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 61690acdce..9d3aac4074 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1063,7 +1063,7 @@ end cb = [x ~ 0.0] => [x ~ 0, y ~ 1] @mtkbuild pend = ODESystem(eqs, t; continuous_events = [cb]) prob = ODEProblem(pend, [x => 1], (0.0, 3.0), guesses = [y => x]) - @test_throws "CheckInit specified but initialization" solve(prob, Rodas5()) + @test_throws "DAE initialization failed" solve(prob, Rodas5()) cb = [x ~ 0.0] => [y ~ 1] @mtkbuild pend = ODESystem(eqs, t; continuous_events = [cb]) From 1320fc0bd6bd11e819405512bb48ab6737094fa8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 13:06:04 +0530 Subject: [PATCH 076/101] refactor: separate out operating point and initializeprob construction --- src/systems/problem_utils.jl | 154 +++++++++++++++++++++-------------- 1 file changed, 94 insertions(+), 60 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index ebc5a5bd75..6b75fb2c6d 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -489,6 +489,89 @@ function EmptySciMLFunction(args...; kwargs...) return EmptySciMLFunction{typeof(args), typeof(kwargs)}(args, kwargs) end +""" + $(TYPEDSIGNATURES) + +Construct the operating point of the system from the user-provided `u0map` and `pmap`, system +defaults `defs`, constant equations `cmap` (from `get_cmap(sys)`), unknowns `dvs` and +parameters `ps`. Return the operating point as a dictionary, the list of unknowns for which +no values can be determined, and the list of parameters for which no values can be determined. +""" +function build_operating_point( + u0map::AbstractDict, pmap::AbstractDict, defs::AbstractDict, cmap, dvs, ps) + op = add_toterms(u0map) + missing_unknowns = add_fallbacks!(op, dvs, defs) + for (k, v) in defs + haskey(op, k) && continue + op[k] = v + end + merge!(op, pmap) + missing_pars = add_fallbacks!(op, ps, defs) + for eq in cmap + op[eq.lhs] = eq.rhs + end + return op, missing_unknowns, missing_pars +end + +""" + $(TYPEDSIGNATURES) + +Build and return the initialization problem and associated data as a `NamedTuple` to be passed +to the `SciMLFunction` constructor. Requires the system `sys`, operating point `op`, +user-provided `u0map` and `pmap`, initial time `t`, system defaults `defs`, user-provided +`guesses`, and list of unknowns which don't have a value in `op`. The keyword `implicit_dae` +denotes whether the `SciMLProblem` being constructed is in implicit DAE form (`DAEProblem`). +All other keyword arguments are forwarded to `InitializationProblem`. +""" +function maybe_build_initialization_problem( + sys::AbstractSystem, op::AbstractDict, u0map, pmap, t, defs, + guesses, missing_unknowns; implicit_dae = false, kwargs...) + guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) + has_observed_u0s = any( + k -> has_observed_with_lhs(sys, k) || has_parameter_dependency_with_lhs(sys, k), + keys(op)) + solvablepars = [p + for p in parameters(sys) + if is_parameter_solvable(p, pmap, defs, guesses)] + has_dependent_unknowns = any(unknowns(sys)) do sym + val = get(op, sym, nothing) + val === nothing && return false + return symbolic_type(val) != NotSymbolic() || is_array_of_symbolics(val) + end + if (((implicit_dae || has_observed_u0s || !isempty(missing_unknowns) || + !isempty(solvablepars) || has_dependent_unknowns) && + get_tearing_state(sys) !== nothing) || + !isempty(initialization_equations(sys))) && t !== nothing + initializeprob = ModelingToolkit.InitializationProblem( + sys, t, u0map, pmap; guesses, kwargs...) + initializeprobmap = getu(initializeprob, unknowns(sys)) + + punknowns = [p + for p in all_variable_symbols(initializeprob) + if is_parameter(sys, p)] + getpunknowns = getu(initializeprob, punknowns) + setpunknowns = setp(sys, punknowns) + initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns) + + reqd_syms = parameter_symbols(initializeprob) + update_initializeprob! = UpdateInitializeprob( + getu(sys, reqd_syms), setu(initializeprob, reqd_syms)) + for p in punknowns + p = unwrap(p) + stype = symtype(p) + op[p] = get_temporary_value(p) + end + + for v in missing_unknowns + op[v] = zero_var(v) + end + empty!(missing_unknowns) + return (; + initializeprob, initializeprobmap, initializeprobpmap, update_initializeprob!) + end + return (;) +end + """ $(TYPEDSIGNATURES) @@ -576,67 +659,18 @@ function process_SciMLProblem( cmap, cs = get_cmap(sys) kwargs = NamedTuple(kwargs) - op = add_toterms(u0map) - missing_unknowns = add_fallbacks!(op, dvs, defs) - for (k, v) in defs - haskey(op, k) && continue - op[k] = v - end - merge!(op, pmap) - missing_pars = add_fallbacks!(op, ps, defs) - for eq in cmap - op[eq.lhs] = eq.rhs - end - if sys isa ODESystem - guesses = merge(ModelingToolkit.guesses(sys), todict(guesses)) - has_observed_u0s = any( - k -> has_observed_with_lhs(sys, k) || has_parameter_dependency_with_lhs(sys, k), - keys(op)) - solvablepars = [p - for p in parameters(sys) - if is_parameter_solvable(p, pmap, defs, guesses)] - has_dependent_unknowns = any(unknowns(sys)) do sym - val = get(op, sym, nothing) - val === nothing && return false - return symbolic_type(val) != NotSymbolic() || is_array_of_symbolics(val) - end - if build_initializeprob && - (((implicit_dae || has_observed_u0s || !isempty(missing_unknowns) || - !isempty(solvablepars) || has_dependent_unknowns) && - get_tearing_state(sys) !== nothing) || - !isempty(initialization_equations(sys))) && t !== nothing - initializeprob = ModelingToolkit.InitializationProblem( - sys, t, u0map, pmap; guesses, warn_initialize_determined, - initialization_eqs, eval_expression, eval_module, fully_determined, - warn_cyclic_dependency, check_units = check_initialization_units, - circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc) - initializeprobmap = getu(initializeprob, unknowns(sys)) - - punknowns = [p - for p in all_variable_symbols(initializeprob) - if is_parameter(sys, p)] - getpunknowns = getu(initializeprob, punknowns) - setpunknowns = setp(sys, punknowns) - initializeprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns) - - reqd_syms = parameter_symbols(initializeprob) - update_initializeprob! = UpdateInitializeprob( - getu(sys, reqd_syms), setu(initializeprob, reqd_syms)) - for p in punknowns - p = unwrap(p) - stype = symtype(p) - op[p] = get_temporary_value(p) - delete!(missing_pars, p) - end + op, missing_unknowns, missing_pars = build_operating_point( + u0map, pmap, defs, cmap, dvs, ps) - for v in missing_unknowns - op[v] = zero_var(v) - end - empty!(missing_unknowns) - kwargs = merge(kwargs, - (; initializeprob, initializeprobmap, - initializeprobpmap, update_initializeprob!)) - end + if sys isa ODESystem && build_initializeprob + kws = maybe_build_initialization_problem( + sys, op, u0map, pmap, t, defs, guesses, missing_unknowns; + implicit_dae, warn_initialize_determined, initialization_eqs, + eval_expression, eval_module, fully_determined, + warn_cyclic_dependency, check_units = check_initialization_units, + circular_dependency_max_cycle_length, circular_dependency_max_cycles, use_scc) + + kwargs = merge(kwargs, kws) end if t !== nothing && !(constructor <: Union{DDEFunction, SDDEFunction}) From 35b407c728ba9a8590f1567140822cba9c57e924 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 03:12:26 +0530 Subject: [PATCH 077/101] fix: fix `remake_initialization_data` on problems with no initprob --- src/systems/nonlinear/initializesystem.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 726d171bd0..4544f0fa96 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -281,6 +281,8 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, symbols_to_symbolics!(sys, pmap) guesses = Dict() defs = defaults(sys) + cmap, cs = get_cmap(sys) + if SciMLBase.has_initializeprob(odefn) oldsys = odefn.initializeprob.f.sys meta = get_metadata(oldsys) @@ -324,8 +326,9 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, end filter_missing_values!(u0map) filter_missing_values!(pmap) - f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0) - kws = f.kwargs + + op, missing_unknowns, missing_pars = build_operating_point(u0map, pmap, defs, cmap, dvs, ps) + kws = maybe_build_initialization_problem(sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; use_scc = true) initprob = get(kws, :initializeprob, nothing) if initprob === nothing return nothing From 99435326bd009f38bdd9545801dfcfd2ac1e28b6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 13:11:17 +0530 Subject: [PATCH 078/101] refactor: add better warnings when SCC initialization cannot be used --- src/systems/diffeqs/abstractodesystem.jl | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index ac8870fcaa..8b1476e582 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1342,11 +1342,17 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, neqs = length(equations(isys)) nunknown = length(unknowns(isys)) + if use_scc + scc_message = "`SCCNonlinearProblem` can only be used for initialization of fully determined systems and hence will not be used here. " + else + scc_message = "" + end + if warn_initialize_determined && neqs > nunknown - @warn "Initialization system is overdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false. To make this warning into an error, pass fully_determined = true" + @warn "Initialization system is overdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. $(scc_message)To suppress this warning pass warn_initialize_determined = false. To make this warning into an error, pass fully_determined = true" end if warn_initialize_determined && neqs < nunknown - @warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. To suppress this warning pass warn_initialize_determined = false. To make this warning into an error, pass fully_determined = true" + @warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. $(scc_message)To suppress this warning pass warn_initialize_determined = false. To make this warning into an error, pass fully_determined = true" end parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ? @@ -1384,7 +1390,16 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, end TProb = if neqs == nunknown && isempty(unassigned_vars) - use_scc && neqs > 0 && is_split(isys) ? SCCNonlinearProblem : NonlinearProblem + if use_scc && neqs > 0 + if is_split(isys) + SCCNonlinearProblem + else + @warn "`SCCNonlinearProblem` can only be used with `split = true` systems. Simplify your `ODESystem` with `split = true` or pass `use_scc = false` to disable this warning" + NonlinearProblem + end + else + NonlinearProblem + end else NonlinearLeastSquaresProblem end From 0537715df6c5a0ed03c5d47b26a0768b42b2b184 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Dec 2024 13:22:29 +0530 Subject: [PATCH 079/101] refactor: propagate `use_scc` to `remake_initialization_data` --- src/systems/diffeqs/abstractodesystem.jl | 6 ++++-- src/systems/nonlinear/initializesystem.jl | 14 ++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 8b1476e582..f4e29346ff 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1311,11 +1311,13 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, elseif isempty(u0map) && get_initializesystem(sys) === nothing isys = structural_simplify( generate_initializesystem( - sys; initialization_eqs, check_units, pmap = parammap, guesses); fully_determined) + sys; initialization_eqs, check_units, pmap = parammap, + guesses, extra_metadata = (; use_scc)); fully_determined) else isys = structural_simplify( generate_initializesystem( - sys; u0map, initialization_eqs, check_units, pmap = parammap, guesses); fully_determined) + sys; u0map, initialization_eqs, check_units, + pmap = parammap, guesses, extra_metadata = (; use_scc)); fully_determined) end ts = get_tearing_state(isys) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 4544f0fa96..2344727920 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -11,7 +11,7 @@ function generate_initializesystem(sys::ODESystem; default_dd_guess = 0.0, algebraic_only = false, check_units = true, check_defguess = false, - name = nameof(sys), kwargs...) + name = nameof(sys), extra_metadata = (;), kwargs...) trueobs, eqs = unhack_observed(observed(sys), equations(sys)) vars = unique([unknowns(sys); getfield.(trueobs, :lhs)]) vars_set = Set(vars) # for efficient in-lookup @@ -179,7 +179,8 @@ function generate_initializesystem(sys::ODESystem; for k in keys(defs) defs[k] = substitute(defs[k], paramsubs) end - meta = InitializationSystemMetadata(anydict(u0map), anydict(pmap), additional_guesses) + meta = InitializationSystemMetadata( + anydict(u0map), anydict(pmap), additional_guesses, extra_metadata) return NonlinearSystem(eqs_ics, vars, pars; @@ -195,6 +196,7 @@ struct InitializationSystemMetadata u0map::Dict{Any, Any} pmap::Dict{Any, Any} additional_guesses::Dict{Any, Any} + extra_metadata::NamedTuple end function is_parameter_solvable(p, pmap, defs, guesses) @@ -282,6 +284,7 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, guesses = Dict() defs = defaults(sys) cmap, cs = get_cmap(sys) + use_scc = true if SciMLBase.has_initializeprob(odefn) oldsys = odefn.initializeprob.f.sys @@ -290,6 +293,7 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, u0map = merge(meta.u0map, u0map) pmap = merge(meta.pmap, pmap) merge!(guesses, meta.additional_guesses) + use_scc = get(meta.extra_metadata, :use_scc, true) end else # there is no initializeprob, so the original problem construction @@ -327,8 +331,10 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, filter_missing_values!(u0map) filter_missing_values!(pmap) - op, missing_unknowns, missing_pars = build_operating_point(u0map, pmap, defs, cmap, dvs, ps) - kws = maybe_build_initialization_problem(sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; use_scc = true) + op, missing_unknowns, missing_pars = build_operating_point( + u0map, pmap, defs, cmap, dvs, ps) + kws = maybe_build_initialization_problem( + sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; use_scc) initprob = get(kws, :initializeprob, nothing) if initprob === nothing return nothing From ab5747f208a9cf5c2924fa251c739053cccf859d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 5 Dec 2024 07:30:41 -0700 Subject: [PATCH 080/101] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index d81b3d4ef9..fe3a9e88e5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ModelingToolkit" uuid = "961ee093-0014-501f-94e3-6117800e7a78" authors = ["Yingbo Ma ", "Chris Rackauckas and contributors"] -version = "9.55.0" +version = "9.56.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" From 82198b7f49cb707b31b06af6d49dd0a7ecd33160 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Nov 2024 14:09:11 +0530 Subject: [PATCH 081/101] refactor: move HomtopyContinuation-related code to its own file --- src/ModelingToolkit.jl | 1 + .../nonlinear/homotopy_continuation.jl | 68 ++++++++++++++++++ src/systems/nonlinear/nonlinearsystem.jl | 69 ------------------- 3 files changed, 69 insertions(+), 69 deletions(-) create mode 100644 src/systems/nonlinear/homotopy_continuation.jl diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 59be358349..de9a06d774 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -150,6 +150,7 @@ include("systems/callbacks.jl") include("systems/problem_utils.jl") include("systems/nonlinear/nonlinearsystem.jl") +include("systems/nonlinear/homotopy_continuation.jl") include("systems/diffeqs/odesystem.jl") include("systems/diffeqs/sdesystem.jl") include("systems/diffeqs/abstractodesystem.jl") diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl new file mode 100644 index 0000000000..e41a3f0ce7 --- /dev/null +++ b/src/systems/nonlinear/homotopy_continuation.jl @@ -0,0 +1,68 @@ +""" +$(TYPEDEF) + +A type of Nonlinear problem which specializes on polynomial systems and uses +HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to +create and solve. +""" +struct HomotopyContinuationProblem{uType, H, D, O, SS, U} <: + SciMLBase.AbstractNonlinearProblem{uType, true} + """ + The initial values of states in the system. If there are multiple real roots of + the system, the one closest to this point is returned. + """ + u0::uType + """ + A subtype of `HomotopyContinuation.AbstractSystem` to solve. Also contains the + parameter object. + """ + homotopy_continuation_system::H + """ + A function with signature `(u, p) -> resid`. In case of rational functions, this + is used to rule out roots of the system which would cause the denominator to be + zero. + """ + denominator::D + """ + The `NonlinearSystem` used to create this problem. Used for symbolic indexing. + """ + sys::NonlinearSystem + """ + A function which generates and returns observed expressions for the given system. + """ + obsfn::O + """ + The HomotopyContinuation.jl solver and start system, obtained through + `HomotopyContinuation.solver_startsystems`. + """ + solver_and_starts::SS + """ + A function which takes a solution of the transformed system, and returns a vector + of solutions for the original system. This is utilized when converting systems + to polynomials. + """ + unpack_solution::U +end + +function HomotopyContinuationProblem(::AbstractSystem, _u0, _p; kwargs...) + error("HomotopyContinuation.jl is required to create and solve `HomotopyContinuationProblem`s. Please run `Pkg.add(\"HomotopyContinuation\")` to continue.") +end + +SymbolicIndexingInterface.symbolic_container(p::HomotopyContinuationProblem) = p.sys +SymbolicIndexingInterface.state_values(p::HomotopyContinuationProblem) = p.u0 +function SymbolicIndexingInterface.set_state!(p::HomotopyContinuationProblem, args...) + set_state!(p.u0, args...) +end +function SymbolicIndexingInterface.parameter_values(p::HomotopyContinuationProblem) + parameter_values(p.homotopy_continuation_system) +end +function SymbolicIndexingInterface.set_parameter!(p::HomotopyContinuationProblem, args...) + set_parameter!(parameter_values(p), args...) +end +function SymbolicIndexingInterface.observed(p::HomotopyContinuationProblem, sym) + if p.obsfn !== nothing + return p.obsfn(sym) + else + return SymbolicIndexingInterface.observed(p.sys, sym) + end +end diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 7826e06f76..ff4a7d0eb2 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -866,72 +866,3 @@ function Base.:(==)(sys1::NonlinearSystem, sys2::NonlinearSystem) _eq_unordered(get_ps(sys1), get_ps(sys2)) && all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2))) end - -""" -$(TYPEDEF) - -A type of Nonlinear problem which specializes on polynomial systems and uses -HomotopyContinuation.jl to solve the system. Requires importing HomotopyContinuation.jl to -create and solve. -""" -struct HomotopyContinuationProblem{uType, H, D, O, SS, U} <: - SciMLBase.AbstractNonlinearProblem{uType, true} - """ - The initial values of states in the system. If there are multiple real roots of - the system, the one closest to this point is returned. - """ - u0::uType - """ - A subtype of `HomotopyContinuation.AbstractSystem` to solve. Also contains the - parameter object. - """ - homotopy_continuation_system::H - """ - A function with signature `(u, p) -> resid`. In case of rational functions, this - is used to rule out roots of the system which would cause the denominator to be - zero. - """ - denominator::D - """ - The `NonlinearSystem` used to create this problem. Used for symbolic indexing. - """ - sys::NonlinearSystem - """ - A function which generates and returns observed expressions for the given system. - """ - obsfn::O - """ - The HomotopyContinuation.jl solver and start system, obtained through - `HomotopyContinuation.solver_startsystems`. - """ - solver_and_starts::SS - """ - A function which takes a solution of the transformed system, and returns a vector - of solutions for the original system. This is utilized when converting systems - to polynomials. - """ - unpack_solution::U -end - -function HomotopyContinuationProblem(::AbstractSystem, _u0, _p; kwargs...) - error("HomotopyContinuation.jl is required to create and solve `HomotopyContinuationProblem`s. Please run `Pkg.add(\"HomotopyContinuation\")` to continue.") -end - -SymbolicIndexingInterface.symbolic_container(p::HomotopyContinuationProblem) = p.sys -SymbolicIndexingInterface.state_values(p::HomotopyContinuationProblem) = p.u0 -function SymbolicIndexingInterface.set_state!(p::HomotopyContinuationProblem, args...) - set_state!(p.u0, args...) -end -function SymbolicIndexingInterface.parameter_values(p::HomotopyContinuationProblem) - parameter_values(p.homotopy_continuation_system) -end -function SymbolicIndexingInterface.set_parameter!(p::HomotopyContinuationProblem, args...) - set_parameter!(parameter_values(p), args...) -end -function SymbolicIndexingInterface.observed(p::HomotopyContinuationProblem, sym) - if p.obsfn !== nothing - return p.obsfn(sym) - else - return SymbolicIndexingInterface.observed(p.sys, sym) - end -end From 2401aa22fc2ef382589efb57e75bede4a8ed14a3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Nov 2024 16:08:38 +0530 Subject: [PATCH 082/101] refactor: move HomotopyContinuation internals to MTK --- ext/MTKHomotopyContinuationExt.jl | 355 +------------- .../nonlinear/homotopy_continuation.jl | 437 ++++++++++++++++++ 2 files changed, 458 insertions(+), 334 deletions(-) diff --git a/ext/MTKHomotopyContinuationExt.jl b/ext/MTKHomotopyContinuationExt.jl index fa5d1bfcd4..586d608156 100644 --- a/ext/MTKHomotopyContinuationExt.jl +++ b/ext/MTKHomotopyContinuationExt.jl @@ -11,217 +11,6 @@ using ModelingToolkit: iscomplete, parameters, has_index_cache, get_index_cache, const MTK = ModelingToolkit -function contains_variable(x, wrt) - any(y -> occursin(y, x), wrt) -end - -""" -Possible reasons why a term is not polynomial -""" -MTK.EnumX.@enumx NonPolynomialReason begin - NonIntegerExponent - ExponentContainsUnknowns - BaseNotPolynomial - UnrecognizedOperation -end - -function display_reason(reason::NonPolynomialReason.T, sym) - if reason == NonPolynomialReason.NonIntegerExponent - pow = arguments(sym)[2] - "In $sym: Exponent $pow is not an integer" - elseif reason == NonPolynomialReason.ExponentContainsUnknowns - pow = arguments(sym)[2] - "In $sym: Exponent $pow contains unknowns of the system" - elseif reason == NonPolynomialReason.BaseNotPolynomial - base = arguments(sym)[1] - "In $sym: Base $base is not a polynomial in the unknowns" - elseif reason == NonPolynomialReason.UnrecognizedOperation - op = operation(sym) - """ - In $sym: Operation $op is not recognized. Allowed polynomial operations are \ - `*, /, +, -, ^`. - """ - else - error("This should never happen. Please open an issue in ModelingToolkit.jl.") - end -end - -mutable struct PolynomialData - non_polynomial_terms::Vector{BasicSymbolic} - reasons::Vector{NonPolynomialReason.T} - has_parametric_exponent::Bool -end - -PolynomialData() = PolynomialData(BasicSymbolic[], NonPolynomialReason.T[], false) - -abstract type PolynomialTransformationError <: Exception end - -struct MultivarTerm <: PolynomialTransformationError - term::Any - vars::Any -end - -function Base.showerror(io::IO, err::MultivarTerm) - println(io, - "Cannot convert system to polynomial: Found term $(err.term) which is a function of multiple unknowns $(err.vars).") -end - -struct MultipleTermsOfSameVar <: PolynomialTransformationError - terms::Any - var::Any -end - -function Base.showerror(io::IO, err::MultipleTermsOfSameVar) - println(io, - "Cannot convert system to polynomial: Found multiple non-polynomial terms $(err.terms) involving the same unknown $(err.var).") -end - -struct SymbolicSolveFailure <: PolynomialTransformationError - term::Any - var::Any -end - -function Base.showerror(io::IO, err::SymbolicSolveFailure) - println(io, - "Cannot convert system to polynomial: Unable to symbolically solve $(err.term) for $(err.var).") -end - -struct NemoNotLoaded <: PolynomialTransformationError end - -function Base.showerror(io::IO, err::NemoNotLoaded) - println(io, - "ModelingToolkit may be able to solve this system as a polynomial system if `Nemo` is loaded. Run `import Nemo` and try again.") -end - -struct VariablesAsPolyAndNonPoly <: PolynomialTransformationError - vars::Any -end - -function Base.showerror(io::IO, err::VariablesAsPolyAndNonPoly) - println(io, - "Cannot convert convert system to polynomial: Variables $(err.vars) occur in both polynomial and non-polynomial terms in the system.") -end - -struct NotPolynomialError <: Exception - transformation_err::Union{PolynomialTransformationError, Nothing} - eq::Vector{Equation} - data::Vector{PolynomialData} -end - -function Base.showerror(io::IO, err::NotPolynomialError) - if err.transformation_err !== nothing - Base.showerror(io, err.transformation_err) - end - for (eq, data) in zip(err.eq, err.data) - if isempty(data.non_polynomial_terms) - continue - end - println(io, - "Equation $(eq) is not a polynomial in the unknowns for the following reasons:") - for (term, reason) in zip(data.non_polynomial_terms, data.reasons) - println(io, display_reason(reason, term)) - end - end -end - -function is_polynomial!(data, y, wrt) - process_polynomial!(data, y, wrt) - isempty(data.reasons) -end - -""" -$(TYPEDSIGNATURES) - -Return information about the polynmial `x` with respect to variables in `wrt`, -writing said information to `data`. -""" -function process_polynomial!(data::PolynomialData, x, wrt) - x = unwrap(x) - symbolic_type(x) == NotSymbolic() && return true - iscall(x) || return true - contains_variable(x, wrt) || return true - any(isequal(x), wrt) && return true - - if operation(x) in (*, +, -, /) - # `map` because `all` will early exit, but we want to search - # through everything to get all the non-polynomial terms - return all(map(y -> is_polynomial!(data, y, wrt), arguments(x))) - end - if operation(x) == (^) - b, p = arguments(x) - is_pow_integer = symtype(p) <: Integer - if !is_pow_integer - push!(data.non_polynomial_terms, x) - push!(data.reasons, NonPolynomialReason.NonIntegerExponent) - end - if symbolic_type(p) != NotSymbolic() - data.has_parametric_exponent = true - end - - exponent_has_unknowns = contains_variable(p, wrt) - if exponent_has_unknowns - push!(data.non_polynomial_terms, x) - push!(data.reasons, NonPolynomialReason.ExponentContainsUnknowns) - end - base_polynomial = is_polynomial!(data, b, wrt) - return base_polynomial && !exponent_has_unknowns && is_pow_integer - end - push!(data.non_polynomial_terms, x) - push!(data.reasons, NonPolynomialReason.UnrecognizedOperation) - return false -end - -""" -$(TYPEDSIGNATURES) - -Given a `x`, a polynomial in variables in `wrt` which may contain rational functions, -express `x` as a single rational function with polynomial `num` and denominator `den`. -Return `(num, den)`. -""" -function handle_rational_polynomials(x, wrt) - x = unwrap(x) - symbolic_type(x) == NotSymbolic() && return x, 1 - iscall(x) || return x, 1 - contains_variable(x, wrt) || return x, 1 - any(isequal(x), wrt) && return x, 1 - - # simplify_fractions cancels out some common factors - # and expands (a / b)^c to a^c / b^c, so we only need - # to handle these cases - x = simplify_fractions(x) - op = operation(x) - args = arguments(x) - - if op == / - # numerator and denominator are trivial - num, den = args - # but also search for rational functions in numerator - n, d = handle_rational_polynomials(num, wrt) - num, den = n, den * d - elseif op == + - num = 0 - den = 1 - - # we don't need to do common denominator - # because we don't care about cases where denominator - # is zero. The expression is zero when all the numerators - # are zero. - for arg in args - n, d = handle_rational_polynomials(arg, wrt) - num += n - den *= d - end - else - return x, 1 - end - # if the denominator isn't a polynomial in `wrt`, better to not include it - # to reduce the size of the gcd polynomial - if !contains_variable(den, wrt) - return num / den, 1 - end - return num, den -end - """ $(TYPEDSIGNATURES) @@ -289,12 +78,6 @@ end SymbolicIndexingInterface.parameter_values(s::MTKHomotopySystem) = s.p -struct PolynomialTransformationData - new_var::BasicSymbolic - term::BasicSymbolic - inv_term::Vector -end - """ $(TYPEDSIGNATURES) @@ -312,128 +95,31 @@ Keyword arguments: All other keyword arguments are forwarded to `HomotopyContinuation.solver_startsystems`. """ function MTK.HomotopyContinuationProblem( - sys::NonlinearSystem, u0map, parammap = nothing; eval_expression = false, - eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...) + sys::NonlinearSystem, u0map, parammap = nothing; kwargs...) if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`") end - - dvs = unknowns(sys) - # we need to consider `full_equations` because observed also should be - # polynomials (if used in equations) and we don't know if observed is used - # in denominator. - # This is not the most efficient, and would be improved significantly with - # CSE/hashconsing. - eqs = full_equations(sys) - - polydata = map(eqs) do eq - data = PolynomialData() - process_polynomial!(data, eq.lhs, dvs) - process_polynomial!(data, eq.rhs, dvs) - data - end - - has_parametric_exponents = any(d -> d.has_parametric_exponent, polydata) - - all_non_poly_terms = mapreduce(d -> d.non_polynomial_terms, vcat, polydata) - unique!(all_non_poly_terms) - - var_to_nonpoly = Dict{BasicSymbolic, PolynomialTransformationData}() - - is_poly = true - transformation_err = nothing - for t in all_non_poly_terms - # if the term involves multiple unknowns, we can't invert it - dvs_in_term = map(x -> occursin(x, t), dvs) - if count(dvs_in_term) > 1 - transformation_err = MultivarTerm(t, dvs[dvs_in_term]) - is_poly = false - break - end - # we already have a substitution solving for `var` - var = dvs[findfirst(dvs_in_term)] - if haskey(var_to_nonpoly, var) && !isequal(var_to_nonpoly[var].term, t) - transformation_err = MultipleTermsOfSameVar([t, var_to_nonpoly[var].term], var) - is_poly = false - break - end - # we want to solve `term - new_var` for `var` - new_var = gensym(Symbol(var)) - new_var = unwrap(only(@variables $new_var)) - invterm = Symbolics.ia_solve( - t - new_var, var; complex_roots = false, periodic_roots = false, warns = false) - # if we can't invert it, quit - if invterm === nothing || isempty(invterm) - transformation_err = SymbolicSolveFailure(t, var) - is_poly = false - break - end - # `ia_solve` returns lazy terms i.e. `asin(1.0)` instead of `pi/2` - # this just evaluates the constant expressions - invterm = Symbolics.substitute.(invterm, (Dict(),)) - # RootsOf implies Symbolics couldn't solve the inner polynomial because - # `Nemo` wasn't loaded. - if any(x -> MTK.iscall(x) && MTK.operation(x) == Symbolics.RootsOf, invterm) - transformation_err = NemoNotLoaded() - is_poly = false - break - end - var_to_nonpoly[var] = PolynomialTransformationData(new_var, t, invterm) + transformation = MTK.PolynomialTransformation(sys) + if transformation isa MTK.NotPolynomialError + throw(transformation) end - - if !is_poly - throw(NotPolynomialError(transformation_err, eqs, polydata)) - end - - subrules = Dict() - combinations = Vector[] - new_dvs = [] - for x in dvs - if haskey(var_to_nonpoly, x) - _data = var_to_nonpoly[x] - subrules[_data.term] = _data.new_var - push!(combinations, _data.inv_term) - push!(new_dvs, _data.new_var) - else - push!(combinations, [x]) - push!(new_dvs, x) - end - end - all_solutions = collect.(collect(Iterators.product(combinations...))) - - denoms = [] - eqs2 = map(eqs) do eq - t = eq.rhs - eq.lhs - t = Symbolics.fixpoint_sub(t, subrules; maxiters = length(dvs)) - # the substituted variable occurs outside the substituted term - poly_and_nonpoly = map(dvs) do x - haskey(var_to_nonpoly, x) && occursin(x, t) - end - if any(poly_and_nonpoly) - throw(NotPolynomialError( - VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata)) - end - - num, den = handle_rational_polynomials(t, new_dvs) - # make factors different elements, otherwise the nonzero factors artificially - # inflate the error of the zero factor. - if iscall(den) && operation(den) == * - for arg in arguments(den) - # ignore constant factors - symbolic_type(arg) == NotSymbolic() && continue - push!(denoms, abs(arg)) - end - elseif symbolic_type(den) != NotSymbolic() - push!(denoms, abs(den)) - end - return 0 ~ num + result = MTK.transform_system(sys, transformation) + if result isa MTK.NotPolynomialError + throw(result) end + MTK.HomotopyContinuationProblem(sys, transformation, result, u0map, parammap; kwargs...) +end - sys2 = MTK.@set sys.eqs = eqs2 - MTK.@set! sys2.unknowns = new_dvs - # remove observed equations to avoid adding them in codegen - MTK.@set! sys2.observed = Equation[] - MTK.@set! sys2.substitutions = nothing +function MTK.HomotopyContinuationProblem( + sys::MTK.NonlinearSystem, transformation::MTK.PolynomialTransformation, + result::MTK.PolynomialTransformationResult, u0map, + parammap = nothing; eval_expression = false, + eval_module = ModelingToolkit, warn_parametric_exponent = true, kwargs...) + sys2 = result.sys + denoms = result.denominators + polydata = transformation.polydata + new_dvs = transformation.new_dvs + all_solutions = transformation.all_solutions _, u0, p = MTK.process_SciMLProblem( MTK.EmptySciMLFunction, sys, u0map, parammap; eval_expression, eval_module) @@ -443,10 +129,11 @@ function MTK.HomotopyContinuationProblem( unpack_solution = MTK.build_explicit_observed_function(sys2, all_solutions) hvars = symbolics_to_hc.(new_dvs) - mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(eqs)) + mtkhsys = MTKHomotopySystem(nlfn.f, p, nlfn.jac, hvars, length(new_dvs)) obsfn = MTK.ObservedFunctionCache(sys; eval_expression, eval_module) + has_parametric_exponents = any(d -> d.has_parametric_exponent, polydata) if has_parametric_exponents if warn_parametric_exponent @warn """ diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl index e41a3f0ce7..044f44f70f 100644 --- a/src/systems/nonlinear/homotopy_continuation.jl +++ b/src/systems/nonlinear/homotopy_continuation.jl @@ -66,3 +66,440 @@ function SymbolicIndexingInterface.observed(p::HomotopyContinuationProblem, sym) return SymbolicIndexingInterface.observed(p.sys, sym) end end + +function contains_variable(x, wrt) + any(y -> occursin(y, x), wrt) +end + +""" +Possible reasons why a term is not polynomial +""" +EnumX.@enumx NonPolynomialReason begin + NonIntegerExponent + ExponentContainsUnknowns + BaseNotPolynomial + UnrecognizedOperation +end + +function display_reason(reason::NonPolynomialReason.T, sym) + if reason == NonPolynomialReason.NonIntegerExponent + pow = arguments(sym)[2] + "In $sym: Exponent $pow is not an integer" + elseif reason == NonPolynomialReason.ExponentContainsUnknowns + pow = arguments(sym)[2] + "In $sym: Exponent $pow contains unknowns of the system" + elseif reason == NonPolynomialReason.BaseNotPolynomial + base = arguments(sym)[1] + "In $sym: Base $base is not a polynomial in the unknowns" + elseif reason == NonPolynomialReason.UnrecognizedOperation + op = operation(sym) + """ + In $sym: Operation $op is not recognized. Allowed polynomial operations are \ + `*, /, +, -, ^`. + """ + else + error("This should never happen. Please open an issue in ModelingToolkit.jl.") + end +end + +""" + $(TYPEDEF) + +Information about an expression about its polynomial nature. +""" +mutable struct PolynomialData + """ + A list of all non-polynomial terms in the expression. + """ + non_polynomial_terms::Vector{BasicSymbolic} + """ + Corresponding to `non_polynomial_terms`, a list of reasons why they are + not polynomial. + """ + reasons::Vector{NonPolynomialReason.T} + """ + Whether the polynomial contains parametric exponents of unknowns. + """ + has_parametric_exponent::Bool +end + +PolynomialData() = PolynomialData(BasicSymbolic[], NonPolynomialReason.T[], false) + +abstract type PolynomialTransformationError <: Exception end + +struct MultivarTerm <: PolynomialTransformationError + term::Any + vars::Any +end + +function Base.showerror(io::IO, err::MultivarTerm) + println(io, + "Cannot convert system to polynomial: Found term $(err.term) which is a function of multiple unknowns $(err.vars).") +end + +struct MultipleTermsOfSameVar <: PolynomialTransformationError + terms::Any + var::Any +end + +function Base.showerror(io::IO, err::MultipleTermsOfSameVar) + println(io, + "Cannot convert system to polynomial: Found multiple non-polynomial terms $(err.terms) involving the same unknown $(err.var).") +end + +struct SymbolicSolveFailure <: PolynomialTransformationError + term::Any + var::Any +end + +function Base.showerror(io::IO, err::SymbolicSolveFailure) + println(io, + "Cannot convert system to polynomial: Unable to symbolically solve $(err.term) for $(err.var).") +end + +struct NemoNotLoaded <: PolynomialTransformationError end + +function Base.showerror(io::IO, err::NemoNotLoaded) + println(io, + "ModelingToolkit may be able to solve this system as a polynomial system if `Nemo` is loaded. Run `import Nemo` and try again.") +end + +struct VariablesAsPolyAndNonPoly <: PolynomialTransformationError + vars::Any +end + +function Base.showerror(io::IO, err::VariablesAsPolyAndNonPoly) + println(io, + "Cannot convert convert system to polynomial: Variables $(err.vars) occur in both polynomial and non-polynomial terms in the system.") +end + +struct NotPolynomialError <: Exception + transformation_err::Union{PolynomialTransformationError, Nothing} + eq::Vector{Equation} + data::Vector{PolynomialData} +end + +function Base.showerror(io::IO, err::NotPolynomialError) + if err.transformation_err !== nothing + Base.showerror(io, err.transformation_err) + end + for (eq, data) in zip(err.eq, err.data) + if isempty(data.non_polynomial_terms) + continue + end + println(io, + "Equation $(eq) is not a polynomial in the unknowns for the following reasons:") + for (term, reason) in zip(data.non_polynomial_terms, data.reasons) + println(io, display_reason(reason, term)) + end + end +end + +function is_polynomial!(data, y, wrt) + process_polynomial!(data, y, wrt) + isempty(data.reasons) +end + +""" +$(TYPEDSIGNATURES) + +Return information about the polynmial `x` with respect to variables in `wrt`, +writing said information to `data`. +""" +function process_polynomial!(data::PolynomialData, x, wrt) + x = unwrap(x) + symbolic_type(x) == NotSymbolic() && return true + iscall(x) || return true + contains_variable(x, wrt) || return true + any(isequal(x), wrt) && return true + + if operation(x) in (*, +, -, /) + # `map` because `all` will early exit, but we want to search + # through everything to get all the non-polynomial terms + return all(map(y -> is_polynomial!(data, y, wrt), arguments(x))) + end + if operation(x) == (^) + b, p = arguments(x) + is_pow_integer = symtype(p) <: Integer + if !is_pow_integer + push!(data.non_polynomial_terms, x) + push!(data.reasons, NonPolynomialReason.NonIntegerExponent) + end + if symbolic_type(p) != NotSymbolic() + data.has_parametric_exponent = true + end + + exponent_has_unknowns = contains_variable(p, wrt) + if exponent_has_unknowns + push!(data.non_polynomial_terms, x) + push!(data.reasons, NonPolynomialReason.ExponentContainsUnknowns) + end + base_polynomial = is_polynomial!(data, b, wrt) + return base_polynomial && !exponent_has_unknowns && is_pow_integer + end + push!(data.non_polynomial_terms, x) + push!(data.reasons, NonPolynomialReason.UnrecognizedOperation) + return false +end + +""" + $(TYPEDEF) + +Information about how an unknown in the system is substituted for a non-polynomial +expression to turn the system into a polynomial. Used in `PolynomialTransformation`. +""" +struct PolynomialTransformationData + """ + The new variable to use as an unknown of the transformed system. + """ + new_var::BasicSymbolic + """ + The non-polynomial expression being substituted. + """ + term::BasicSymbolic + """ + A vector of expressions corresponding to the solutions of + the non-polynomial expression `term` in terms of the new unknown `new_var`, + used to backsolve for the original unknown of the system. + """ + inv_term::Vector{BasicSymbolic} +end + +""" + $(TYPEDEF) + +Information representing how to transform a `NonlinearSystem` into a polynomial +system. +""" +struct PolynomialTransformation + """ + Substitutions mapping non-polynomial terms to temporary unknowns. The system + is a polynomial in the new unknowns. Currently, each non-polynomial term is a + function of a single unknown of the original system. + """ + substitution_rules::Dict{BasicSymbolic, BasicSymbolic} + """ + A vector of expressions involving unknowns of the transformed system, mapping + back to solutions of the original system. + """ + all_solutions::Vector{Vector{BasicSymbolic}} + """ + The new unknowns of the transformed system. + """ + new_dvs::Vector{BasicSymbolic} + """ + The polynomial data for each equation. + """ + polydata::Vector{PolynomialData} +end + +function PolynomialTransformation(sys::NonlinearSystem) + # we need to consider `full_equations` because observed also should be + # polynomials (if used in equations) and we don't know if observed is used + # in denominator. + # This is not the most efficient, and would be improved significantly with + # CSE/hashconsing. + eqs = full_equations(sys) + dvs = unknowns(sys) + + # Collect polynomial information about all equations + polydata = map(eqs) do eq + data = PolynomialData() + process_polynomial!(data, eq.lhs, dvs) + process_polynomial!(data, eq.rhs, dvs) + data + end + + # Get all unique non-polynomial terms + # NOTE: + # Is there a better way to check for uniqueness? `simplify` is relatively slow + # (maybe use the threaded version?) and `expand` can blow up expression size. + # Could metatheory help? + all_non_poly_terms = mapreduce(d -> d.non_polynomial_terms, vcat, polydata) + unique!(all_non_poly_terms) + + # each variable can only be replaced by one non-polynomial expression involving + # that variable. Keep track of this mapping. + var_to_nonpoly = Dict{BasicSymbolic, PolynomialTransformationData}() + + is_poly = true + transformation_err = nothing + for t in all_non_poly_terms + # if the term involves multiple unknowns, we can't invert it + dvs_in_term = map(x -> occursin(x, t), dvs) + if count(dvs_in_term) > 1 + transformation_err = MultivarTerm(t, dvs[dvs_in_term]) + is_poly = false + break + end + # we already have a substitution solving for `var` + var = dvs[findfirst(dvs_in_term)] + if haskey(var_to_nonpoly, var) && !isequal(var_to_nonpoly[var].term, t) + transformation_err = MultipleTermsOfSameVar([t, var_to_nonpoly[var].term], var) + is_poly = false + break + end + # we want to solve `term - new_var` for `var` + new_var = gensym(Symbol(var)) + new_var = unwrap(only(@variables $new_var)) + invterm = Symbolics.ia_solve( + t - new_var, var; complex_roots = false, periodic_roots = false, warns = false) + # if we can't invert it, quit + if invterm === nothing || isempty(invterm) + transformation_err = SymbolicSolveFailure(t, var) + is_poly = false + break + end + # `ia_solve` returns lazy terms i.e. `asin(1.0)` instead of `pi/2` + # this just evaluates the constant expressions + invterm = Symbolics.substitute.(invterm, (Dict(),)) + # RootsOf implies Symbolics couldn't solve the inner polynomial because + # `Nemo` wasn't loaded. + if any(x -> iscall(x) && operation(x) == Symbolics.RootsOf, invterm) + transformation_err = NemoNotLoaded() + is_poly = false + break + end + var_to_nonpoly[var] = PolynomialTransformationData(new_var, t, invterm) + end + + # return the error instead of throwing it, so the user can choose what to do + # without having to catch the exception + if !is_poly + return NotPolynomialError(transformation_err, eqs, polydata) + end + + subrules = Dict{BasicSymbolic, BasicSymbolic}() + # corresponding to each unknown in `dvs`, the list of its possible solutions + # in terms of the new unknown. + combinations = Vector{BasicSymbolic}[] + new_dvs = BasicSymbolic[] + for x in dvs + if haskey(var_to_nonpoly, x) + _data = var_to_nonpoly[x] + # map term to new unknown + subrules[_data.term] = _data.new_var + push!(combinations, _data.inv_term) + push!(new_dvs, _data.new_var) + else + push!(combinations, BasicSymbolic[x]) + push!(new_dvs, x) + end + end + all_solutions = vec(collect.(collect(Iterators.product(combinations...)))) + return PolynomialTransformation(subrules, all_solutions, new_dvs, polydata) +end + +""" + $(TYPEDEF) + +A struct containing the result of transforming a system into a polynomial system +using the appropriate `PolynomialTransformation`. Also contains the denominators +in the equations, to rule out invalid roots. +""" +struct PolynomialTransformationResult + sys::NonlinearSystem + denominators::Vector{BasicSymbolic} +end + +""" + $(TYPEDSIGNATURES) + +Transform the system `sys` with `transformation` and return a +`PolynomialTransformationResult`, or a `NotPolynomialError` if the system cannot +be transformed. +""" +function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation) + subrules = transformation.substitution_rules + dvs = unknowns(sys) + eqs = full_equations(sys) + polydata = transformation.polydata + new_dvs = transformation.new_dvs + all_solutions = transformation.all_solutions + + eqs2 = Equation[] + denoms = BasicSymbolic[] + for eq in eqs + t = eq.rhs - eq.lhs + t = Symbolics.fixpoint_sub(t, subrules; maxiters = length(dvs)) + # the substituted variable occurs outside the substituted term + poly_and_nonpoly = map(dvs) do x + all(!isequal(x), new_dvs) && occursin(x, t) + end + if any(poly_and_nonpoly) + return NotPolynomialError( + VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata) + end + num, den = handle_rational_polynomials(t, new_dvs) + # make factors different elements, otherwise the nonzero factors artificially + # inflate the error of the zero factor. + if iscall(den) && operation(den) == * + for arg in arguments(den) + # ignore constant factors + symbolic_type(arg) == NotSymbolic() && continue + push!(denoms, abs(arg)) + end + elseif symbolic_type(den) != NotSymbolic() + push!(denoms, abs(den)) + end + push!(eqs2, 0 ~ num) + end + + sys2 = @set sys.eqs = eqs2 + @set! sys2.unknowns = new_dvs + # remove observed equations to avoid adding them in codegen + @set! sys2.observed = Equation[] + @set! sys2.substitutions = nothing + return PolynomialTransformationResult(sys2, denoms) +end + +""" +$(TYPEDSIGNATURES) + +Given a `x`, a polynomial in variables in `wrt` which may contain rational functions, +express `x` as a single rational function with polynomial `num` and denominator `den`. +Return `(num, den)`. +""" +function handle_rational_polynomials(x, wrt) + x = unwrap(x) + symbolic_type(x) == NotSymbolic() && return x, 1 + iscall(x) || return x, 1 + contains_variable(x, wrt) || return x, 1 + any(isequal(x), wrt) && return x, 1 + + # simplify_fractions cancels out some common factors + # and expands (a / b)^c to a^c / b^c, so we only need + # to handle these cases + x = simplify_fractions(x) + op = operation(x) + args = arguments(x) + + if op == / + # numerator and denominator are trivial + num, den = args + # but also search for rational functions in numerator + n, d = handle_rational_polynomials(num, wrt) + num, den = n, den * d + elseif op == + + num = 0 + den = 1 + + # we don't need to do common denominator + # because we don't care about cases where denominator + # is zero. The expression is zero when all the numerators + # are zero. + for arg in args + n, d = handle_rational_polynomials(arg, wrt) + num += n + den *= d + end + else + return x, 1 + end + # if the denominator isn't a polynomial in `wrt`, better to not include it + # to reduce the size of the gcd polynomial + if !contains_variable(den, wrt) + return num / den, 1 + end + return num, den +end From 669bc84e41897d8c3a0a270adf941829be1cdd10 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Nov 2024 19:53:48 +0530 Subject: [PATCH 083/101] feat: add `safe_HomotopyContinuationProblem` --- ext/MTKHomotopyContinuationExt.jl | 10 +++++++-- .../nonlinear/homotopy_continuation.jl | 21 +++++++++++++++++++ test/extensions/homotopy_continuation.jl | 20 ++++++++++++++++++ 3 files changed, 49 insertions(+), 2 deletions(-) diff --git a/ext/MTKHomotopyContinuationExt.jl b/ext/MTKHomotopyContinuationExt.jl index 586d608156..c4a090d9a8 100644 --- a/ext/MTKHomotopyContinuationExt.jl +++ b/ext/MTKHomotopyContinuationExt.jl @@ -96,16 +96,22 @@ All other keyword arguments are forwarded to `HomotopyContinuation.solver_starts """ function MTK.HomotopyContinuationProblem( sys::NonlinearSystem, u0map, parammap = nothing; kwargs...) + prob = MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap; kwargs...) + prob isa MTK.HomotopyContinuationProblem || throw(prob) + return prob +end + +function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; kwargs...) if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`") end transformation = MTK.PolynomialTransformation(sys) if transformation isa MTK.NotPolynomialError - throw(transformation) + return transformation end result = MTK.transform_system(sys, transformation) if result isa MTK.NotPolynomialError - throw(result) + return result end MTK.HomotopyContinuationProblem(sys, transformation, result, u0map, parammap; kwargs...) end diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl index 044f44f70f..ed69db1aee 100644 --- a/src/systems/nonlinear/homotopy_continuation.jl +++ b/src/systems/nonlinear/homotopy_continuation.jl @@ -48,6 +48,27 @@ function HomotopyContinuationProblem(::AbstractSystem, _u0, _p; kwargs...) error("HomotopyContinuation.jl is required to create and solve `HomotopyContinuationProblem`s. Please run `Pkg.add(\"HomotopyContinuation\")` to continue.") end +""" + $(TYPEDSIGNATURES) + +Utility function for `safe_HomotopyContinuationProblem`, implemented in the extension. +""" +function _safe_HomotopyContinuationProblem end + +""" + $(TYPEDSIGNATURES) + +Return a `HomotopyContinuationProblem` if the extension is loaded and the system is +polynomial. If the extension is not loaded, return `nothing`. If the system is not +polynomial, return the appropriate `NotPolynomialError`. +""" +function safe_HomotopyContinuationProblem(sys::NonlinearSystem, args...; kwargs...) + if Base.get_extension(ModelingToolkit, :MTKHomotopyContinuationExt) === nothing + return nothing + end + return _safe_HomotopyContinuationProblem(sys, args...; kwargs...) +end + SymbolicIndexingInterface.symbolic_container(p::HomotopyContinuationProblem) = p.sys SymbolicIndexingInterface.state_values(p::HomotopyContinuationProblem) = p.u0 function SymbolicIndexingInterface.set_state!(p::HomotopyContinuationProblem, args...) diff --git a/test/extensions/homotopy_continuation.jl b/test/extensions/homotopy_continuation.jl index 81c252c84a..3bdfa06a5d 100644 --- a/test/extensions/homotopy_continuation.jl +++ b/test/extensions/homotopy_continuation.jl @@ -1,6 +1,19 @@ using ModelingToolkit, NonlinearSolve, SymbolicIndexingInterface +import ModelingToolkit as MTK using LinearAlgebra using Test + +@testset "Safe HCProblem" begin + @variables x y z + eqs = [0 ~ x^2 + y^2 + 2x * y + 0 ~ x^2 + 4x + 4 + 0 ~ y * z + 4x^2] + @mtkbuild sys = NonlinearSystem(eqs) + prob = MTK.safe_HomotopyContinuationProblem(sys, [x => 1.0, y => 1.0, z => 1.0], []) + @test prob === nothing +end + + import HomotopyContinuation @testset "No parameters" begin @@ -78,30 +91,37 @@ end @test_throws ["Cannot convert", "Unable", "symbolically solve", "Exponent", "not an integer", "not a polynomial"] HomotopyContinuationProblem( sys, []) + @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError @mtkbuild sys = NonlinearSystem([x^x - x ~ 0]) @test_throws ["Cannot convert", "Unable", "symbolically solve", "Exponent", "unknowns", "not a polynomial"] HomotopyContinuationProblem( sys, []) + @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError @mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0]) @test_throws ["Cannot convert", "both polynomial", "non-polynomial", "recognized", "sin", "not a polynomial"] HomotopyContinuationProblem( sys, []) + @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError @variables y = 2.0 @mtkbuild sys = NonlinearSystem([x^2 + y^2 + 2 ~ 0, y ~ sin(x)]) @test_throws ["Cannot convert", "recognized", "sin", "not a polynomial"] HomotopyContinuationProblem( sys, []) + @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError @mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2 ~ 0, sin(x + y) ~ 0]) @test_throws ["Cannot convert", "function of multiple unknowns"] HomotopyContinuationProblem( sys, []) + @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError @mtkbuild sys = NonlinearSystem([sin(x)^2 + 1 ~ 0, cos(y) - cos(x) - 1 ~ 0]) @test_throws ["Cannot convert", "multiple non-polynomial terms", "same unknown"] HomotopyContinuationProblem( sys, []) + @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError @mtkbuild sys = NonlinearSystem([sin(x^2)^2 + sin(x^2) - 1 ~ 0]) @test_throws ["import Nemo"] HomotopyContinuationProblem(sys, []) + @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError end import Nemo From 1ef12170e0206694ac627639b28589871e2128a4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Nov 2024 19:58:00 +0530 Subject: [PATCH 084/101] feat: use `HomotopyContinuationProblem` in `NonlinearProblem` if possible --- src/systems/nonlinear/nonlinearsystem.jl | 6 +++++- test/extensions/homotopy_continuation.jl | 18 ++++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index ff4a7d0eb2..3cb68853aa 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -496,10 +496,14 @@ end function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, parammap = DiffEqBase.NullParameters(); - check_length = true, kwargs...) where {iip} + check_length = true, use_homotopy_continuation = true, kwargs...) where {iip} if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`") end + prob = safe_HomotopyContinuationProblem(sys, u0map, parammap; check_length, kwargs...) + if prob isa HomotopyContinuationProblem + return prob + end f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap; check_length, kwargs...) pt = something(get_metadata(sys), StandardNonlinearProblem()) diff --git a/test/extensions/homotopy_continuation.jl b/test/extensions/homotopy_continuation.jl index 3bdfa06a5d..e57cd0b2f7 100644 --- a/test/extensions/homotopy_continuation.jl +++ b/test/extensions/homotopy_continuation.jl @@ -13,7 +13,6 @@ using Test @test prob === nothing end - import HomotopyContinuation @testset "No parameters" begin @@ -22,12 +21,19 @@ import HomotopyContinuation 0 ~ x^2 + 4x + 4 0 ~ y * z + 4x^2] @mtkbuild sys = NonlinearSystem(eqs) - prob = HomotopyContinuationProblem(sys, [x => 1.0, y => 1.0, z => 1.0], []) + u0 = [x => 1.0, y => 1.0, z => 1.0] + prob = HomotopyContinuationProblem(sys, u0) @test prob[x] == prob[y] == prob[z] == 1.0 @test prob[x + y] == 2.0 sol = solve(prob; threading = false) @test SciMLBase.successful_retcode(sol) @test norm(sol.resid)≈0.0 atol=1e-10 + + prob2 = NonlinearProblem(sys, u0) + @test prob2 isa HomotopyContinuationProblem + sol = solve(prob2; threading = false) + @test SciMLBase.successful_retcode(sol) + @test norm(sol.resid)≈0.0 atol=1e-10 end struct Wrapper @@ -92,36 +98,44 @@ end "Exponent", "not an integer", "not a polynomial"] HomotopyContinuationProblem( sys, []) @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError + @test NonlinearProblem(sys, []) isa NonlinearProblem + @mtkbuild sys = NonlinearSystem([x^x - x ~ 0]) @test_throws ["Cannot convert", "Unable", "symbolically solve", "Exponent", "unknowns", "not a polynomial"] HomotopyContinuationProblem( sys, []) @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError + @test NonlinearProblem(sys, []) isa NonlinearProblem @mtkbuild sys = NonlinearSystem([((x^2) / sin(x))^2 + x ~ 0]) @test_throws ["Cannot convert", "both polynomial", "non-polynomial", "recognized", "sin", "not a polynomial"] HomotopyContinuationProblem( sys, []) @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError + @test NonlinearProblem(sys, []) isa NonlinearProblem @variables y = 2.0 @mtkbuild sys = NonlinearSystem([x^2 + y^2 + 2 ~ 0, y ~ sin(x)]) @test_throws ["Cannot convert", "recognized", "sin", "not a polynomial"] HomotopyContinuationProblem( sys, []) @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError + @test NonlinearProblem(sys, []) isa NonlinearProblem @mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2 ~ 0, sin(x + y) ~ 0]) @test_throws ["Cannot convert", "function of multiple unknowns"] HomotopyContinuationProblem( sys, []) @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError + @test NonlinearProblem(sys, []) isa NonlinearProblem @mtkbuild sys = NonlinearSystem([sin(x)^2 + 1 ~ 0, cos(y) - cos(x) - 1 ~ 0]) @test_throws ["Cannot convert", "multiple non-polynomial terms", "same unknown"] HomotopyContinuationProblem( sys, []) @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError + @test NonlinearProblem(sys, []) isa NonlinearProblem @mtkbuild sys = NonlinearSystem([sin(x^2)^2 + sin(x^2) - 1 ~ 0]) @test_throws ["import Nemo"] HomotopyContinuationProblem(sys, []) @test MTK.safe_HomotopyContinuationProblem(sys, []) isa MTK.NotPolynomialError + @test NonlinearProblem(sys, []) isa NonlinearProblem end import Nemo From aa0e08ad700d5c9f52c0a6fd258ee11c3a5add3e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Nov 2024 23:16:17 +0530 Subject: [PATCH 085/101] docs: add docstrings to `NonPolynomialReason` enum variants --- src/systems/nonlinear/homotopy_continuation.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl index ed69db1aee..03aeed1edf 100644 --- a/src/systems/nonlinear/homotopy_continuation.jl +++ b/src/systems/nonlinear/homotopy_continuation.jl @@ -96,9 +96,21 @@ end Possible reasons why a term is not polynomial """ EnumX.@enumx NonPolynomialReason begin + """ + Exponent of an expression involving unknowns is not an integer. + """ NonIntegerExponent + """ + Exponent is an expression containing unknowns. + """ ExponentContainsUnknowns + """ + The base of an exponent is not a polynomial in the unknowns. + """ BaseNotPolynomial + """ + An expression involves a non-polynomial operation involving unknowns. + """ UnrecognizedOperation end From b9b396ddb80537b71d9c06dca20df22e44eb6bc3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 8 Dec 2024 10:56:17 +0530 Subject: [PATCH 086/101] test: add logs to debug tests --- test/extensions/homotopy_continuation.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/extensions/homotopy_continuation.jl b/test/extensions/homotopy_continuation.jl index e57cd0b2f7..f9724692e7 100644 --- a/test/extensions/homotopy_continuation.jl +++ b/test/extensions/homotopy_continuation.jl @@ -160,6 +160,9 @@ end @test prob[x] ≈ 0.25 @test prob[y] ≈ 0.125 sol = solve(prob; threading = false) + # can't replicate the solve failure locally, so CI logs might help + @show sol.u HomotopyContinuation.real_solutions(sol.original) + @test SciMLBase.successful_retcode(sol) @test sol[a]≈0.5 atol=1e-6 @test sol[b]≈0.25 atol=1e-6 end From f428df4bc0dbc298a0f15280bb980679d7f895ef Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 9 Dec 2024 12:41:19 +0530 Subject: [PATCH 087/101] fixup! test: add logs to debug tests --- test/extensions/homotopy_continuation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/extensions/homotopy_continuation.jl b/test/extensions/homotopy_continuation.jl index f9724692e7..554f9e1e1d 100644 --- a/test/extensions/homotopy_continuation.jl +++ b/test/extensions/homotopy_continuation.jl @@ -161,7 +161,7 @@ end @test prob[y] ≈ 0.125 sol = solve(prob; threading = false) # can't replicate the solve failure locally, so CI logs might help - @show sol.u HomotopyContinuation.real_solutions(sol.original) + @show sol.u sol.original.path_results @test SciMLBase.successful_retcode(sol) @test sol[a]≈0.5 atol=1e-6 @test sol[b]≈0.25 atol=1e-6 From 16d3f5c7a57bffd698a6c69da3ea89e0089b2d54 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 9 Dec 2024 02:52:15 -0800 Subject: [PATCH 088/101] Refactor ImperativeAffect into its own file --- src/ModelingToolkit.jl | 1 + src/systems/callbacks.jl | 221 +------------------------------ src/systems/diffeqs/odesystem.jl | 2 - src/systems/imperative_affect.jl | 220 ++++++++++++++++++++++++++++++ 4 files changed, 224 insertions(+), 220 deletions(-) create mode 100644 src/systems/imperative_affect.jl diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index de8e69c41f..ccdaa15e99 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -145,6 +145,7 @@ include("systems/parameter_buffer.jl") include("systems/abstractsystem.jl") include("systems/model_parsing.jl") include("systems/connectors.jl") +include("systems/imperative_affect.jl") include("systems/callbacks.jl") include("systems/problem_utils.jl") diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 0cb7f16c9f..4533534029 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -73,111 +73,6 @@ function namespace_affect(affect::FunctionalAffect, s) context(affect)) end -""" - ImperativeAffect(f::Function; modified::NamedTuple, observed::NamedTuple, ctx) - -`ImperativeAffect` is a helper for writing affect functions that will compute observed values and -ensure that modified values are correctly written back into the system. The affect function `f` needs to have -the signature - -``` - f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple -``` - -The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions. -Each declaration`NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x` -so the value of `a + b` will be accessible as `observed.x` in `f`. `modified` currently restricts symbolic expressions to only bare variables, so only tuples of the form -`(; x = y)` or `(; x)` (which aliases `x` as itself) are allowed. - -The argument NamedTuples (for instance `(;x=y)`) will be populated with the declared values on function entry; if we require `(;x=y)` in `observed` and `y=2`, for example, -then the NamedTuple `(;x=2)` will be passed as `observed` to the affect function `f`. - -The NamedTuple returned from `f` includes the values to be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write - - ImperativeAffect(observed=(; x_plus_y = x + y), modified=(; x)) do m, o - @set! m.x = o.x_plus_y - end - -Where we use Setfield to copy the tuple `m` with a new value for `x`, then return the modified value of `m`. All values updated by the tuple must have names originally declared in -`modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it -in the returned tuple, in which case the associated field will not be updated. -""" -@kwdef struct ImperativeAffect - f::Any - obs::Vector - obs_syms::Vector{Symbol} - modified::Vector - mod_syms::Vector{Symbol} - ctx::Any - skip_checks::Bool -end - -function ImperativeAffect(f::Function; - observed::NamedTuple = NamedTuple{()}(()), - modified::NamedTuple = NamedTuple{()}(()), - ctx = nothing, - skip_checks = false) - ImperativeAffect(f, - collect(values(observed)), collect(keys(observed)), - collect(values(modified)), collect(keys(modified)), - ctx, skip_checks) -end -function ImperativeAffect(f::Function, modified::NamedTuple; - observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false) - ImperativeAffect( - f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) -end -function ImperativeAffect( - f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks = false) - ImperativeAffect( - f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) -end -function ImperativeAffect( - f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks = false) - ImperativeAffect( - f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) -end - -function Base.show(io::IO, mfa::ImperativeAffect) - obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ") - mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ") - affect = mfa.f - print(io, - "ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)") -end -func(f::ImperativeAffect) = f.f -context(a::ImperativeAffect) = a.ctx -observed(a::ImperativeAffect) = a.obs -observed_syms(a::ImperativeAffect) = a.obs_syms -discretes(a::ImperativeAffect) = filter(ModelingToolkit.isparameter, a.modified) -modified(a::ImperativeAffect) = a.modified -modified_syms(a::ImperativeAffect) = a.mod_syms - -function Base.:(==)(a1::ImperativeAffect, a2::ImperativeAffect) - isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) && - isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) && - isequal(a1.ctx, a2.ctx) -end - -function Base.hash(a::ImperativeAffect, s::UInt) - s = hash(a.f, s) - s = hash(a.obs, s) - s = hash(a.obs_syms, s) - s = hash(a.modified, s) - s = hash(a.mod_syms, s) - hash(a.ctx, s) -end - -function namespace_affect(affect::ImperativeAffect, s) - ImperativeAffect(func(affect), - namespace_expr.(observed(affect), (s,)), - observed_syms(affect), - renamespace.((s,), modified(affect)), - modified_syms(affect), - context(affect), - affect.skip_checks) -end - function has_functional_affect(cb) (affects(cb) isa FunctionalAffect || affects(cb) isa ImperativeAffect) end @@ -203,13 +98,13 @@ sharp discontinuity between integrator steps (which in this example would not no guaranteed to be triggered. Once detected the integrator will "wind back" through a root-finding process to identify the point when the condition became active; the method used -is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). If we denote the time when the condition becomes active at tc, +is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). If we denote the time when the condition becomes active as `tc``, the value in the integrator after windback will be: * `u[tc-epsilon], p[tc-epsilon], tc` if `LeftRootFind` is used, * `u[tc+epsilon], p[tc+epsilon], tc` if `RightRootFind` is used, * or `u[t], p[t], t` if `NoRootFind` is used. For example, if we want to detect when an unknown variable `x` satisfies `x > 0` using the condition `x ~ 0` on a positive edge (that is, `D(x) > 0`), -then left root finding will get us `x=-epsilon`, right root finding `x=epsilon` and no root finding whatever the next step of the integrator was after +then left root finding will get us `x=-epsilon`, right root finding `x=epsilon` and no root finding will produce whatever the next step of the integrator was after it passed through 0. Multiple callbacks in the same system with different `rootfind` operations will be grouped @@ -405,7 +300,6 @@ end namespace_affects(af::Vector, s) = Equation[namespace_affect(a, s) for a in af] namespace_affects(af::FunctionalAffect, s) = namespace_affect(af, s) -namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s) namespace_affects(::Nothing, s) = nothing function namespace_callback(cb::SymbolicContinuousCallback, s)::SymbolicContinuousCallback @@ -480,7 +374,6 @@ scalarize_affects(affects) = scalarize(affects) scalarize_affects(affects::Tuple) = FunctionalAffect(affects...) scalarize_affects(affects::NamedTuple) = FunctionalAffect(; affects...) scalarize_affects(affects::FunctionalAffect) = affects -scalarize_affects(affects::ImperativeAffect) = affects SymbolicDiscreteCallback(p::Pair) = SymbolicDiscreteCallback(p[1], p[2]) SymbolicDiscreteCallback(cb::SymbolicDiscreteCallback) = cb # passthrough @@ -1099,117 +992,9 @@ function check_assignable(sys, sym) end end -function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs...) - #= - Implementation sketch: - generate observed function (oop), should save to a component array under obs_syms - do the same stuff as the normal FA for pars_syms - call the affect method - unpack and apply the resulting values - =# - function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup) - seen = Set{Symbol}() - syms_dedup = [] - exprs_dedup = [] - for (sym, exp) in Iterators.zip(syms, exprs) - if !in(sym, seen) - push!(syms_dedup, sym) - push!(exprs_dedup, exp) - push!(seen, sym) - elseif !affect.skip_checks - @warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used." - end - end - return (syms_dedup, exprs_dedup) - end - - obs_exprs = observed(affect) - if !affect.skip_checks - for oexpr in obs_exprs - invalid_vars = invalid_variables(sys, oexpr) - if length(invalid_vars) > 0 - error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).") - end - end - end - obs_syms = observed_syms(affect) - obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs) - - mod_exprs = modified(affect) - if !affect.skip_checks - for mexpr in mod_exprs - if !check_assignable(sys, mexpr) - @warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") - end - invalid_vars = unassignable_variables(sys, mexpr) - if length(invalid_vars) > 0 - error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") - end - end - end - mod_syms = modified_syms(affect) - mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs) - - overlapping_syms = intersect(mod_syms, obs_syms) - if length(overlapping_syms) > 0 && !affect.skip_checks - @warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value." - end - - # sanity checks done! now build the data and update function for observed values - mkzero(sz) = - if sz === () - 0.0 - else - zeros(sz) - end - obs_fun = build_explicit_observed_function( - sys, Symbolics.scalarize.(obs_exprs); - array_type = Tuple) - obs_sym_tuple = (obs_syms...,) - - # okay so now to generate the stuff to assign it back into the system - mod_pairs = mod_exprs .=> mod_syms - mod_names = (mod_syms...,) - mod_og_val_fun = build_explicit_observed_function( - sys, Symbolics.scalarize.(first.(mod_pairs)); - array_type = Tuple) - - upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) - - if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing - save_idxs = get(ic.callback_to_clocks, cb, Int[]) - else - save_idxs = Int[] - end - - let user_affect = func(affect), ctx = context(affect) - function (integ) - # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens - modvals = mod_og_val_fun(integ.u, integ.p, integ.t) - upd_component_array = NamedTuple{mod_names}(modvals) - - # update the observed values - obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun( - integ.u, integ.p, integ.t)) - - # let the user do their thing - upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) - - # write the new values back to the integrator - _generated_writeback(integ, upd_funs, upd_vals) - - for idx in save_idxs - SciMLBase.save_discretes!(integ, idx) - end - end - end -end - -function compile_affect( - affect::Union{FunctionalAffect, ImperativeAffect}, cb, sys, dvs, ps; kwargs...) +function compile_affect(affect::FunctionalAffect, cb, sys, dvs, ps; kwargs...) compile_user_affect(affect, cb, sys, dvs, ps; kwargs...) end - function _compile_optional_affect(default, aff, cb, sys, dvs, ps; kwargs...) if isnothing(aff) || aff == default return nothing diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 4040b7a646..2b0bd8c8d7 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -629,8 +629,6 @@ function build_explicit_observed_function(sys, ts; oop_mtkp_wrapper = mtkparams_wrapper end - output_expr = isscalar ? ts[1] : - (array_type <: Vector ? MakeArray(ts, output_type) : MakeTuple(ts)) # Need to keep old method of building the function since it uses `output_type`, # which can't be provided to `build_function` return_value = if isscalar diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl new file mode 100644 index 0000000000..f20ad941da --- /dev/null +++ b/src/systems/imperative_affect.jl @@ -0,0 +1,220 @@ + +""" + ImperativeAffect(f::Function; modified::NamedTuple, observed::NamedTuple, ctx) + +`ImperativeAffect` is a helper for writing affect functions that will compute observed values and +ensure that modified values are correctly written back into the system. The affect function `f` needs to have +the signature + +``` + f(modified::NamedTuple, observed::NamedTuple, ctx, integrator)::NamedTuple +``` + +The function `f` will be called with `observed` and `modified` `NamedTuple`s that are derived from their respective `NamedTuple` definitions. +Each declaration`NamedTuple` should map an expression to a symbol; for example if we pass `observed=(; x = a + b)` this will alias the result of executing `a+b` in the system as `x` +so the value of `a + b` will be accessible as `observed.x` in `f`. `modified` currently restricts symbolic expressions to only bare variables, so only tuples of the form +`(; x = y)` or `(; x)` (which aliases `x` as itself) are allowed. + +The argument NamedTuples (for instance `(;x=y)`) will be populated with the declared values on function entry; if we require `(;x=y)` in `observed` and `y=2`, for example, +then the NamedTuple `(;x=2)` will be passed as `observed` to the affect function `f`. + +The NamedTuple returned from `f` includes the values to be written back to the system after `f` returns. For example, if we want to update the value of `x` to be the result of `x + y` we could write + + ImperativeAffect(observed=(; x_plus_y = x + y), modified=(; x)) do m, o + @set! m.x = o.x_plus_y + end + +Where we use Setfield to copy the tuple `m` with a new value for `x`, then return the modified value of `m`. All values updated by the tuple must have names originally declared in +`modified`; a runtime error will be produced if a value is written that does not appear in `modified`. The user can dynamically decide not to write a value back by not including it +in the returned tuple, in which case the associated field will not be updated. +""" +@kwdef struct ImperativeAffect + f::Any + obs::Vector + obs_syms::Vector{Symbol} + modified::Vector + mod_syms::Vector{Symbol} + ctx::Any + skip_checks::Bool +end + +function ImperativeAffect(f::Function; + observed::NamedTuple = NamedTuple{()}(()), + modified::NamedTuple = NamedTuple{()}(()), + ctx = nothing, + skip_checks = false) + ImperativeAffect(f, + collect(values(observed)), collect(keys(observed)), + collect(values(modified)), collect(keys(modified)), + ctx, skip_checks) +end +function ImperativeAffect(f::Function, modified::NamedTuple; + observed::NamedTuple = NamedTuple{()}(()), ctx = nothing, skip_checks = false) + ImperativeAffect( + f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) +end +function ImperativeAffect( + f::Function, modified::NamedTuple, observed::NamedTuple; ctx = nothing, skip_checks = false) + ImperativeAffect( + f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) +end +function ImperativeAffect( + f::Function, modified::NamedTuple, observed::NamedTuple, ctx; skip_checks = false) + ImperativeAffect( + f, observed = observed, modified = modified, ctx = ctx, skip_checks = skip_checks) +end + +function Base.show(io::IO, mfa::ImperativeAffect) + obs_vals = join(map((ob, nm) -> "$ob => $nm", mfa.obs, mfa.obs_syms), ", ") + mod_vals = join(map((md, nm) -> "$md => $nm", mfa.modified, mfa.mod_syms), ", ") + affect = mfa.f + print(io, + "ImperativeAffect(observed: [$obs_vals], modified: [$mod_vals], affect:$affect)") +end +func(f::ImperativeAffect) = f.f +context(a::ImperativeAffect) = a.ctx +observed(a::ImperativeAffect) = a.obs +observed_syms(a::ImperativeAffect) = a.obs_syms +discretes(a::ImperativeAffect) = filter(ModelingToolkit.isparameter, a.modified) +modified(a::ImperativeAffect) = a.modified +modified_syms(a::ImperativeAffect) = a.mod_syms + +function Base.:(==)(a1::ImperativeAffect, a2::ImperativeAffect) + isequal(a1.f, a2.f) && isequal(a1.obs, a2.obs) && isequal(a1.modified, a2.modified) && + isequal(a1.obs_syms, a2.obs_syms) && isequal(a1.mod_syms, a2.mod_syms) && + isequal(a1.ctx, a2.ctx) +end + +function Base.hash(a::ImperativeAffect, s::UInt) + s = hash(a.f, s) + s = hash(a.obs, s) + s = hash(a.obs_syms, s) + s = hash(a.modified, s) + s = hash(a.mod_syms, s) + hash(a.ctx, s) +end + + +namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s) +function namespace_affect(affect::ImperativeAffect, s) + ImperativeAffect(func(affect), + namespace_expr.(observed(affect), (s,)), + observed_syms(affect), + renamespace.((s,), modified(affect)), + modified_syms(affect), + context(affect), + affect.skip_checks) +end + +function compile_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs...) + compile_user_affect(affect, cb, sys, dvs, ps; kwargs...) +end + +function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs...) + #= + Implementation sketch: + generate observed function (oop), should save to a component array under obs_syms + do the same stuff as the normal FA for pars_syms + call the affect method + unpack and apply the resulting values + =# + function check_dups(syms, exprs) # = (syms_dedup, exprs_dedup) + seen = Set{Symbol}() + syms_dedup = [] + exprs_dedup = [] + for (sym, exp) in Iterators.zip(syms, exprs) + if !in(sym, seen) + push!(syms_dedup, sym) + push!(exprs_dedup, exp) + push!(seen, sym) + elseif !affect.skip_checks + @warn "Expression $(expr) is aliased as $sym, which has already been used. The first definition will be used." + end + end + return (syms_dedup, exprs_dedup) + end + + obs_exprs = observed(affect) + if !affect.skip_checks + for oexpr in obs_exprs + invalid_vars = invalid_variables(sys, oexpr) + if length(invalid_vars) > 0 + error("Observed equation $(oexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing).") + end + end + end + obs_syms = observed_syms(affect) + obs_syms, obs_exprs = check_dups(obs_syms, obs_exprs) + + mod_exprs = modified(affect) + if !affect.skip_checks + for mexpr in mod_exprs + if !check_assignable(sys, mexpr) + @warn ("Expression $mexpr cannot be assigned to; currently only unknowns and parameters may be updated by an affect.") + end + invalid_vars = unassignable_variables(sys, mexpr) + if length(invalid_vars) > 0 + error("Modified equation $(mexpr) in affect refers to missing variable(s) $(invalid_vars); the variables may not have been added (e.g. if a component is missing) or they may have been reduced away.") + end + end + end + mod_syms = modified_syms(affect) + mod_syms, mod_exprs = check_dups(mod_syms, mod_exprs) + + overlapping_syms = intersect(mod_syms, obs_syms) + if length(overlapping_syms) > 0 && !affect.skip_checks + @warn "The symbols $overlapping_syms are declared as both observed and modified; this is a code smell because it becomes easy to confuse them and assign/not assign a value." + end + + # sanity checks done! now build the data and update function for observed values + mkzero(sz) = + if sz === () + 0.0 + else + zeros(sz) + end + obs_fun = build_explicit_observed_function( + sys, Symbolics.scalarize.(obs_exprs); + array_type = Tuple) + obs_sym_tuple = (obs_syms...,) + + # okay so now to generate the stuff to assign it back into the system + mod_pairs = mod_exprs .=> mod_syms + mod_names = (mod_syms...,) + mod_og_val_fun = build_explicit_observed_function( + sys, Symbolics.scalarize.(first.(mod_pairs)); + array_type = Tuple) + + upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) + + if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing + save_idxs = get(ic.callback_to_clocks, cb, Int[]) + else + save_idxs = Int[] + end + + let user_affect = func(affect), ctx = context(affect) + function (integ) + # update the to-be-mutated values; this ensures that if you do a no-op then nothing happens + modvals = mod_og_val_fun(integ.u, integ.p, integ.t) + upd_component_array = NamedTuple{mod_names}(modvals) + + # update the observed values + obs_component_array = NamedTuple{obs_sym_tuple}(obs_fun( + integ.u, integ.p, integ.t)) + + # let the user do their thing + upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ) + + # write the new values back to the integrator + _generated_writeback(integ, upd_funs, upd_vals) + + for idx in save_idxs + SciMLBase.save_discretes!(integ, idx) + end + end + end +end + + +scalarize_affects(affects::ImperativeAffect) = affects From 1c78dee591ac5623a240084005545dbca04e2941 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 9 Dec 2024 03:16:24 -0800 Subject: [PATCH 089/101] Fix tests & update API usage --- src/systems/callbacks.jl | 2 +- src/systems/imperative_affect.jl | 4 ++-- test/symbolic_events.jl | 8 ++++---- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 4533534029..57db5e097c 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -818,7 +818,7 @@ function generate_vector_rootfinding_callback( let save_idxs = save_idxs custom_init = fn.initialize (i) -> begin - isnothing(custom_init) && custom_init(i) + !isnothing(custom_init) && custom_init(i) for idx in save_idxs SciMLBase.save_discretes!(i, idx) end diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index f20ad941da..19f7f7590d 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -175,7 +175,7 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. end obs_fun = build_explicit_observed_function( sys, Symbolics.scalarize.(obs_exprs); - array_type = Tuple) + mkarray = (es,_) -> MakeTuple(es)) obs_sym_tuple = (obs_syms...,) # okay so now to generate the stuff to assign it back into the system @@ -183,7 +183,7 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. mod_names = (mod_syms...,) mod_og_val_fun = build_explicit_observed_function( sys, Symbolics.scalarize.(first.(mod_pairs)); - array_type = Tuple) + mkarray = (es,_) -> MakeTuple(es)) upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 717d2438ac..858a30f4fd 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -936,7 +936,7 @@ end @named trigsys = ODESystem(eqs, t; continuous_events = [evt1, evt2]) trigsys_ss = structural_simplify(trigsys) prob = ODEProblem(trigsys_ss, [], (0.0, 2π)) - sol = solve(prob, Tsit5()) + sol = solve(prob, Tsit5(); dtmax=0.01) required_crossings_c1 = [π / 2, 3 * π / 2] required_crossings_c2 = [π / 6, π / 2, 5 * π / 6, 7 * π / 6, 3 * π / 2, 11 * π / 6] @test maximum(abs.(first.(cr1) .- required_crossings_c1)) < 1e-4 @@ -1079,8 +1079,8 @@ end @test sort(canonicalize(Discrete(), prob.p)[1]) == [0.0, 1.0, 2.0] sol = solve(prob, Tsit5()) - @test sol[a] == [-1.0] - @test sol[b] == [5.0, 5.0] + @test sol[a] == [1.0,-1.0] + @test sol[b] == [2.0,5.0, 5.0] @test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] end @testset "Heater" begin @@ -1248,7 +1248,7 @@ end ss = structural_simplify(sys) prob = ODEProblem(ss, [theta => 0.0], (0.0, pi)) sol = solve(prob, Tsit5(); dtmax = 0.01) - @test getp(sol, cnt)(sol) == 197 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state + @test getp(sol, cnt)(sol) == 198 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state end @testset "Initialization" begin From aa556d612a12f2cc297b7af7282a959aa8674775 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 9 Dec 2024 03:59:31 -0800 Subject: [PATCH 090/101] Adjust quadrature test forward a little to avoid numerical issues --- test/symbolic_events.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 858a30f4fd..1138b7c96f 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -1246,7 +1246,7 @@ end @named sys = ODESystem( eqs, t, [theta, omega], params; continuous_events = [qAevt, qBevt]) ss = structural_simplify(sys) - prob = ODEProblem(ss, [theta => 0.0], (0.0, pi)) + prob = ODEProblem(ss, [theta => 1e-5], (0.0, pi)) sol = solve(prob, Tsit5(); dtmax = 0.01) @test getp(sol, cnt)(sol) == 198 # we get 2 pulses per phase cycle (cos 0 crossing) and we go to 100 cycles; we miss a few due to the initial state end From c169b9e97629c3031bc9f0e82f3c0899e6e17f71 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Mon, 9 Dec 2024 04:01:00 -0800 Subject: [PATCH 091/101] Formatter --- src/systems/imperative_affect.jl | 6 ++---- test/symbolic_events.jl | 6 +++--- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/src/systems/imperative_affect.jl b/src/systems/imperative_affect.jl index 19f7f7590d..2f489913ad 100644 --- a/src/systems/imperative_affect.jl +++ b/src/systems/imperative_affect.jl @@ -94,7 +94,6 @@ function Base.hash(a::ImperativeAffect, s::UInt) hash(a.ctx, s) end - namespace_affects(af::ImperativeAffect, s) = namespace_affect(af, s) function namespace_affect(affect::ImperativeAffect, s) ImperativeAffect(func(affect), @@ -175,7 +174,7 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. end obs_fun = build_explicit_observed_function( sys, Symbolics.scalarize.(obs_exprs); - mkarray = (es,_) -> MakeTuple(es)) + mkarray = (es, _) -> MakeTuple(es)) obs_sym_tuple = (obs_syms...,) # okay so now to generate the stuff to assign it back into the system @@ -183,7 +182,7 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. mod_names = (mod_syms...,) mod_og_val_fun = build_explicit_observed_function( sys, Symbolics.scalarize.(first.(mod_pairs)); - mkarray = (es,_) -> MakeTuple(es)) + mkarray = (es, _) -> MakeTuple(es)) upd_funs = NamedTuple{mod_names}((setu.((sys,), first.(mod_pairs))...,)) @@ -216,5 +215,4 @@ function compile_user_affect(affect::ImperativeAffect, cb, sys, dvs, ps; kwargs. end end - scalarize_affects(affects::ImperativeAffect) = affects diff --git a/test/symbolic_events.jl b/test/symbolic_events.jl index 1138b7c96f..ccf0a17a40 100644 --- a/test/symbolic_events.jl +++ b/test/symbolic_events.jl @@ -936,7 +936,7 @@ end @named trigsys = ODESystem(eqs, t; continuous_events = [evt1, evt2]) trigsys_ss = structural_simplify(trigsys) prob = ODEProblem(trigsys_ss, [], (0.0, 2π)) - sol = solve(prob, Tsit5(); dtmax=0.01) + sol = solve(prob, Tsit5(); dtmax = 0.01) required_crossings_c1 = [π / 2, 3 * π / 2] required_crossings_c2 = [π / 6, π / 2, 5 * π / 6, 7 * π / 6, 3 * π / 2, 11 * π / 6] @test maximum(abs.(first.(cr1) .- required_crossings_c1)) < 1e-4 @@ -1079,8 +1079,8 @@ end @test sort(canonicalize(Discrete(), prob.p)[1]) == [0.0, 1.0, 2.0] sol = solve(prob, Tsit5()) - @test sol[a] == [1.0,-1.0] - @test sol[b] == [2.0,5.0, 5.0] + @test sol[a] == [1.0, -1.0] + @test sol[b] == [2.0, 5.0, 5.0] @test sol[c] == [0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0] end @testset "Heater" begin From c8a207bb10803ae3248c61245fdf6d7f05e0366e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 10 Dec 2024 16:45:01 +0530 Subject: [PATCH 092/101] build: bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index fe3a9e88e5..534b0bd32b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ModelingToolkit" uuid = "961ee093-0014-501f-94e3-6117800e7a78" authors = ["Yingbo Ma ", "Chris Rackauckas and contributors"] -version = "9.56.0" +version = "9.57.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" From b24e5256cd9d3e15b25e9b9468ee2996af6febdc Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 10 Dec 2024 17:23:24 -0800 Subject: [PATCH 093/101] Update src/systems/callbacks.jl Co-authored-by: Aayush Sabharwal --- src/systems/callbacks.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/callbacks.jl b/src/systems/callbacks.jl index 57db5e097c..eaf31a9c5f 100644 --- a/src/systems/callbacks.jl +++ b/src/systems/callbacks.jl @@ -98,7 +98,7 @@ sharp discontinuity between integrator steps (which in this example would not no guaranteed to be triggered. Once detected the integrator will "wind back" through a root-finding process to identify the point when the condition became active; the method used -is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). If we denote the time when the condition becomes active as `tc``, +is specified by `rootfind` from [`SciMLBase.RootfindOpt`](@ref). If we denote the time when the condition becomes active as `tc`, the value in the integrator after windback will be: * `u[tc-epsilon], p[tc-epsilon], tc` if `LeftRootFind` is used, * `u[tc+epsilon], p[tc+epsilon], tc` if `RightRootFind` is used, From e5b5f61f8dcd4d61c5ae450a82c869ac51462dd8 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 11 Dec 2024 01:43:34 -0100 Subject: [PATCH 094/101] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 534b0bd32b..0b56e53805 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ModelingToolkit" uuid = "961ee093-0014-501f-94e3-6117800e7a78" authors = ["Yingbo Ma ", "Chris Rackauckas and contributors"] -version = "9.57.0" +version = "9.58.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" From 6d6383abf73d16b6c3ecb86bfed3500d70d25257 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 11 Dec 2024 17:11:17 +0530 Subject: [PATCH 095/101] fix: allow interpolating names in `@brownian` --- src/variables.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/variables.jl b/src/variables.jl index d11c6c1834..3b0d64e7ea 100644 --- a/src/variables.jl +++ b/src/variables.jl @@ -489,7 +489,9 @@ $(SIGNATURES) Define one or more Brownian variables. """ macro brownian(xs...) - all(x -> x isa Symbol || Meta.isexpr(x, :call) && x.args[1] == :$, xs) || + all( + x -> x isa Symbol || Meta.isexpr(x, :call) && x.args[1] == :$ || Meta.isexpr(x, :$), + xs) || error("@brownian only takes scalar expressions!") Symbolics._parse_vars(:brownian, Real, From a2acbe59eebeec40b6ddb4355636aedb8ec5134f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 11 Dec 2024 17:25:26 +0530 Subject: [PATCH 096/101] feat: add ability to convert `SDESystem` to equivalent `ODESystem` --- src/systems/diffeqs/sdesystem.jl | 41 ++++++++++++++++++++++++++ test/sdesystem.jl | 49 ++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 011ba6e216..ac47f4c45c 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -263,6 +263,47 @@ function Base.:(==)(sys1::SDESystem, sys2::SDESystem) all(s1 == s2 for (s1, s2) in zip(get_systems(sys1), get_systems(sys2))) end +""" + function ODESystem(sys::SDESystem) + +Convert an `SDESystem` to the equivalent `ODESystem` using `@brownian` variables instead +of noise equations. The returned system will not be `iscomplete` and will not have an +index cache, regardless of `iscomplete(sys)`. +""" +function ODESystem(sys::SDESystem) + neqs = get_noiseeqs(sys) + eqs = equations(sys) + is_scalar_noise = get_is_scalar_noise(sys) + nbrownian = if is_scalar_noise + length(neqs) + else + size(neqs, 2) + end + brownvars = map(1:nbrownian) do i + name = gensym(Symbol(:brown_, i)) + only(@brownian $name) + end + if is_scalar_noise + brownterms = reduce(+, neqs .* brownvars; init = 0) + neweqs = map(eqs) do eq + eq.lhs ~ eq.rhs + brownterms + end + else + if neqs isa AbstractVector + neqs = reshape(neqs, (length(neqs), 1)) + end + brownterms = neqs * brownvars + neweqs = map(eqs, brownterms) do eq, brown + eq.lhs ~ eq.rhs + brown + end + end + newsys = ODESystem(neweqs, get_iv(sys), unknowns(sys), parameters(sys); + parameter_dependencies = parameter_dependencies(sys), defaults = defaults(sys), + continuous_events = continuous_events(sys), discrete_events = discrete_events(sys), + name = nameof(sys), description = description(sys), metadata = get_metadata(sys)) + @set newsys.parent = sys +end + function __num_isdiag_noise(mat) for i in axes(mat, 1) nnz = 0 diff --git a/test/sdesystem.jl b/test/sdesystem.jl index 749aca86a7..036c94b868 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -809,3 +809,52 @@ end prob = SDEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0)) @test prob[z] ≈ 2.0 end + +@testset "SDESystem to ODESystem" begin + @variables x(t) y(t) z(t) + @testset "Scalar noise" begin + @named sys = SDESystem([D(x) ~ x, D(y) ~ y, z ~ x + y], [x, y, 3], + t, [x, y, z], [], is_scalar_noise = true) + odesys = ODESystem(sys) + @test odesys isa ODESystem + vs = ModelingToolkit.vars(equations(odesys)) + nbrownian = count( + v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs) + @test nbrownian == 3 + for eq in equations(odesys) + ModelingToolkit.isdiffeq(eq) || continue + @test length(arguments(eq.rhs)) == 4 + end + end + + @testset "Non-scalar vector noise" begin + @named sys = SDESystem([D(x) ~ x, D(y) ~ y, z ~ x + y], [x, y, 0], + t, [x, y, z], [], is_scalar_noise = false) + odesys = ODESystem(sys) + @test odesys isa ODESystem + vs = ModelingToolkit.vars(equations(odesys)) + nbrownian = count( + v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs) + @test nbrownian == 1 + for eq in equations(odesys) + ModelingToolkit.isdiffeq(eq) || continue + @test length(arguments(eq.rhs)) == 2 + end + end + + @testset "Matrix noise" begin + noiseeqs = [x+y y+z z+x + 2y 2z 2x + z+1 x+1 y+1] + @named sys = SDESystem([D(x) ~ x, D(y) ~ y, D(z) ~ z], noiseeqs, t, [x, y, z], []) + odesys = ODESystem(sys) + @test odesys isa ODESystem + vs = ModelingToolkit.vars(equations(odesys)) + nbrownian = count( + v -> ModelingToolkit.getvariabletype(v) == ModelingToolkit.BROWNIAN, vs) + @test nbrownian == 3 + for eq in equations(odesys) + @test length(arguments(eq.rhs)) == 4 + end + end +end From ea9b6bd02f683f7c75ce858187e92293f4c4e72a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 11 Dec 2024 17:25:38 +0530 Subject: [PATCH 097/101] feat: enable `structural_simplify(::SDESystem)` --- src/systems/systems.jl | 4 ++++ test/sdesystem.jl | 10 ++++++++++ 2 files changed, 14 insertions(+) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 862718968d..97d22cf4cf 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -72,6 +72,10 @@ function __structural_simplify(sys::JumpSystem, args...; kwargs...) return sys end +function __structural_simplify(sys::SDESystem, args...; kwargs...) + return __structural_simplify(ODESystem(sys), args...; kwargs...) +end + function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = false, kwargs...) sys = expand_connections(sys) diff --git a/test/sdesystem.jl b/test/sdesystem.jl index 036c94b868..8069581dcc 100644 --- a/test/sdesystem.jl +++ b/test/sdesystem.jl @@ -858,3 +858,13 @@ end end end end + +@testset "`structural_simplify(::SDESystem)`" begin + @variables x(t) y(t) + @mtkbuild sys = SDESystem( + [D(x) ~ x, y ~ 2x], [x, 0], t, [x, y], []; is_scalar_noise = true) + @test sys isa SDESystem + @test length(equations(sys)) == 1 + @test length(ModelingToolkit.get_noiseeqs(sys)) == 1 + @test length(observed(sys)) == 1 +end From 22d70938e2ac432c29c4787aa57c687bf45b76f3 Mon Sep 17 00:00:00 2001 From: CompatHelper Julia Date: Thu, 12 Dec 2024 00:25:48 +0000 Subject: [PATCH 098/101] CompatHelper: add new compat entry for Setfield at version 1 for package docs, (keep existing compat) --- docs/Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/Project.toml b/docs/Project.toml index 24f00b4db2..a358455503 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -38,6 +38,7 @@ OptimizationOptimJL = "0.1, 0.4" OrdinaryDiffEq = "6.31" Plots = "1.36" SciMLStructures = "1.1" +Setfield = "1" StochasticDiffEq = "6" SymbolicIndexingInterface = "0.3.1" SymbolicUtils = "3" From 41600b14e7ddea22cf17e529f005eb85aa98ce31 Mon Sep 17 00:00:00 2001 From: Benjamin Chung Date: Tue, 12 Nov 2024 23:53:53 -0800 Subject: [PATCH 099/101] Add a simple mechanism to add passes to structural simplify --- src/systems/systems.jl | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 97d22cf4cf..29cc96d28f 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -26,7 +26,7 @@ topological sort of the observed equations in `sys`. + `fully_determined=true` controls whether or not an error will be thrown if the number of equations don't match the number of inputs, outputs, and equations. """ function structural_simplify( - sys::AbstractSystem, io = nothing; simplify = false, split = true, + sys::AbstractSystem, io = nothing; additional_passes = [], simplify = false, split = true, allow_symbolic = false, allow_parameter = true, conservative = false, fully_determined = true, kwargs...) isscheduled(sys) && throw(RepeatedStructuralSimplificationError()) @@ -49,6 +49,9 @@ function structural_simplify( if newsys isa ODESystem || has_parent(newsys) @set! newsys.parent = complete(sys; split, flatten = false) end + for pass in additional_passes + newsys = pass(newsys) + end newsys = complete(newsys; split) if has_defaults(newsys) && (defs = get_defaults(newsys)) !== nothing ks = collect(keys(defs)) # take copy to avoid mutating defs while iterating. From c0abc56decc7691b25254af716f20cc9979f73cf Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 12 Dec 2024 14:59:55 +0530 Subject: [PATCH 100/101] fix: run additional passes before setting the parent of the system --- src/systems/systems.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 29cc96d28f..47acd81a82 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -46,12 +46,12 @@ function structural_simplify( not yet supported. """) end + for pass in additional_passes + newsys = pass(newsys) + end if newsys isa ODESystem || has_parent(newsys) @set! newsys.parent = complete(sys; split, flatten = false) end - for pass in additional_passes - newsys = pass(newsys) - end newsys = complete(newsys; split) if has_defaults(newsys) && (defs = get_defaults(newsys)) !== nothing ks = collect(keys(defs)) # take copy to avoid mutating defs while iterating. From daf93edfde0a60ec6fb460f83e977c248cfeb849 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 12 Dec 2024 15:00:08 +0530 Subject: [PATCH 101/101] test: test additional passes mechanism --- test/structural_transformation/utils.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/structural_transformation/utils.jl b/test/structural_transformation/utils.jl index 2704559f72..863e091aad 100644 --- a/test/structural_transformation/utils.jl +++ b/test/structural_transformation/utils.jl @@ -152,3 +152,12 @@ end end end end + +@testset "additional passes" begin + @variables x(t) y(t) + @named sys = ODESystem([D(x) ~ x, y ~ x + t], t) + value = Ref(0) + pass(sys; kwargs...) = (value[] += 1; return sys) + structural_simplify(sys; additional_passes = [pass]) + @test value[] == 1 +end