From 2f6508478e639266c5747e936e924c31303a7168 Mon Sep 17 00:00:00 2001 From: JuliusMartensen Date: Tue, 6 Feb 2024 13:08:08 +0100 Subject: [PATCH 1/4] Filter for target variables in observed equations --- src/augment.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/augment.jl b/src/augment.jl index abd1076..e673c4a 100644 --- a/src/augment.jl +++ b/src/augment.jl @@ -119,7 +119,7 @@ function construct_jacobians(::MTKBackend, fx = ModelingToolkit.jacobian(eqs, states(sys)) dfddx = ModelingToolkit.jacobian(eqs, D.(states(sys))) fp = ModelingToolkit.jacobian(eqs, p) - obs = observed(sys) + obs = filter(x -> is_measured(x.lhs), observed(sys)) obs = isempty(obs) ? states(sys) : map(x -> x.rhs, obs) hx = ModelingToolkit.jacobian(obs, states(sys)) return dfddx, fx, fp, hx @@ -163,10 +163,10 @@ function build_augmented_system(sys::ModelingToolkit.AbstractODESystem, t = ModelingToolkit.get_iv(sys) delta_t = Differential(t) # The observed equations - obs = observed(sys) - + obs = filter(x -> is_measured(x.lhs), observed(sys)) + @info obs # Check if all observed equations and controls have measurement rates associated - @assert all(x -> is_measured(x.lhs), obs) "Not all observed equations have measurement rates associated to them!" + @assert !isempty(obs) "None of the observed equations have measurement rates associated to them! Please provide at least one observable measurement." @assert all(is_measured, c) "Not all controls have rates associated to them! If you mean to apply continuous controls, please adjust your model before passing it." @assert !isempty(obs) "Please defined `observed` equations to use optimal experimental design." From 90f2165d13adf2136d3b08aee4632b01267837bd Mon Sep 17 00:00:00 2001 From: JuliusMartensen Date: Tue, 6 Feb 2024 13:08:22 +0100 Subject: [PATCH 2/4] Add urethan example --- test/references/urethan.jl | 102 +++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 test/references/urethan.jl diff --git a/test/references/urethan.jl b/test/references/urethan.jl new file mode 100644 index 0000000..0c79173 --- /dev/null +++ b/test/references/urethan.jl @@ -0,0 +1,102 @@ +using ModelingToolkit +using LinearAlgebra +using DynamicOED + +# Define the system +T_min, T_max = 293.16, 473.16 + +const M = [0.11911, 0.07412, 0.19323, 0.31234, 0.35733, 0.07806] +const rho = [1095.0, 809.0, 1415.0, 1528.0, 1451.0, 1101.0] + +K_REF1 = 5.0E-4 +E_A1 = 35240.0 +K_REF2 = 8.0E-8 +E_A2 = 85000.0 +K_REF4 = 1.0E-8 +E_A4 = 35000.0 +DH_2 = -17031.0 +K_C2 = 0.17 + +R = 8.314; +T1 = 363.16; + +@variables t +@variables V(t)=0.5 [description = "Reaction volume"] +@variables r1(t)=0.0 [description = "Reaction rate 1"] +@variables r2(t)=0.0 [description = "Reaction rate 2"] +@variables r3(t)=0.0 [description = "Reaction rate 3"] +@variables r4(t)=0.0 [description = "Reaction rate 4"] +@variables feed1(t)=0 [description = "State for feed 1"] +@variables feed2(t)=0 [description = "State for feed 2"] +@variables temperature(t)=293.15 [ + bounds = [293.16, 473.16], + description = "State for temperature", +] +@variables n(t)[1:6]=[0.0; 0.0; 0.0; 0.0; 0.0; 0.0] [description = "States"] +@parameters u1 [ + description = "Control feed 1", + bounds = [0, Inf], + input = true, + measurement_rate = 10, +] +@parameters u2 [ + description = "Control feed 2", + bounds = [0, Inf], + input = true, + measurement_rate = 10, +] +@parameters u3 [ + description = "Control temperature", + bounds = [-40, 40], + input = true, + measurement_rate = 10, +] +@variables h₁(t) [description = "Observed", measurement_rate = 20] +@variables h₂(t) [description = "Observed", measurement_rate = 20] +h = vcat(h₁, h₂) +@parameters p[1:6]=[1.0; 1.0; 1.0; 1.0; 1.0; 1.0; 1.0; 1.0] [ + description = "Scaling parameters", + tunable = true, +] +@parameters n0[1:3]=[0.12; 0.0; 0.0] [description = "Initial mole numbers", tunable = false] +D = Differential(t) + +## Define the control function, returns feed1, feed2, T +eqs_ = [V * (r1 - r2 + r3); + V * (r2 - r3); + V * r4] + +subs_ = Dict(r1 => p[1] * K_REF1 * exp(-p[2] * E_A1 / R * (1 / temperature - 1 / T1)) * + n[1] * n[2] / (V * V), + r2 => p[3] * K_REF2 * exp(-p[4] * E_A2 / R * (1 / temperature - 1 / T1)) * n[1] * n[3] / + (V * V), + r3 => p[3] * K_REF2 * exp(-p[4] * E_A2 / R * (1 / temperature - 1 / T1)) * + inv(K_C2 * exp(-(-DH_2 / R) * (1 / temperature - 1 / T1))) * n[4] / V, + r4 => p[5] * K_REF4 * exp(-p[6] * E_A4 / R * (1 / temperature - 1 / T1)) * (n[1] / V)^2) + +eqs = map(Base.Fix2(substitute, subs_), eqs_) + +# Define the eqs +@named urethan = ODESystem([D(feed1) ~ u1; + D(feed2) ~ u2; + D(temperature) ~ u3; + D(n[3]) ~ eqs[1]; #n_C + D(n[4]) ~ eqs[2]; #n_D + D(n[5]) ~ eqs[3]; #n_E + n[1] ~ n0[1] + feed1 - n[3] - 2 * n[4] - 3 * n[5]; #n_A + n[2] ~ n0[2] + feed2 - n[3] - n[4]; #n_B + n[6] ~ n0[3] + feed1 + u2; #n_L + V ~ sum(n .* M ./ rho)], tspan = (0.0, 80.0), + observed = h .~ [100 * n[1] * M[1] / sum([ni .* M[i] for (i, ni) in enumerate(n)]); + #100 * n[3]*M[3]/sum([ni .* M[i] for (i,ni) in enumerate(n)]); + #100 * n[4]*M[4]/sum([ni .* M[i] for (i,ni) in enumerate(n)]); + 100 * n[5] * M[5] / sum([ni .* M[i] for (i, ni) in enumerate(n)])]) + +## Build the OED System + +@named urethan_oed = OEDSystem(structural_simplify(urethan)) + +oed_problem = DynamicOED.OEDProblem(urethan_oed, DCriterion()) + +optimization_variables = states(oed_problem) +timegrids = DynamicOED.get_timegrids(oed_problem) From 9823c242b1f7dc4507f2c88fb0f8cc79ae44a1dc Mon Sep 17 00:00:00 2001 From: JuliusMartensen Date: Tue, 6 Feb 2024 13:10:37 +0100 Subject: [PATCH 3/4] Update src/augment.jl --- src/augment.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/augment.jl b/src/augment.jl index e673c4a..2269cfb 100644 --- a/src/augment.jl +++ b/src/augment.jl @@ -164,7 +164,6 @@ function build_augmented_system(sys::ModelingToolkit.AbstractODESystem, delta_t = Differential(t) # The observed equations obs = filter(x -> is_measured(x.lhs), observed(sys)) - @info obs # Check if all observed equations and controls have measurement rates associated @assert !isempty(obs) "None of the observed equations have measurement rates associated to them! Please provide at least one observable measurement." @assert all(is_measured, c) "Not all controls have rates associated to them! If you mean to apply continuous controls, please adjust your model before passing it." From ae83e8301ace2c8d2da0986e0723b667a7eb5c16 Mon Sep 17 00:00:00 2001 From: chplate Date: Wed, 15 May 2024 14:09:49 +0200 Subject: [PATCH 4/4] added urethan example in docs, WIP --- docs/make.jl | 1 + docs/src/examples/urethan.md | 228 +++++++++++++++++++++++++++++++++++ src/augment.jl | 53 ++++---- src/discretize.jl | 8 +- 4 files changed, 259 insertions(+), 31 deletions(-) create mode 100644 docs/src/examples/urethan.md diff --git a/docs/make.jl b/docs/make.jl index f409d0d..52f3c44 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -13,6 +13,7 @@ makedocs(modules = [DynamicOED], "Examples" => [ "Linear System" => "examples/1D.md", "Lotka-Volterra" => "examples/lotka.md", + "Urethan" => "examples/urethan.md" ], "Theory" => "theory.md", "API" => "api.md", diff --git a/docs/src/examples/urethan.md b/docs/src/examples/urethan.md new file mode 100644 index 0000000..613d15b --- /dev/null +++ b/docs/src/examples/urethan.md @@ -0,0 +1,228 @@ +# Design of Experiments for Urethan + +In this example we take a look at the reaction of Urethane from the educts Butanol and Isocyanate. The reaction scheme is + +```math +\begin{aligned} +A + B &\rightarrow C\\  +A + C &\rightleftarrow D\\ +3 A &\rightarrow E +\end{aligned} +``` +with reactants Isocyanate A, Butanol B, Urethane C, Allophanate D and Isocyanurate E. + +We start by using [ModelingToolkit.jl](https://github.com/SciML/ModelingToolkit.jl) and DynamicOED.jl to model the system. + +```@example urethan +using DynamicOED +using ModelingToolkit +using Optimization, OptimizationMOI, Ipopt +using Plots + +const M = [ 0.11911, 0.07412, 0.19323, 0.31234, 0.35733, 0.07806 ] +const rho = [1095.0, 809.0, 1415.0, 1528.0, 1451.0, 1101.0] + +Rg = 8.314; +T1 = 363.16; + +n_A0, n_B0, n_L0 = 0.1, 0.05, 0.01 + +# Set up variables +@variables t +D = Differential(t) +@variables h(t)[1:4] [description = "Observed", measurement_rate=10] +h = collect(h) + +@variables n_C(t)=0.0 [description = "Molar numbers for C"] +@variables n_D(t)=0.0 [description = "Molar numbers for D"] +@variables n_E(t)=0.0 [description = "Molar numbers for E"] +@variables n_A(t)=n_A0 [description = "Molar numbers for C"] +@variables n_B(t)=n_B0 [description = "Molar numbers for D"] +@variables n_L(t)=n_L0 [description = "Molar numbers for E"] + +@parameters begin + p1=1.0, [description = "Scaling parameter 1", tunable = true] + p2=1.0, [description = "Scaling parameter 2", tunable = true] + p3=1.0, [description = "Scaling parameter 3", tunable = true] + p4=1.0, [description = "Scaling parameter 4", tunable = false] + p5=1.0, [description = "Scaling parameter 5", tunable = false] + p6=1.0, [description = "Scaling parameter 6", tunable = false] +end + +# Variables for temperature and feed +@variables feed1(t)=0 [description = "State for feed 1"] +@variables feed2(t)=0 [description = "State for feed 2"] +@variables temperature(t)=373.15 [description = "State for temperature"] + +# Controls for temperature and feed +@parameters begin + u1=0.0125, [description = "RHS of feed1", bounds=(0.0,0.0125), input=true, measurement_rate=10] + u2=0.0125, [description = "RHS of feed2", bounds=(0.0,0.0125), input=true, measurement_rate=10] + u_temp=0.0, [description = "RHS of temperature", bounds=(-15,15), input=false] +end + + +# Write system of equations +k_ref1 = p1 * 5.0E-4 +E_a1 = p2 * 35240.0 +k_ref2 = p3 * 8.0E-8 +E_a2 = p4 * 85000.0 +k_ref4 = p5 * 1.0E-8 +E_a4 = p6 * 35000.0 +dH_2 = -17031.0 +K_C2 = 0.17 + +# Arrhenius equations for the reaction rates +fac_T = 1.0 / (Rg*temperature) - 1.0 / (Rg*T1); +k1 = k_ref1 * exp(- E_a1 * fac_T); +k2 = k_ref2 * exp(- E_a2 * fac_T); +k4 = k_ref4 * exp(- E_a4 * fac_T); +K_C = K_C2 * exp(- dH_2 * fac_T); +k3 = k2/K_C; + +# Reaction volume +V = n_A * M[1] / rho[1] + n_B * M[2] /rho[2] + n_C * M[3] / rho[3] + n_D * M[4] / rho[4] + + n_E * M[5] / rho[5] + n_L * M[6] / rho[6] + +# Reaction rates +r1 = k1 * n_A/V * n_B/V; +r2 = k2 * n_A/V * n_C/V; +r3 = k3 * n_D/V; +r4 = k4 * (n_A / V)*(n_A / V); + +sum_observed = n_A * M[1] + n_B * M[2] + n_C * M[3] + n_D * M[4] + n_E * M[5] + n_L * M[6] +# Define the eqs +@named urethan = ODESystem( + [ + D(feed1) ~ u1; + D(feed2) ~ u2; + D(temperature) ~ u_temp; + D(n_C) ~ V * (r1 - r2 + r3); #n_C + D(n_D) ~ V * (r2 - r3); #n_D + D(n_E) ~ V * r4; #n_E + 0 ~ n_A0 + feed1 - n_C - 2n_D - 3n_E - n_A; #n_A + 0 ~ n_B0 + feed2 - n_C - n_D - n_B; #n_B + 0 ~ n_L0 + (feed1 + feed2) - n_L; #n_L + ], tspan = (0.0, 80.0), + observed = h .~ [100 * n_A*M[1]/sum_observed; + 100 * n_C*M[3]/sum_observed; + 100 * n_D*M[4]/sum_observed; + 100 * n_E*M[5]/sum_observed + ] +) +``` + +Like in the [Design of Experiments for a simple system](@ref), we added important information: + +- The observed variables are initialized with a [`measurement_rate`](@ref DynamicOED.VariableRate). This time we use an integer measurement rate, resulting in $10$ subintervals of equal length. +- The parameters $p_1$ t0 $p_3$ that enter the first and second Arrhenius equations are marked as `tunable`. These are the parameters we want to estimate. + +Now we can build the `OEDSystem`, which includes the necessary equations for the sensitivities and the Fisher information matrix. For this, we only consider the differential equations of the species $C, D$ and $E$. + +```@example urethan +relevant_equations = equations(urethan)[4:6] +relevant_states = states(urethan)[4:6] + +@named oed = OEDSystem(urethan, equationset = relevant_equations, stateset = relevant_states) +``` + +With this augmented `ODESystem` we can set up the `OEDProblem` by specifying the criterion we want to optimize. + +```@example urethan +oed_simp = structural_simplify(oed) +oed_problem = OEDProblem(oed_simp, ACriterion()) +``` +We choose the [`ACriterion`](@ref), which minimizes the trace of the inverse of the Fisher information matrix. For constraining the time we can measure by defining a `ConstraintSystem` from ModelingToolkit on the optimization variables. We allow measurements of each species in only 2 of the 10 subintervals. Also, we limit the amount of energy we give into the system via an upper limit on the sum of the temperature controls. + +```@example urethan +optimization_variables = states(oed_problem) +w1, w2, w3, w4 = keys(optimization_variables.measurements) +u1, u2 = keys(optimization_variables.controls) + +constraint_equations = [ + sum(optimization_variables.measurements[w1]) ≲ 2, + sum(optimization_variables.measurements[w2]) ≲ 2, + sum(optimization_variables.measurements[w3]) ≲ 2, + sum(optimization_variables.measurements[w4]) ≲ 2, +] + + +@named constraint_system = ConstraintsSystem( + constraint_equations, collect(optimization_variables), [] +) +nothing # hide +``` +!!! note + The `optimization_variables` contain several groups of variables, namely `measurements`, `controls`, `initial_conditions`, and `regularization`. `measurements` represent the decision to observe at a specific time point at the grid. We currently work with the naming convention `w_i` for the i-th observed equation. Currently we need to `collect` the states before passing them into the `ConstraintsSystem`! + + +Finally, we are now able to convert our [`OEDProblem`](@ref) into an `OptimizationProblem` and `solve` it. + +!!! note + Currently we only support `AutoForwardDiff()` as an AD backend. + + +```@example urethan +optimization_problem = OptimizationProblem( + oed_problem, AutoForwardDiff(), constraints = constraint_system, + integer_constraints = false +) + +optimal_design = solve(optimization_problem, Ipopt.Optimizer(); hessian_approximation="limited-memory") + +u_opt = optimal_design.u + optimization_problem.u0 +``` + +Now we want to visualize the found solution. +```@example urethan + +predictor = DynamicOED.build_predictor(oed_problem) +x_opt, t_opt = predictor(u_opt) +timegrid = oed_problem.timegrid + +np = sum(istunable.(parameters(urethan))) +nx = length(relevant_equations) +sts = states(oed_simp) + + +states_plot1 = plot(t_opt, x_opt[4:5,:]', label=hcat(string.(sts[4:5])...), xlabel="Time", ylabel="Concentrations") +states_plot2 = plot(t_opt, x_opt[6,:], label=string(sts[6]), xlabel="Time", ylabel="Concentrations", color=3) + +feed_plot = plot(t_opt, x_opt[1:2, :]', xlabel = "Time", ylabel = "Feed", label = hcat([string(x) for x in sts[1:2]]...)) +temp_plot = plot(t_opt, x_opt[3, :], xlabel="Time", ylabel="Temperature", label=string(sts[3])) +hspan!([293, 473], alpha=.3, label=nothing) + + +states_plot1 = plot(t_opt, x_opt[4:5,:]', label=hcat(string.(sts[4:5])...), xlabel="Time", ylabel="Concentrations") +states_plot2 = plot(t_opt, x_opt[6,:], label=string(sts[6]), xlabel="Time", ylabel="Concentrations", color=3) + +sens_vars = startswith.(string.(sts), "(G") + +sensitivities_plot = plot(t_opt, x_opt[sens_vars,:]', label=hcat(string.(sts[sens_vars])...), legend_font_pointsize=6, legend_columns=np, xlabel="Time", ylabel="dx/dp") + +u1_, u2_, = keys(optimization_variables.controls) + +repfirst(x) = [x[1]; x] + +control_feed_plot = plot(t_opt, repfirst(u_opt.controls[u1_]), label="u1(t)", xlabel="Time", linetype=:steppre) +plot!(t_opt, repfirst(u_opt.controls[u2_]), label="u2(t)", xlabel="Time", linetype=:steppre) + +ws = keys(optimization_variables.measurements) +sampling_plot = plot() +for wi in ws + w_i = u_opt.measurements[wi] + plot!(t_opt, repfirst(w_i), linetype=:steppre, xlabel="Time", ylabel="Sampling", label=string(wi)) +end + + + +l = @layout [ + grid(2,3) + a{0.3h} + b{0.2h} +] + +plot(feed_plot, temp_plot, states_plot1, control_feed_plot, states_plot2, plot(), sensitivities_plot, sampling_plot, layout=l, size=(900,600)) + +``` + diff --git a/src/augment.jl b/src/augment.jl index 2269cfb..7805708 100644 --- a/src/augment.jl +++ b/src/augment.jl @@ -2,7 +2,7 @@ """ $(TYPEDEF) -Indicator that a given variable is a measurement function. +Indicator that a given variable is a measurement function. ``` @variables w=1.0 [measurement_function=true] @@ -19,7 +19,7 @@ set_measurement_function(x) = Symbolics.setmetadata(x, MeasurementFunction, true """ $(TYPEDEF) -Indicator that a given variable is a state of the fisher information matrix. +Indicator that a given variable is a state of the fisher information matrix. ``` @variables F[1:3, 1:3] [fisher_state=true] @@ -66,8 +66,8 @@ $(TYPEDEF) Indicator that a given state is subject to a fixed rate. Is used for modeling the rate of observation of observed variables and the rate of control for control variables. -If the provided rate is a `Real`, it is assumed that the resulting time grid is given in fractions of the time unit. -If the provided rate is a `Int`, it is assumed that the resulting time grid is divided in `rate` equidistant intervals. +If the provided rate is a `Real`, it is assumed that the resulting time grid is given in fractions of the time unit. +If the provided rate is a `Int`, it is assumed that the resulting time grid is divided in `rate` equidistant intervals. ``` @variables y(t) [measurement_rate=0.1] # Create a variable measured every 0.1 t @@ -112,37 +112,36 @@ end function construct_jacobians(::MTKBackend, sys::ModelingToolkit.AbstractODESystem, - p = parameters(sys)) - eqs = map(x -> x.rhs - x.lhs, equations(sys)) + p = parameters(sys), stateset = states(sys), equationset = equations(sys)) + eqs = map(x -> x.rhs - x.lhs, equationset) t = ModelingToolkit.get_iv(sys) D = Differential(t) - fx = ModelingToolkit.jacobian(eqs, states(sys)) - dfddx = ModelingToolkit.jacobian(eqs, D.(states(sys))) + fx = ModelingToolkit.jacobian(eqs, stateset) + dfddx = ModelingToolkit.jacobian(eqs, D.(stateset)) fp = ModelingToolkit.jacobian(eqs, p) obs = filter(x -> is_measured(x.lhs), observed(sys)) - obs = isempty(obs) ? states(sys) : map(x -> x.rhs, obs) - hx = ModelingToolkit.jacobian(obs, states(sys)) + obs = isempty(obs) ? sts : map(x -> x.rhs, obs) + hx = ModelingToolkit.jacobian(obs, stateset) return dfddx, fx, fp, hx end function build_augmented_system(sys::ModelingToolkit.AbstractODESystem, - backend::AbstractAugmentationBackened; + backend::AbstractAugmentationBackened; stateset = states(sys), + equationset = equations(sys), name::Symbol, kwargs...) T = Float64 # The set of tuneable parameters p = parameters(sys) - # The set of controls + # The set of controls c = filter(ModelingToolkit.isinput, p) # The set of tuneable parameters p_tuneable = setdiff(filter(ModelingToolkit.istunable, p), c) - # The states - x = states(sys) # The unknown initial conditions - x_ic = eltype(x)[] + x_ic = eltype(stateset)[] unknown_initial_conditions = Int[] - @inbounds for i in axes(x, 1) - xi = getindex(x, i) + @inbounds for i in axes(stateset, 1) + xi = getindex(stateset, i) if istunable(xi) xi_0 = Symbolics.variable(Symbol(xi, "₀"), T = Symbolics.symtype(xi)) xi_0 = setmetadata(xi_0, @@ -159,10 +158,10 @@ function build_augmented_system(sys::ModelingToolkit.AbstractODESystem, end end - # The independent variable + # The independent variable t = ModelingToolkit.get_iv(sys) delta_t = Differential(t) - # The observed equations + # The observed equations obs = filter(x -> is_measured(x.lhs), observed(sys)) # Check if all observed equations and controls have measurement rates associated @assert !isempty(obs) "None of the observed equations have measurement rates associated to them! Please provide at least one observable measurement." @@ -170,12 +169,12 @@ function build_augmented_system(sys::ModelingToolkit.AbstractODESystem, @assert !isempty(obs) "Please defined `observed` equations to use optimal experimental design." - np, nx, n_obs = length(p_tuneable), length(x), length(obs) + np, nx, n_obs = length(p_tuneable), length(stateset), length(obs) ## build the jacobians - dfx, fx, fp, hx = construct_jacobians(backend, sys, p_tuneable) + dfx, fx, fp, hx = construct_jacobians(backend, sys, p_tuneable, stateset, equationset) # Check the size of the equations - @assert size(fx, 1)==size(fp, 1)==size(x, 1) "The size of the state equations and the jacobian does not match" + @assert size(fx, 1)==size(fp, 1)==size(stateset, 1) "The size of the state equations and the jacobian does not match" @assert size(fp, 2)==size(p_tuneable, 1) "The size of the state equations and the jacobian does not match" G_init = zeros(T, (nx, np)) @@ -218,7 +217,7 @@ function build_augmented_system(sys::ModelingToolkit.AbstractODESystem, G = collect(G) Q = collect(Q) - # Create new observed function + # Create new observed function idx = triu(trues(np, np)) new_obs = delta_t.(F[idx]) .~ (sum(enumerate(w)) do (i, wi) wi * ((hx[i:i, :] * G)' * (hx[i:i, :] * G))[idx] @@ -228,7 +227,6 @@ function build_augmented_system(sys::ModelingToolkit.AbstractODESystem, # We always assume DAE form here. Results in a stable system # 60 % of the time it works everytime! This is an easteregg, do not take it seriously! sens_eqs = vec(zeros(T, nx, np) .~ dfx * delta_t.(G) .+ fx * G .+ fp) - # We do not need to do anything more than push this into the equations # Simplify will figure out the rest ODESystem([vec(dynamic_eqs); @@ -236,12 +234,13 @@ function build_augmented_system(sys::ModelingToolkit.AbstractODESystem, vec(new_obs); vec(Q .~ hx * G)], t, - vcat(x, vec(G), vec(F[idx]), vec(Q)), + vcat(states(sys), vec(G), vec(F[idx]), vec(Q)), vcat(union(p, p_tuneable), w), tspan = ModelingToolkit.get_tspan(sys), observed = observed(sys), name = name) end -function OEDSystem(sys::ModelingToolkit.AbstractODESystem; kwargs...) - build_augmented_system(sys, MTKBackend(); kwargs...) +function OEDSystem(sys::ModelingToolkit.AbstractODESystem; stateset = states(sys), + equationset = equations(sys), kwargs...) + build_augmented_system(sys, MTKBackend(); stateset=stateset, equationset = equationset, kwargs...) end diff --git a/src/discretize.jl b/src/discretize.jl index 3ea6191..50866e9 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -6,7 +6,7 @@ end function _generate_timegrid(Δt::T, tspan::Tuple{Real, Real}) where {T <: Real} @assert Δt>0 "Stepsize must be greater than 0." - @assert Δt<-(reverse(tspan)...) "Stepsize must be smaller than total time interval." + @assert Δt<=-(reverse(tspan)...) "Stepsize must be smaller than total time interval." t0, tinf = tspan timepoints = collect(T, t0:Δt:tinf) if timepoints[end] != tinf @@ -23,7 +23,7 @@ end """ $(TYPEDEF) -A structure for holding a multi-variable time grid. +A structure for holding a multi-variable time grid. # Fields @@ -86,7 +86,7 @@ end function get_variable_idx(grid::Timegrid, var::Symbol, i::Int) id = _get_variable_idx(grid, var) - isnothing(id) && return 1 # We always assume here that this will work + isnothing(id) && return 1 # We always assume here that this will work return grid.indicators[id, i] end @@ -251,7 +251,7 @@ function (remaker::OEDRemake)(i::Int, ics = getproperty(parameters, :initial_conditions) |> NamedTuple controls = getproperty(parameters, :controls) |> NamedTuple measurements = getproperty(parameters, :measurements) |> NamedTuple - # Get the right controls + # Get the right controls controls = get_vars_from_grid(remaker.grid, i, controls) measurements = get_vars_from_grid(remaker.grid, i, measurements) p0_ = remaker.parameter_remake(p0, measurements, controls, ics)