From 552b0393f683fb16935c90779296347454cae9c3 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 11 Dec 2024 12:14:27 +0530 Subject: [PATCH 01/38] fix: respect `use_homotopy_continuation` in `NonlinearProblem` and default it to `false` --- src/systems/nonlinear/nonlinearsystem.jl | 10 ++++++---- test/extensions/homotopy_continuation.jl | 2 +- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 3cb68853aa..2c788688bd 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -496,13 +496,15 @@ end function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, parammap = DiffEqBase.NullParameters(); - check_length = true, use_homotopy_continuation = true, kwargs...) where {iip} + check_length = true, use_homotopy_continuation = false, kwargs...) where {iip} if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`") end - prob = safe_HomotopyContinuationProblem(sys, u0map, parammap; check_length, kwargs...) - if prob isa HomotopyContinuationProblem - return prob + if use_homotopy_continuation + prob = safe_HomotopyContinuationProblem(sys, u0map, parammap; check_length, kwargs...) + if prob isa HomotopyContinuationProblem + return prob + end end f, u0, p = process_SciMLProblem(NonlinearFunction{iip}, sys, u0map, parammap; check_length, kwargs...) diff --git a/test/extensions/homotopy_continuation.jl b/test/extensions/homotopy_continuation.jl index 554f9e1e1d..3f4a3fbc71 100644 --- a/test/extensions/homotopy_continuation.jl +++ b/test/extensions/homotopy_continuation.jl @@ -29,7 +29,7 @@ import HomotopyContinuation @test SciMLBase.successful_retcode(sol) @test norm(sol.resid)≈0.0 atol=1e-10 - prob2 = NonlinearProblem(sys, u0) + prob2 = NonlinearProblem(sys, u0; use_homotopy_continuation = true) @test prob2 isa HomotopyContinuationProblem sol = solve(prob2; threading = false) @test SciMLBase.successful_retcode(sol) From e75d06fb2f680ccda550f6ad37cc875497119736 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 11 Dec 2024 14:35:16 +0530 Subject: [PATCH 02/38] fix: properly handle rational functions in HomotopyContinuation --- ext/MTKHomotopyContinuationExt.jl | 5 +- .../nonlinear/homotopy_continuation.jl | 63 +++++++++++++------ src/systems/nonlinear/nonlinearsystem.jl | 3 +- test/extensions/homotopy_continuation.jl | 31 ++++++++- 4 files changed, 79 insertions(+), 23 deletions(-) diff --git a/ext/MTKHomotopyContinuationExt.jl b/ext/MTKHomotopyContinuationExt.jl index c4a090d9a8..8f17c05b18 100644 --- a/ext/MTKHomotopyContinuationExt.jl +++ b/ext/MTKHomotopyContinuationExt.jl @@ -101,7 +101,8 @@ function MTK.HomotopyContinuationProblem( return prob end -function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; kwargs...) +function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; + fraction_cancel_fn = SymbolicUtils.simplify_fractions, kwargs...) if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `HomotopyContinuationProblem`") end @@ -109,7 +110,7 @@ function MTK._safe_HomotopyContinuationProblem(sys, u0map, parammap = nothing; k if transformation isa MTK.NotPolynomialError return transformation end - result = MTK.transform_system(sys, transformation) + result = MTK.transform_system(sys, transformation; fraction_cancel_fn) if result isa MTK.NotPolynomialError return result end diff --git a/src/systems/nonlinear/homotopy_continuation.jl b/src/systems/nonlinear/homotopy_continuation.jl index 03aeed1edf..00c6c7f1b0 100644 --- a/src/systems/nonlinear/homotopy_continuation.jl +++ b/src/systems/nonlinear/homotopy_continuation.jl @@ -442,7 +442,8 @@ Transform the system `sys` with `transformation` and return a `PolynomialTransformationResult`, or a `NotPolynomialError` if the system cannot be transformed. """ -function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation) +function transform_system(sys::NonlinearSystem, transformation::PolynomialTransformation; + fraction_cancel_fn = simplify_fractions) subrules = transformation.substitution_rules dvs = unknowns(sys) eqs = full_equations(sys) @@ -463,7 +464,7 @@ function transform_system(sys::NonlinearSystem, transformation::PolynomialTransf return NotPolynomialError( VariablesAsPolyAndNonPoly(dvs[poly_and_nonpoly]), eqs, polydata) end - num, den = handle_rational_polynomials(t, new_dvs) + num, den = handle_rational_polynomials(t, new_dvs; fraction_cancel_fn) # make factors different elements, otherwise the nonzero factors artificially # inflate the error of the zero factor. if iscall(den) && operation(den) == * @@ -492,43 +493,67 @@ $(TYPEDSIGNATURES) Given a `x`, a polynomial in variables in `wrt` which may contain rational functions, express `x` as a single rational function with polynomial `num` and denominator `den`. Return `(num, den)`. + +Keyword arguments: +- `fraction_cancel_fn`: A function which takes a fraction (`operation(expr) == /`) and returns + a simplified symbolic quantity with common factors in the numerator and denominator are + cancelled. Defaults to `SymbolicUtils.simplify_fractions`, but can be changed to + `nothing` to improve performance on large polynomials at the cost of avoiding non-trivial + cancellation. """ -function handle_rational_polynomials(x, wrt) +function handle_rational_polynomials(x, wrt; fraction_cancel_fn = simplify_fractions) x = unwrap(x) symbolic_type(x) == NotSymbolic() && return x, 1 iscall(x) || return x, 1 contains_variable(x, wrt) || return x, 1 any(isequal(x), wrt) && return x, 1 - # simplify_fractions cancels out some common factors - # and expands (a / b)^c to a^c / b^c, so we only need - # to handle these cases - x = simplify_fractions(x) op = operation(x) args = arguments(x) if op == / # numerator and denominator are trivial num, den = args - # but also search for rational functions in numerator - n, d = handle_rational_polynomials(num, wrt) - num, den = n, den * d - elseif op == + + n1, d1 = handle_rational_polynomials(num, wrt; fraction_cancel_fn) + n2, d2 = handle_rational_polynomials(den, wrt; fraction_cancel_fn) + num, den = n1 * d2, d1 * n2 + elseif (op == +) || (op == -) num = 0 den = 1 - - # we don't need to do common denominator - # because we don't care about cases where denominator - # is zero. The expression is zero when all the numerators - # are zero. + if op == - + args[2] = -args[2] + end + for arg in args + n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn) + num = num * d + n * den + den *= d + end + elseif op == ^ + base, pow = args + num, den = handle_rational_polynomials(base, wrt; fraction_cancel_fn) + num ^= pow + den ^= pow + elseif op == * + num = 1 + den = 1 for arg in args - n, d = handle_rational_polynomials(arg, wrt) - num += n + n, d = handle_rational_polynomials(arg, wrt; fraction_cancel_fn) + num *= n den *= d end else - return x, 1 + error("Unhandled operation in `handle_rational_polynomials`. This should never happen. Please open an issue in ModelingToolkit.jl with an MWE.") + end + + if fraction_cancel_fn !== nothing + expr = fraction_cancel_fn(num / den) + if iscall(expr) && operation(expr) == / + num, den = arguments(expr) + else + num, den = expr, 1 + end end + # if the denominator isn't a polynomial in `wrt`, better to not include it # to reduce the size of the gcd polynomial if !contains_variable(den, wrt) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 2c788688bd..2751f4f5cd 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -501,7 +501,8 @@ function DiffEqBase.NonlinearProblem{iip}(sys::NonlinearSystem, u0map, error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearProblem`") end if use_homotopy_continuation - prob = safe_HomotopyContinuationProblem(sys, u0map, parammap; check_length, kwargs...) + prob = safe_HomotopyContinuationProblem( + sys, u0map, parammap; check_length, kwargs...) if prob isa HomotopyContinuationProblem return prob end diff --git a/test/extensions/homotopy_continuation.jl b/test/extensions/homotopy_continuation.jl index 3f4a3fbc71..9e15ea857e 100644 --- a/test/extensions/homotopy_continuation.jl +++ b/test/extensions/homotopy_continuation.jl @@ -1,4 +1,5 @@ using ModelingToolkit, NonlinearSolve, SymbolicIndexingInterface +using SymbolicUtils import ModelingToolkit as MTK using LinearAlgebra using Test @@ -34,6 +35,8 @@ import HomotopyContinuation sol = solve(prob2; threading = false) @test SciMLBase.successful_retcode(sol) @test norm(sol.resid)≈0.0 atol=1e-10 + + @test NonlinearProblem(sys, u0; use_homotopy_continuation = false) isa NonlinearProblem end struct Wrapper @@ -217,7 +220,17 @@ end @mtkbuild sys = NonlinearSystem([x^2 + y^2 - 2x - 2 ~ 0, y ~ (x - 1) / (x - 2)]) prob = HomotopyContinuationProblem(sys, []) @test any(prob.denominator([2.0], parameter_values(prob)) .≈ 0.0) - @test_nowarn solve(prob; threading = false) + @test SciMLBase.successful_retcode(solve(prob; threading = false)) + end + + @testset "Rational function forced to common denominators" begin + @variables x = 1 + @mtkbuild sys = NonlinearSystem([0 ~ 1 / (1 + x) - x]) + prob = HomotopyContinuationProblem(sys, []) + @test any(prob.denominator([-1.0], parameter_values(prob)) .≈ 0.0) + sol = solve(prob; threading = false) + @test SciMLBase.successful_retcode(sol) + @test 1 / (1 + sol.u[1]) - sol.u[1]≈0.0 atol=1e-10 end end @@ -229,3 +242,19 @@ end @test sol[x] ≈ √2.0 @test sol[y] ≈ sin(√2.0) end + +@testset "`fraction_cancel_fn`" begin + @variables x = 1 + @named sys = NonlinearSystem([0 ~ ((x^2 - 5x + 6) / (x - 2) - 1) * (x^2 - 7x + 12) / + (x - 4)^3]) + sys = complete(sys) + + @testset "`simplify_fractions`" begin + prob = HomotopyContinuationProblem(sys, []) + @test prob.denominator([0.0], parameter_values(prob)) ≈ [4.0] + end + @testset "`nothing`" begin + prob = HomotopyContinuationProblem(sys, []; fraction_cancel_fn = nothing) + @test sort(prob.denominator([0.0], parameter_values(prob))) ≈ [2.0, 4.0^3] + end +end From 61b79a849cc643982aa036158d7b5fd5b7e11c7b Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Mon, 16 Dec 2024 10:41:04 +0100 Subject: [PATCH 03/38] add option to include disturbance arguments in `generate_control_function` --- src/inputoutput.jl | 21 +++++++++++++------- test/input_output_handling.jl | 37 +++++++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index 4a99ec11f5..89db016795 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -160,7 +160,7 @@ has_var(ex, x) = x ∈ Set(get_variables(ex)) # Build control function """ - (f_oop, f_ip), x_sym, p, io_sys = generate_control_function( + (f_oop, f_ip), x_sym, p_sym, io_sys = generate_control_function( sys::AbstractODESystem, inputs = unbound_inputs(sys), disturbance_inputs = nothing; @@ -177,8 +177,7 @@ f_ip : (xout,x,u,p,t) -> nothing The return values also include the chosen state-realization (the remaining unknowns) `x_sym` and parameters, in the order they appear as arguments to `f`. -If `disturbance_inputs` is an array of variables, the generated dynamics function will preserve any state and dynamics associated with disturbance inputs, but the disturbance inputs themselves will not be included as inputs to the generated function. The use case for this is to generate dynamics for state observers that estimate the influence of unmeasured disturbances, and thus require unknown variables for the disturbance model, but without disturbance inputs since the disturbances are not available for measurement. -See [`add_input_disturbance`](@ref) for a higher-level interface to this functionality. +If `disturbance_inputs` is an array of variables, the generated dynamics function will preserve any state and dynamics associated with disturbance inputs, but the disturbance inputs themselves will (by default) not be included as inputs to the generated function. The use case for this is to generate dynamics for state observers that estimate the influence of unmeasured disturbances, and thus require unknown variables for the disturbance model, but without disturbance inputs since the disturbances are not available for measurement. To add an input argument corresponding to the disturbance inputs, either include the disturbance inputs among the control inputs, or set `disturbance_argument=true`, in which case an additional input argument `w` is added to the generated function `(x,u,p,t,w)->rhs`. !!! note "Un-simplified system" This function expects `sys` to be un-simplified, i.e., `structural_simplify` or `@mtkbuild` should not be called on the system before passing it into this function. `generate_control_function` calls a special version of `structural_simplify` internally. @@ -196,6 +195,7 @@ f[1](x, inputs, p, t) """ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys), disturbance_inputs = disturbances(sys); + disturbance_argument = false, implicit_dae = false, simplify = false, eval_expression = false, @@ -219,10 +219,11 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu # ps = [ps; disturbance_inputs] end inputs = map(x -> time_varying_as_func(value(x), sys), inputs) + disturbance_inputs = map(x -> time_varying_as_func(value(x), sys), disturbance_inputs) eqs = [eq for eq in full_equations(sys)] eqs = map(subs_constants, eqs) - if disturbance_inputs !== nothing + if disturbance_inputs !== nothing && !disturbance_argument # Set all disturbance *inputs* to zero (we just want to keep the disturbance state) subs = Dict(disturbance_inputs .=> 0) eqs = [eq.lhs ~ substitute(eq.rhs, subs) for eq in eqs] @@ -239,16 +240,22 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu t = get_iv(sys) # pre = has_difference ? (ex -> ex) : get_postprocess_fbody(sys) - - args = (u, inputs, p..., t) + if disturbance_argument + args = (u, inputs, p..., t, disturbance_inputs) + else + args = (u, inputs, p..., t) + end if implicit_dae ddvs = map(Differential(get_iv(sys)), dvs) args = (ddvs, args...) end process = get_postprocess_fbody(sys) + wrapped_arrays_vars = disturbance_argument ? + wrap_array_vars(sys, rhss; dvs, ps, inputs, disturbance_inputs) : + wrap_array_vars(sys, rhss; dvs, ps, inputs) f = build_function(rhss, args...; postprocess_fbody = process, expression = Val{true}, wrap_code = wrap_mtkparameters(sys, false, 3) .∘ - wrap_array_vars(sys, rhss; dvs, ps, inputs) .∘ + wrapped_arrays_vars .∘ wrap_parameter_dependencies(sys, false), kwargs...) f = eval_or_rgf.(f; eval_expression, eval_module) diff --git a/test/input_output_handling.jl b/test/input_output_handling.jl index 9550a87f31..25434e3dfb 100644 --- a/test/input_output_handling.jl +++ b/test/input_output_handling.jl @@ -170,6 +170,43 @@ x = [rand()] u = [rand()] @test f[1](x, u, p, 1) == -x + u +# With disturbance inputs +@variables x(t)=0 u(t)=0 [input = true] d(t)=0 +eqs = [ + D(x) ~ -x + u + d^2 +] + +@named sys = ODESystem(eqs, t) +f, dvs, ps, io_sys = ModelingToolkit.generate_control_function( + sys, [u], [d], simplify = true) + +@test isequal(dvs[], x) +@test isempty(ps) + +p = nothing +x = [rand()] +u = [rand()] +@test f[1](x, u, p, 1) == -x + u + +# With added d argument +@variables x(t)=0 u(t)=0 [input = true] d(t)=0 +eqs = [ + D(x) ~ -x + u + d^2 +] + +@named sys = ODESystem(eqs, t) +f, dvs, ps, io_sys = ModelingToolkit.generate_control_function( + sys, [u], [d], simplify = true, disturbance_argument = true) + +@test isequal(dvs[], x) +@test isempty(ps) + +p = nothing +x = [rand()] +u = [rand()] +d = [rand()] +@test f[1](x, u, p, 1, d) == -x + u + [d[]^2] + # more complicated system @variables u(t) [input = true] From b7a4cdf71bbca02bb4f4a6f62a6c0d65964188d1 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Mon, 16 Dec 2024 11:25:51 +0100 Subject: [PATCH 04/38] Apply suggestions from code review Co-authored-by: Aayush Sabharwal --- src/inputoutput.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index 89db016795..f1ea0bc749 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -219,7 +219,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu # ps = [ps; disturbance_inputs] end inputs = map(x -> time_varying_as_func(value(x), sys), inputs) - disturbance_inputs = map(x -> time_varying_as_func(value(x), sys), disturbance_inputs) + disturbance_inputs = unwrap.(disturbance_inputs) eqs = [eq for eq in full_equations(sys)] eqs = map(subs_constants, eqs) @@ -251,7 +251,7 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu end process = get_postprocess_fbody(sys) wrapped_arrays_vars = disturbance_argument ? - wrap_array_vars(sys, rhss; dvs, ps, inputs, disturbance_inputs) : + wrap_array_vars(sys, rhss; dvs, ps, inputs, cachesyms = (disturbance_inputs,)) : wrap_array_vars(sys, rhss; dvs, ps, inputs) f = build_function(rhss, args...; postprocess_fbody = process, expression = Val{true}, wrap_code = wrap_mtkparameters(sys, false, 3) .∘ From d75009ba5fde932da03095edb5b811d3b03d38a6 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Mon, 16 Dec 2024 12:08:12 +0100 Subject: [PATCH 05/38] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/inputoutput.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index f1ea0bc749..1be1fbb8dd 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -251,7 +251,8 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu end process = get_postprocess_fbody(sys) wrapped_arrays_vars = disturbance_argument ? - wrap_array_vars(sys, rhss; dvs, ps, inputs, cachesyms = (disturbance_inputs,)) : + wrap_array_vars( + sys, rhss; dvs, ps, inputs, cachesyms = (disturbance_inputs,)) : wrap_array_vars(sys, rhss; dvs, ps, inputs) f = build_function(rhss, args...; postprocess_fbody = process, expression = Val{true}, wrap_code = wrap_mtkparameters(sys, false, 3) .∘ From a3789ae17a9f67df87b98179483e044c33e1fe44 Mon Sep 17 00:00:00 2001 From: Fredrik Bagge Carlson Date: Mon, 16 Dec 2024 14:20:27 +0100 Subject: [PATCH 06/38] test with split true and false --- test/input_output_handling.jl | 113 ++++++++++++++++++---------------- 1 file changed, 59 insertions(+), 54 deletions(-) diff --git a/test/input_output_handling.jl b/test/input_output_handling.jl index 25434e3dfb..de6fc92b5c 100644 --- a/test/input_output_handling.jl +++ b/test/input_output_handling.jl @@ -153,61 +153,66 @@ if VERSION >= v"1.8" # :opaque_closure not supported before end ## Code generation with unbound inputs +@testset "generate_control_function with disturbance inputs" begin + for split in [true, false] + simplify = true + + @variables x(t)=0 u(t)=0 [input = true] + eqs = [ + D(x) ~ -x + u + ] + + @named sys = ODESystem(eqs, t) + f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys; simplify, split) + + @test isequal(dvs[], x) + @test isempty(ps) + + p = nothing + x = [rand()] + u = [rand()] + @test f[1](x, u, p, 1) == -x + u + + # With disturbance inputs + @variables x(t)=0 u(t)=0 [input = true] d(t)=0 + eqs = [ + D(x) ~ -x + u + d^2 + ] + + @named sys = ODESystem(eqs, t) + f, dvs, ps, io_sys = ModelingToolkit.generate_control_function( + sys, [u], [d]; simplify, split) + + @test isequal(dvs[], x) + @test isempty(ps) + + p = nothing + x = [rand()] + u = [rand()] + @test f[1](x, u, p, 1) == -x + u + + ## With added d argument + @variables x(t)=0 u(t)=0 [input = true] d(t)=0 + eqs = [ + D(x) ~ -x + u + d^2 + ] + + @named sys = ODESystem(eqs, t) + f, dvs, ps, io_sys = ModelingToolkit.generate_control_function( + sys, [u], [d]; simplify, split, disturbance_argument = true) + + @test isequal(dvs[], x) + @test isempty(ps) + + p = nothing + x = [rand()] + u = [rand()] + d = [rand()] + @test f[1](x, u, p, t, d) == -x + u + [d[]^2] + end +end -@variables x(t)=0 u(t)=0 [input = true] -eqs = [ - D(x) ~ -x + u -] - -@named sys = ODESystem(eqs, t) -f, dvs, ps, io_sys = ModelingToolkit.generate_control_function(sys, simplify = true) - -@test isequal(dvs[], x) -@test isempty(ps) - -p = nothing -x = [rand()] -u = [rand()] -@test f[1](x, u, p, 1) == -x + u - -# With disturbance inputs -@variables x(t)=0 u(t)=0 [input = true] d(t)=0 -eqs = [ - D(x) ~ -x + u + d^2 -] - -@named sys = ODESystem(eqs, t) -f, dvs, ps, io_sys = ModelingToolkit.generate_control_function( - sys, [u], [d], simplify = true) - -@test isequal(dvs[], x) -@test isempty(ps) - -p = nothing -x = [rand()] -u = [rand()] -@test f[1](x, u, p, 1) == -x + u - -# With added d argument -@variables x(t)=0 u(t)=0 [input = true] d(t)=0 -eqs = [ - D(x) ~ -x + u + d^2 -] - -@named sys = ODESystem(eqs, t) -f, dvs, ps, io_sys = ModelingToolkit.generate_control_function( - sys, [u], [d], simplify = true, disturbance_argument = true) - -@test isequal(dvs[], x) -@test isempty(ps) - -p = nothing -x = [rand()] -u = [rand()] -d = [rand()] -@test f[1](x, u, p, 1, d) == -x + u + [d[]^2] - -# more complicated system +## more complicated system @variables u(t) [input = true] From cb6ca4ce4cfd8547117317d23e5b32558d995a68 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 16 Dec 2024 19:25:25 +0530 Subject: [PATCH 07/38] feat: add support for `extra_args` in `wrap_array_vars` --- src/systems/abstractsystem.jl | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index e44f250a7f..6a20ee17dc 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -230,9 +230,33 @@ function wrap_parameter_dependencies(sys::AbstractSystem, isscalar) wrap_assignments(isscalar, [eq.lhs ← eq.rhs for eq in parameter_dependencies(sys)]) end +""" + $(TYPEDSIGNATURES) + +Add the necessary assignment statements to allow use of unscalarized array variables +in the generated code. `expr` is the expression returned by the function. `dvs` and +`ps` are the unknowns and parameters of the system `sys` to use in the generated code. +`inputs` can be specified as an array of symbolics if the generated function has inputs. +If `history == true`, the generated function accepts a history function. `cachesyms` are +extra variables (arrays of variables) stored in the cache array(s) of the parameter +object. `extra_args` are extra arguments appended to the end of the argument list. + +The function is assumed to have the signature `f(du, u, h, x, p, cache_syms..., t, extra_args...)` +Where: +- `du` is the optional buffer to write to for in-place functions. +- `u` is the list of unknowns. This argument is not present if `dvs === nothing`. +- `h` is the optional history function, present if `history == true`. +- `x` is the array of inputs, present only if `inputs !== nothing`. Values are assumed + to be in the order of variables passed to `inputs`. +- `p` is the parameter object. +- `cache_syms` are the cache variables. These are part of the splatted parameter object. +- `t` is time, present only if the system is time dependent. +- `extra_args` are the extra arguments passed to the function, present only if + `extra_args` is non-empty. +""" function wrap_array_vars( sys::AbstractSystem, exprs; dvs = unknowns(sys), ps = parameters(sys), - inputs = nothing, history = false, cachesyms::Tuple = ()) + inputs = nothing, history = false, cachesyms::Tuple = (), extra_args::Tuple = ()) isscalar = !(exprs isa AbstractArray) var_to_arridxs = Dict() @@ -252,6 +276,10 @@ function wrap_array_vars( if inputs !== nothing rps = (inputs, rps...) end + if has_iv(sys) + rps = (rps..., get_iv(sys)) + end + rps = (rps..., extra_args...) for sym in reduce(vcat, rps; init = []) iscall(sym) && operation(sym) == getindex || continue arg = arguments(sym)[1] From 266630dacfa864b79fae6e2eb534b6c4be99001d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 16 Dec 2024 19:44:06 +0530 Subject: [PATCH 08/38] feat: handle additional arguments in `wrap_mtkparameters` --- src/systems/abstractsystem.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 6a20ee17dc..168260ae69 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -360,7 +360,7 @@ end const MTKPARAMETERS_ARG = Sym{Vector{Vector}}(:___mtkparameters___) """ - wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2) + wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2, offset = Int(is_time_dependent(sys))) Return function(s) to be passed to the `wrap_code` keyword of `build_function` which allow the compiled function to be called as `f(u, p, t)` where `p isa MTKParameters` @@ -370,12 +370,14 @@ the first parameter vector in the out-of-place version of the function. For exam if a history function (DDEs) was passed before `p`, then the function before wrapping would have the signature `f(u, h, p..., t)` and hence `p_start` would need to be `3`. +`offset` is the number of arguments at the end of the argument list to ignore. Defaults +to 1 if the system is time-dependent (to ignore `t`) and 0 otherwise. + The returned function is `identity` if the system does not have an `IndexCache`. """ -function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2) +function wrap_mtkparameters(sys::AbstractSystem, isscalar::Bool, p_start = 2, + offset = Int(is_time_dependent(sys))) if has_index_cache(sys) && get_index_cache(sys) !== nothing - offset = Int(is_time_dependent(sys)) - if isscalar function (expr) param_args = expr.args[p_start:(end - offset)] From a79f4c2918b309cf4653f5cb110812cddaa214db Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 16 Dec 2024 19:25:42 +0530 Subject: [PATCH 09/38] fix: use `extra_args` in `generate_control_function` --- src/inputoutput.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/inputoutput.jl b/src/inputoutput.jl index 1be1fbb8dd..6bdcac6dd4 100644 --- a/src/inputoutput.jl +++ b/src/inputoutput.jl @@ -252,10 +252,11 @@ function generate_control_function(sys::AbstractODESystem, inputs = unbound_inpu process = get_postprocess_fbody(sys) wrapped_arrays_vars = disturbance_argument ? wrap_array_vars( - sys, rhss; dvs, ps, inputs, cachesyms = (disturbance_inputs,)) : + sys, rhss; dvs, ps, inputs, extra_args = (disturbance_inputs,)) : wrap_array_vars(sys, rhss; dvs, ps, inputs) f = build_function(rhss, args...; postprocess_fbody = process, - expression = Val{true}, wrap_code = wrap_mtkparameters(sys, false, 3) .∘ + expression = Val{true}, wrap_code = wrap_mtkparameters( + sys, false, 3, Int(disturbance_argument) + 1) .∘ wrapped_arrays_vars .∘ wrap_parameter_dependencies(sys, false), kwargs...) From 21a2a33aa6d505da57c30c8e8326b40201106aae Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 13 Dec 2024 18:21:18 +0530 Subject: [PATCH 10/38] feat: allow passing pre-computed `vars` to `observed_equations_used_by` --- src/utils.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index dd7113cbe6..e9ddad3a07 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1033,19 +1033,24 @@ end $(TYPEDSIGNATURES) Return the indexes of observed equations of `sys` used by expression `exprs`. + +Keyword arguments: +- `involved_vars`: A collection of the variables involved in `exprs`. This is the set of + variables which will be explored to find dependencies on observed equations. Typically, + providing this keyword is not necessary and is only useful to avoid repeatedly calling + `vars(exprs)` """ -function observed_equations_used_by(sys::AbstractSystem, exprs) +function observed_equations_used_by(sys::AbstractSystem, exprs; involved_vars = vars(exprs)) obs = observed(sys) obsvars = getproperty.(obs, :lhs) graph = observed_dependency_graph(obs) - syms = vars(exprs) - obsidxs = BitSet() - for sym in syms + for sym in involved_vars idx = findfirst(isequal(sym), obsvars) idx === nothing && continue + idx in obsidxs && continue parents = dfs_parents(graph, idx) for i in eachindex(parents) parents[i] == 0 && continue From 7deb72ba09091a1324c8efb75dc8b55a5c06906e Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 13 Dec 2024 18:21:33 +0530 Subject: [PATCH 11/38] feat: implement `IfLifting` structural simplification pass Co-authored-by: Benjamin Chung --- src/ModelingToolkit.jl | 1 + src/systems/if_lifting.jl | 511 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 512 insertions(+) create mode 100644 src/systems/if_lifting.jl diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 20b2ada8fa..10aba4d8a9 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -181,6 +181,7 @@ include("discretedomain.jl") include("systems/systemstructure.jl") include("systems/clock_inference.jl") include("systems/systems.jl") +include("systems/if_lifting.jl") include("debugging.jl") include("systems/alias_elimination.jl") diff --git a/src/systems/if_lifting.jl b/src/systems/if_lifting.jl new file mode 100644 index 0000000000..53e1ca4957 --- /dev/null +++ b/src/systems/if_lifting.jl @@ -0,0 +1,511 @@ +""" + struct CondRewriter + +Callable struct used to transform symbolic conditions into conditions involving discrete +variables. +""" +struct CondRewriter + """ + The independent variable which the discrete variables depend on. + """ + iv::BasicSymbolic + """ + A mapping from a discrete variables to a `NamedTuple` containing the condition + determining whether the discrete variable needs to be evaluated and the symbolic + expression the discrete variable represents. The expression is used as a rootfinding + function, and zero-crossings trigger re-evaluation of the condition (if `dependency` + is `true`). `expression < 0` is evaluated on an up-crossing and `expression <= 0` is + evaluated on a down-crossing to get the updated value of the condition variable. + """ + conditions::Dict{Any, @NamedTuple{dependency, expression}} +end + +function CondRewriter(iv) + return CondRewriter(iv, Dict()) +end + +""" + $(TYPEDSIGNATURES) + +Given a symbolic condition `expr` and the condition `dep` it depends on, update the +mapping in `cw` and generate a new discrete variable if necessary. +""" +function new_cond_sym(cw::CondRewriter, expr, dep) + if !iscall(expr) || operation(expr) != Base.:(<) || !iszero(arguments(expr)[2]) + throw(ArgumentError("`expr` passed to `new_cond_sym` must be of the form `f(args...) < 0`. Got $expr.")) + end + # check if the same expression exists in the mapping + existing_var = findfirst(p -> isequal(p.expression, expr), cw.conditions) + if existing_var !== nothing + # cache hit + (existing_dep, _) = cw.conditions[existing_var] + # update the dependency condition + cw.conditions[existing_var] = (dependency = (dep | existing_dep), expression = expr) + return existing_var + end + # generate a new condition variable + cvar = gensym("cond") + st = symtype(expr) + iv = cw.iv + cv = unwrap(first(@parameters $(cvar)(iv)::st = true)) # TODO: real init + cw.conditions[cv] = (dependency = dep, expression = expr) + return cv +end + +""" +Utility function for boolean implication. +""" +implies(a, b) = !a & b + +""" + $(TYPEDSIGNATURES) + +Recursively rewrite conditions into discrete variables. `expr` is the condition to rewrite, +`dep` is a boolean expression/value which determines when the `expr` is to be evaluated. For +example, if `expr = expr1 | expr2` and `dep = dep1`, then `expr` should only be evaluated if +`dep1` evaluates to `true`. Recursively, `expr1` should only be evaluated if `dep1` is `true`, +and `expr2` should only be evaluated if `dep & !expr1`. + +Returns a 3-tuple of the substituted expression, a condition describing when `expr` evaluates +to `true`, and a condition describing when `expr` evaluates to `false`. + +This expects that all expressions with discontinuities or with discontinuous derivatives have +been rewritten into the form of `ifelse(rootfunc(args...) < 0, left(args...), right(args...))`. +The transformation is performed via `discontinuities_to_ifelse` using `Symbolics.rootfunction` +and family. +""" +function (cw::CondRewriter)(expr, dep) + # single variable, trivial case + if issym(expr) || iscall(expr) && issym(operation(expr)) + return (expr, expr, !expr) + # literal boolean or integer + elseif expr isa Bool + return (expr, expr, !expr) + elseif expr isa Int + return (expr, true, true) + # other singleton symbolic variables + elseif !iscall(expr) + @warn "Automatic conversion of if statments to events requires use of a limited conditional grammar; see the documentation. Skipping due to $expr" + return (expr, true, true) # error case => conservative assumption is that both true and false have to be evaluated + elseif operation(expr) == Base.:(|) # OR of two conditions + a, b = arguments(expr) + (rw_conda, truea, falsea) = cw(a, dep) + # only evaluate second if first is false + (rw_condb, trueb, falseb) = cw(b, dep & falsea) + return (rw_conda | rw_condb, truea | trueb, falsea & falseb) + + elseif operation(expr) == Base.:(&) # AND of two conditions + a, b = arguments(expr) + (rw_conda, truea, falsea) = cw(a, dep) + # only evaluate second if first is true + (rw_condb, trueb, falseb) = cw(b, dep & truea) + return (rw_conda & rw_condb, truea & trueb, falsea | falseb) + elseif operation(expr) == ifelse + c, a, b = arguments(expr) + (rw_cond, ctrue, cfalse) = cw(c, dep) + # only evaluate if condition is true + (rw_conda, truea, falsea) = cw(a, dep & ctrue) + # only evaluate if condition is false + (rw_condb, trueb, falseb) = cw(b, dep & cfalse) + # expression is true if condition is true and THEN branch is true, or condition is false + # and ELSE branch is true + # similarly for expression being false + return (ifelse(rw_cond, rw_conda, rw_condb), + implies(ctrue, truea) | implies(cfalse, trueb), + implies(ctrue, falsea) | implies(cfalse, falseb)) + elseif operation(expr) == Base.:(!) # NOT of expression + (a,) = arguments(expr) + (rw, ctrue, cfalse) = cw(a, dep) + return (!rw, cfalse, ctrue) + elseif operation(expr) == Base.:(<) + if !isequal(arguments(expr)[2], 0) + throw(ArgumentError("Expected comparison to be written as `f(args...) < 0`. Found $expr.")) + end + + # if the comparison does not include time-dependent variables, + # don't create a callback for it + + # Calling `expression_is_time_dependent` is `O(d)` where `d` is the depth of the + # expression tree. We only call this in this here to avoid turning this into + # an `O(d^2)` time complexity recursion, which would happen if it were called + # at the beginning of the function. Now, it only happens near the leaves of + # the recursive tree. + if !expression_is_time_dependent(expr, cw.iv) + return (expr, expr, !expr) + end + cv = new_cond_sym(cw, expr, dep) + return (cv, cv, !cv) + elseif operation(expr) == (==) + # we don't touch equality since it's a point discontinuity. It's basically always + # false for continuous variables. In case it's an equality between discrete + # quantities, we don't need to transform it. + return (expr, expr, !expr) + elseif !expression_is_time_dependent(expr, cw.iv) + return (expr, expr, !expr) + end + error(""" + Unsupported expression form in decision variable computation $expr. If the expression + involves a registered function, declare the discontinuity using + `Symbolics.@register_discontinuity`. If this is not meant to be transformed via + `IfLifting`, wrap the parent expression in `ModelingToolkit.no_if_lift`. + """) +end + +""" + $(TYPEDSIGNATURES) + +Acts as the identity function, and prevents transformation of conditional expressions inside it. Useful +if specific `ifelse` or other functions with discontinuous derivatives shouldn't be transformed into +callbacks. +""" +no_if_lift(s) = s +@register_symbolic no_if_lift(s) + +""" + $(TYPEDEF) + +A utility struct to search through an expression specifically for `ifelse` terms, and find +all variables used in the condition of such terms. The variables are stored in a field of +the struct. +""" +struct VarsUsedInCondition + """ + Stores variables used in conditions of `ifelse` statements in the expression. + """ + vars::Set{Any} +end + +VarsUsedInCondition() = VarsUsedInCondition(Set()) + +function (v::VarsUsedInCondition)(expr) + expr = Symbolics.unwrap(expr) + if symbolic_type(expr) == NotSymbolic() + is_array_of_symbolics(expr) || return + foreach(v, expr) + return + end + iscall(expr) || return + op = operation(expr) + + # do not search inside no_if_lift to avoid discovering + # redundant variables + op == no_if_lift && return + + args = arguments(expr) + if op == ifelse + cond, branch_a, branch_b = arguments(expr) + vars!(v.vars, cond) + v(branch_a) + v(branch_b) + end + foreach(v, args) + return +end + +""" + $(TYPEDSIGNATURES) + +Check if `expr` depends on the independent variable `iv`. Return `true` if `iv` is present +in the expression, `Differential(iv)` is in the expression, or a dependent variable such +as `@variables x(iv)` is in the expression. +""" +function expression_is_time_dependent(expr, iv) + any(vars(expr)) do sym + sym = unwrap(sym) + isequal(sym, iv) && return true + iscall(sym) || return false + op = operation(sym) + args = arguments(sym) + op isa Differential && op == Differential(iv) || + issym(op) && length(args) == 1 && expression_is_time_dependent(args[1], iv) + end +end + +""" + $(TYPEDSIGNATURES) + +Given an expression `expr` which is to be evaluated if `dep` evaluates to `true`, transform +the conditions of all all `ifelse` statements in `expr` into functions of new discrete +variables. `cw` is used to store the information relevant to these newly introduced variables. +""" +function rewrite_ifs(cw::CondRewriter, expr, dep) + expr = unwrap(expr) + if symbolic_type(expr) == NotSymbolic() + # non-symbolic expression might still be an array of symbolic expressions + is_array_of_symbolics(expr) || return expr + return map(ex -> rewrite_ifs(cw, ex, dep), expr) + end + + iscall(expr) || return expr + op = operation(expr) + args = arguments(expr) + # do not search into `no_if_lift` + op == no_if_lift && return expr + + # transform `ifelse` + if op == ifelse + cond, iftrue, iffalse = args + + (rw_cond, deptrue, depfalse) = cw(cond, dep) + rw_iftrue = rewrite_ifs(cw, iftrue, deptrue) + rw_iffalse = rewrite_ifs(cw, iffalse, depfalse) + return maketerm( + typeof(expr), ifelse, [unwrap(rw_cond), rw_iftrue, rw_iffalse], metadata(expr)) + end + + # recurse into the rest of the cases + args = map(ex -> rewrite_ifs(cw, ex, dep), args) + return maketerm(typeof(expr), op, args, metadata(expr)) +end + +""" + $(TYPEDSIGNATURES) + +Return a modified `expr` where functions with known discontinuities or discontinuous +derivatives are transformed into `ifelse` statements. Utilizes the discontinuity API +in Symbolics. See [`Symbolics.rootfunction`](@ref), +[`Symbolics.left_continuous_function`](@ref), [`Symbolics.right_continuous_function`](@ref). + +`iv` is the independent variable of the system. Only subexpressions of `expr` which +depend on `iv` are transformed. +""" +function discontinuities_to_ifelse(expr, iv) + expr = unwrap(expr) + if symbolic_type(expr) == NotSymbolic() + # non-symbolic expression might still be an array of symbolic expressions + is_array_of_symbolics(expr) || return expr + return map(ex -> discontinuities_to_ifelse(ex, iv), expr) + end + + iscall(expr) || return expr + op = operation(expr) + args = arguments(expr) + # do not search into `no_if_lift` + op == no_if_lift && return expr + + # Case I: the operation is symbolic. + # We don't actually care if this is a callable parameter or not. + # If it is, we want to search inside and perform if-lifting there. + # If it isn't, either it's `x(t)` in which case this recursion is + # effectively a no-op OR it's `x(f(t))` for DDEs and we want to + # perform if-lifting inside. + # + # Case II: the operation is not symbolic. + # We anyway want to recursively apply the transformation. + # + # Thus, we can do this here regardless of the subsequent checks + args = map(ex -> discontinuities_to_ifelse(ex, iv), args) + + # if the operation is a known discontinuity + if hasmethod(Symbolics.rootfunction, Tuple{typeof(op)}) + rootfn = Symbolics.rootfunction(op) + leftfn = Symbolics.left_continuous_function(op) + rightfn = Symbolics.right_continuous_function(op) + rootexpr = rootfn(args...) < 0 + leftexpr = leftfn(args...) + rightexpr = rightfn(args...) + return maketerm( + typeof(expr), ifelse, [rootexpr, leftexpr, rightexpr], metadata(expr)) + end + + return maketerm(typeof(expr), op, args, metadata(expr)) +end + +""" + $(TYPEDSIGNATURES) + +Generate the symbolic condition for discrete variable `sym`, which represents the condition +of an `ifelse` statement created through [`IfLifting`](@ref). This condition is used to +trigger a callback which updates the value of the condition appropriately. +""" +function generate_condition(cw::CondRewriter, sym) + (dep, expr) = cw.conditions[sym] + + # expr is `f(args...) < 0`, `f(args...)` is the zero-crossing expression + zero_crossing = arguments(expr)[1] + + # if we're meant to evaluate the condition, evaluate it. Otherwise, return `NaN`. + # the solvers don't treat the transition from a number to NaN or back as a zero-crossing, + # so it can be used to effectively disable the affect when the condition is not meant to + # be evaluated. + return ifelse(dep, zero_crossing, NaN) ~ 0 +end + +""" + $(TYPEDSIGNATURES) + +Generate the upcrossing and downcrossing affect functions for discrete variable `sym` involved +in `ifelse` statements that are lifted to callbacks using [`IfLifting`](@ref). `syms` is a +condition variable introduced by `cw`, and is thus a key in `cw.conditions`. `new_cond_vars` +is the list of all such new condition variables, corresponding to the order of vertices in +`new_cond_vars_graph`. `new_cond_vars_graph` is a directed graph where edges denote the +condition variables involved in the dependency expression of the source vertex. +""" +function generate_affects(cw::CondRewriter, sym, new_cond_vars, new_cond_vars_graph) + sym_idx = findfirst(isequal(sym), new_cond_vars) + if sym_idx === nothing + throw(ArgumentError("Expected variable $sym to be a condition variable in $new_cond_vars.")) + end + # use reverse direction of edges because instead of finding the variables it depends + # on, we want the variables that depend on it + parents = bfs_parents(new_cond_vars_graph, sym_idx; dir = :in) + cond_vars_to_update = [new_cond_vars[i] + for i in eachindex(parents) if !iszero(parents[i])] + update_syms = Symbol.(cond_vars_to_update) + modified = NamedTuple{(update_syms...,)}(cond_vars_to_update) + + upcrossing_update_exprs = [arguments(last(cw.conditions[sym]))[1] < 0 + for sym in cond_vars_to_update] + upcrossing = ImperativeAffect( + modified, observed = NamedTuple{(update_syms...,)}(upcrossing_update_exprs), + skip_checks = true) do x, o, c, i + return o + end + downcrossing_update_exprs = [arguments(last(cw.conditions[sym]))[1] <= 0 + for sym in cond_vars_to_update] + downcrossing = ImperativeAffect( + modified, observed = NamedTuple{(update_syms...,)}(downcrossing_update_exprs), + skip_checks = true) do x, o, c, i + return o + end + + return upcrossing, downcrossing +end + +const CONDITION_SIMPLIFIER = Rewriters.Fixpoint(Rewriters.Postwalk(Rewriters.Chain([ + # simple boolean laws + (@rule (!!(~x)) => (~x)) + (@rule ((~x) & true) => (~x)) + (@rule ((~x) & false) => false) + (@rule ((~x) | true) => true) + (@rule ((~x) | false) => (~x)) + (@rule ((~x) & !(~x)) => false) + (@rule ((~x) | !(~x)) => true) + # reversed order of the above, because it matters and `@acrule` refuses to do its job + (@rule (true & (~x)) => (~x)) + (@rule (false & (~x)) => false) + (@rule (true | (~x)) => true) + (@rule (false | (~x)) => (~x)) + (@rule (!(~x) & (~x)) => false) + (@rule (!(~x) | (~x)) => true) + # idempotent + (@rule ((~x) & (~x)) => (~x)) + (@rule ((~x) | (~x)) => (~x)) + # ifelse with determined branches + (@rule ifelse((~x), true, false) => (~x)) + (@rule ifelse((~x), false, true) => !(~x)) + # ifelse with identical branches + (@rule ifelse((~x), (~y), (~y)) => (~y)) + (@rule ifelse((~x), (~y), !(~y)) => ((~x) & + (~y))) + (@rule ifelse((~x), !(~y), (~y)) => ((~x) & + !(~y))) + # ifelse with determined condition + (@rule ifelse(true, (~x), (~y)) => (~x)) + (@rule ifelse(false, (~x), (~y)) => (~y))]))) + +""" +If lifting converts (nested) if statements into a series of continous events + a logically equivalent if statement + parameters. + +Lifting proceeds through the following process: +* rewrite comparisons to be of the form eqn [op] 0; subtract the RHS from the LHS +* replace comparisons with generated parameters; for each comparison eqn [op] 0, generate an event (dependent on op) that sets the parameter +""" +function IfLifting(sys::ODESystem) + cw = CondRewriter(get_iv(sys)) + + eqs = copy(equations(sys)) + obs = copy(observed(sys)) + + # get variables used by `eqs` + syms = vars(eqs) + # get observed equations used by `eqs` + obs_idxs = observed_equations_used_by(sys, eqs; involved_vars = syms) + # and the variables used in those equations + for i in obs_idxs + vars!(syms, obs[i]) + end + + # get all integral variables used in conditions + # this is used when performing the transformation on observed equations + # since they are transformed differently depending on whether they are + # discrete variables involved in a condition or not + condition_vars = Set() + # searcher struct + # we can use the same one since it avoids iterating over duplicates + vars_in_condition! = VarsUsedInCondition() + for i in eachindex(eqs) + eq = eqs[i] + vars_in_condition!(eq.rhs) + # also transform the equation + eqs[i] = eq.lhs ~ rewrite_ifs(cw, discontinuities_to_ifelse(eq.rhs, cw.iv), true) + end + # also search through relevant observed equations + for i in obs_idxs + vars_in_condition!(obs[i].rhs) + end + # add to `condition_vars` after filtering out differential, parameter, independent and + # non-integral variables + for v in vars_in_condition!.vars + v = unwrap(v) + stype = symtype(v) + if isdifferential(v) || isparameter(v) || isequal(v, get_iv(sys)) + continue + end + stype <: Union{Integer, AbstractArray{Integer}} && push!(condition_vars, v) + end + # transform observed equations + for i in obs_idxs + obs[i] = if obs[i].lhs in condition_vars + obs[i].lhs ~ first(cw(discontinuities_to_ifelse(obs[i].rhs, cw.iv), true)) + else + obs[i].lhs ~ rewrite_ifs(cw, discontinuities_to_ifelse(obs[i].rhs, cw.iv), true) + end + end + + # `rewrite_ifs` and calling `cw` generate a lot of redundant code, simplify it + eqs = map(eqs) do eq + eq.lhs ~ CONDITION_SIMPLIFIER(eq.rhs) + end + obs = map(obs) do eq + eq.lhs ~ CONDITION_SIMPLIFIER(eq.rhs) + end + # also simplify dependencies + for (k, v) in cw.conditions + cw.conditions[k] = map(CONDITION_SIMPLIFIER ∘ unwrap, v) + end + + # get directed graph where nodes are the new condition variables and edges from each + # node denote the condition variables used in it's dependency expression + + # so we have an ordering for the vertices + new_cond_vars = collect(keys(cw.conditions)) + # "observed" equations + new_cond_dep_eqs = [v ~ cw.conditions[v] for v in new_cond_vars] + # construct the graph as a `DiCMOBiGraph` + new_cond_vars_graph = observed_dependency_graph(new_cond_dep_eqs) + + new_callbacks = continuous_events(sys) + new_defaults = defaults(sys) + new_ps = Vector{SymbolicParam}(parameters(sys)) + + for var in new_cond_vars + condition = generate_condition(cw, var) + up_affect, down_affect = generate_affects( + cw, var, new_cond_vars, new_cond_vars_graph) + cb = SymbolicContinuousCallback([condition], up_affect; affect_neg = down_affect, + initialize = up_affect, rootfind = SciMLBase.RightRootFind) + + push!(new_callbacks, cb) + new_defaults[var] = getdefault(var) + push!(new_ps, var) + end + + @set! sys.defaults = new_defaults + @set! sys.eqs = eqs + # do not need to topsort because we didn't modify the order + @set! sys.observed = obs + @set! sys.continuous_events = new_callbacks + @set! sys.ps = new_ps + return sys +end From 838ad8034b83bd7e935b64c31e034ecd94bd0fff Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 17 Dec 2024 22:15:26 +0530 Subject: [PATCH 12/38] test: add tests for if-lifting --- test/if_lifting.jl | 110 +++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + 2 files changed, 111 insertions(+) create mode 100644 test/if_lifting.jl diff --git a/test/if_lifting.jl b/test/if_lifting.jl new file mode 100644 index 0000000000..d702355506 --- /dev/null +++ b/test/if_lifting.jl @@ -0,0 +1,110 @@ +using ModelingToolkit, OrdinaryDiffEq +using ModelingToolkit: t_nounits as t, D_nounits as D, IfLifting, no_if_lift + +@testset "Simple `abs(x)`" begin + @mtkmodel SimpleAbs begin + @variables begin + x(t) + y(t) + end + @equations begin + D(x) ~ abs(y) + y ~ sin(t) + end + end + @named sys = SimpleAbs() + ss1 = structural_simplify(sys) + @test length(equations(ss1)) == 1 + ss2 = structural_simplify(sys, additional_passes = [IfLifting]) + @test length(equations(ss2)) == 1 + @test length(parameters(ss2)) == 1 + @test operation(only(equations(ss2)).rhs) === ifelse + + discvar = only(parameters(ss2)) + prob2 = ODEProblem(ss2, [x => 0.0], (0.0, 5.0)) + sol2 = solve(prob2, Tsit5()) + @test count(isapprox(pi), sol2.t) == 2 + @test any(isapprox(pi), sol2.discretes[1].t) + @test !sol2[discvar][1] + @test sol2[discvar][end] + + _t = pi + 1.0 + # x(t) = 1 - cos(t) in [0, pi) + # x(t) = 3 + cos(t) in [pi, 2pi) + _trueval = 3 + cos(_t) + @test !isapprox(sol1(_t)[1], _trueval; rtol = 1e-3) + @test isapprox(sol2(_t)[1], _trueval; rtol = 1e-3) +end + +@testset "Big test case" begin + @mtkmodel BigModel begin + @variables begin + x(t) + y(t) + z(t) + c(t)::Bool + w(t) + q(t) + r(t) + end + @parameters begin + p + end + @equations begin + # ifelse, max, min + D(x) ~ ifelse(c, max(x, y), min(x, y)) + # discrete observed + c ~ x <= y + # observed should also get if-lifting + y ~ abs(sin(t)) + # should be ignored + D(z) ~ no_if_lift(ifelse(x < y, x, y)) + # ignore time-independent ifelse + D(w) ~ ifelse(p < 3, 1.0, 2.0) + # all the boolean operators + D(q) ~ ifelse((x < 1) & ((y < 0.5) | ifelse(y > 0.8, c, !c)), 1.0, 2.0) + # don't touch time-independent condition, but modify time-dependent branches + D(r) ~ ifelse(p < 2, abs(x), max(y, 0.9)) + end + end + + @named sys = BigModel() + ss = structural_simplify(sys, additional_passes = [IfLifting]) + + ps = parameters(ss) + @test length(ps) == 9 + eqs = equations(ss) + obs = observed(ss) + + @testset "no_if_lift is untouched" begin + idx = findfirst(eq -> isequal(eq.lhs, D(ss.z)), eqs) + eq = eqs[idx] + @test isequal(eq.rhs, no_if_lift(ifelse(ss.x < ss.y, ss.x, ss.y))) + end + @testset "time-independent ifelse is untouched" begin + idx = findfirst(eq -> isequal(eq.lhs, D(ss.w)), eqs) + eq = eqs[idx] + @test operation(arguments(eq.rhs)[1]) === Base.:< + end + @testset "time-dependent branch of time-independent condition is modified" begin + idx = findfirst(eq -> isequal(eq.lhs, D(ss.r)), eqs) + eq = eqs[idx] + @test operation(eq.rhs) === ifelse + args = arguments(eq.rhs) + @test operation(args[1]) == Base.:< + @test operation(args[2]) === ifelse + condvars = ModelingToolkit.vars(arguments(args[2])[1]) + @test length(condvars) == 1 && any(isequal(only(condvars)), ps) + @test operation(args[3]) === ifelse + condvars = ModelingToolkit.vars(arguments(args[3])[1]) + @test length(condvars) == 1 && any(isequal(only(condvars)), ps) + end + @testset "Observed variables are modified" begin + idx = findfirst(eq -> isequal(eq.lhs, ss.c), obs) + eq = obs[idx] + @test operation(eq.rhs) === Base.:! && any(isequal(only(arguments(eq.rhs))), ps) + idx = findfirst(eq -> isequal(eq.lhs, ss.y), obs) + eq = obs[idx] + @test operation(eq.rhs) === ifelse + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 677f40c717..4a2ec80e6a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -83,6 +83,7 @@ end @safetestset "JumpSystem Test" include("jumpsystem.jl") @safetestset "print_tree" include("print_tree.jl") @safetestset "Constraints Test" include("constraints.jl") + @safetestset "IfLifting Test" include("if_lifting.jl") end end From e4d5710e74edb0e20b11190f16fcde79d8df97ed Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 18 Dec 2024 11:32:27 +0530 Subject: [PATCH 13/38] fix: retain system metadata when calling `flatten` --- src/systems/diffeqs/odesystem.jl | 1 + .../discrete_system/discrete_system.jl | 1 + src/systems/nonlinear/nonlinearsystem.jl | 1 + .../optimization/optimizationsystem.jl | 1 + test/components.jl | 36 +++++++++++++++++++ 5 files changed, 40 insertions(+) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 2b0bd8c8d7..34003f40c2 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -409,6 +409,7 @@ function flatten(sys::ODESystem, noeqs = false) initialization_eqs = initialization_equations(sys), is_dde = is_dde(sys), tstops = symbolic_tstops(sys), + metadata = get_metadata(sys), checks = false) end end diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 3e220998cb..01fca30235 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -227,6 +227,7 @@ function flatten(sys::DiscreteSystem, noeqs = false) defaults = defaults(sys), name = nameof(sys), description = description(sys), + metadata = get_metadata(sys), checks = false) end end diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 3cb68853aa..b2abac5184 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -859,6 +859,7 @@ function flatten(sys::NonlinearSystem, noeqs = false) defaults = defaults(sys), name = nameof(sys), description = description(sys), + metadata = get_metadata(sys), checks = false) end end diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 43e9294dd3..0398c892eb 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -184,6 +184,7 @@ function flatten(sys::OptimizationSystem) constraints = constraints(sys), defaults = defaults(sys), name = nameof(sys), + metadata = get_metadata(sys), checks = false ) end diff --git a/test/components.jl b/test/components.jl index 298f9fceb9..8ac40f6fbb 100644 --- a/test/components.jl +++ b/test/components.jl @@ -370,3 +370,39 @@ end ss = structural_simplify(cbar) @test isequal(cbar.foo.x, ss.foo.x) end + +@testset "Issue#3275: Metadata retained on `complete`" begin + @variables x(t) y(t) + @testset "ODESystem" begin + @named inner = ODESystem(D(x) ~ x, t) + @named outer = ODESystem(D(y) ~ y, t; systems = [inner], metadata = "test") + @test ModelingToolkit.get_metadata(outer) == "test" + sys = complete(outer) + @test ModelingToolkit.get_metadata(sys) == "test" + end + @testset "NonlinearSystem" begin + @named inner = NonlinearSystem([0 ~ x^2 + 4x + 4], [x], []) + @named outer = NonlinearSystem( + [0 ~ x^3 - y^3], [x, y], []; systems = [inner], metadata = "test") + @test ModelingToolkit.get_metadata(outer) == "test" + sys = complete(outer) + @test ModelingToolkit.get_metadata(sys) == "test" + end + k = ShiftIndex(t) + @testset "DiscreteSystem" begin + @named inner = DiscreteSystem([x(k) ~ x(k - 1) + x(k - 2)], t, [x], []) + @named outer = DiscreteSystem([y(k) ~ y(k - 1) + y(k - 2)], t, [x, y], + []; systems = [inner], metadata = "test") + @test ModelingToolkit.get_metadata(outer) == "test" + sys = complete(outer) + @test ModelingToolkit.get_metadata(sys) == "test" + end + @testset "OptimizationSystem" begin + @named inner = OptimizationSystem(x^2 + y^2 - 3, [x, y], []) + @named outer = OptimizationSystem( + x^3 - y, [x, y], []; systems = [inner], metadata = "test") + @test ModelingToolkit.get_metadata(outer) == "test" + sys = complete(outer) + @test ModelingToolkit.get_metadata(sys) == "test" + end +end From 4792360ad15eff9bc1a77fc5245bfaaec46c96a0 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 18 Dec 2024 14:49:50 -0100 Subject: [PATCH 14/38] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 0b56e53805..3c6126b7c0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ModelingToolkit" uuid = "961ee093-0014-501f-94e3-6117800e7a78" authors = ["Yingbo Ma ", "Chris Rackauckas and contributors"] -version = "9.58.0" +version = "9.59.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" From 219aee396ed4d29606fcb3fa6922bb140091a0e6 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 1 Dec 2024 13:07:15 +0530 Subject: [PATCH 15/38] refactor: add guesses to `SDESystem`, `NonlinearSystem`, `JumpSystem` --- src/systems/diffeqs/odesystem.jl | 27 +++------ src/systems/diffeqs/sdesystem.jl | 52 ++++++++++++----- .../discrete_system/discrete_system.jl | 47 +++++++++++---- src/systems/jumps/jumpsystem.jl | 43 ++++++++++---- src/systems/nonlinear/nonlinearsystem.jl | 57 ++++++++++++++----- .../optimization/constraints_system.jl | 4 +- .../optimization/optimizationsystem.jl | 4 +- src/utils.jl | 27 +++++++++ 8 files changed, 187 insertions(+), 74 deletions(-) diff --git a/src/systems/diffeqs/odesystem.jl b/src/systems/diffeqs/odesystem.jl index 34003f40c2..61f16fd926 100644 --- a/src/systems/diffeqs/odesystem.jl +++ b/src/systems/diffeqs/odesystem.jl @@ -256,29 +256,16 @@ function ODESystem(deqs::AbstractVector{<:Equation}, iv, dvs, ps; :ODESystem, force = true) end defaults = Dict{Any, Any}(todict(defaults)) + guesses = Dict{Any, Any}(todict(guesses)) var_to_name = Dict() - process_variables!(var_to_name, defaults, dvs′) - process_variables!(var_to_name, defaults, ps′) - process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies]) - process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies]) + process_variables!(var_to_name, defaults, guesses, dvs′) + process_variables!(var_to_name, defaults, guesses, ps′) + process_variables!( + var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies]) + process_variables!( + var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies]) defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults) if v !== nothing) - - sysdvsguesses = [ModelingToolkit.getguess(st) for st in dvs′] - hasaguess = findall(!isnothing, sysdvsguesses) - var_guesses = dvs′[hasaguess] .=> sysdvsguesses[hasaguess] - sysdvsguesses = isempty(var_guesses) ? Dict() : todict(var_guesses) - syspsguesses = [ModelingToolkit.getguess(st) for st in ps′] - hasaguess = findall(!isnothing, syspsguesses) - ps_guesses = ps′[hasaguess] .=> syspsguesses[hasaguess] - syspsguesses = isempty(ps_guesses) ? Dict() : todict(ps_guesses) - syspdepguesses = [ModelingToolkit.getguess(eq.lhs) for eq in parameter_dependencies] - hasaguess = findall(!isnothing, syspdepguesses) - pdep_guesses = [eq.lhs for eq in parameter_dependencies][hasaguess] .=> - syspdepguesses[hasaguess] - syspdepguesses = isempty(pdep_guesses) ? Dict() : todict(pdep_guesses) - - guesses = merge(sysdvsguesses, syspsguesses, syspdepguesses, todict(guesses)) guesses = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(guesses) if v !== nothing) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index ac47f4c45c..d604863024 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -93,6 +93,19 @@ struct SDESystem <: AbstractODESystem """ defaults::Dict """ + The guesses to use as the initial conditions for the + initialization system. + """ + guesses::Dict + """ + The system for performing the initialization. + """ + initializesystem::Union{Nothing, NonlinearSystem} + """ + Extra equations to be enforced during the initialization sequence. + """ + initialization_eqs::Vector{Equation} + """ Type of the system. """ connector_type::Any @@ -144,9 +157,8 @@ struct SDESystem <: AbstractODESystem isscheduled::Bool function SDESystem(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, - tgrad, - jac, - ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type, + tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, + guesses, initializesystem, initialization_eqs, connector_type, cevents, devents, parameter_dependencies, metadata = nothing, gui_metadata = nothing, complete = false, index_cache = nothing, parent = nothing, is_scalar_noise = false, is_dde = false, @@ -171,9 +183,9 @@ struct SDESystem <: AbstractODESystem check_units(u, deqs, neqs) end new(tag, deqs, neqs, iv, dvs, ps, tspan, var_to_name, ctrls, observed, tgrad, jac, - ctrl_jac, - Wfact, Wfact_t, name, description, systems, - defaults, connector_type, cevents, devents, + ctrl_jac, Wfact, Wfact_t, name, description, systems, + defaults, guesses, initializesystem, initialization_eqs, connector_type, cevents, + devents, parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise, is_dde, isscheduled) end @@ -187,6 +199,9 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv default_u0 = Dict(), default_p = Dict(), defaults = _merge(Dict(default_u0), Dict(default_p)), + guesses = Dict(), + initializesystem = nothing, + initialization_eqs = Equation[], name = nothing, description = "", connector_type = nothing, @@ -207,6 +222,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv dvs′ = value.(dvs) ps′ = value.(ps) ctrl′ = value.(controls) + parameter_dependencies, ps′ = process_parameter_dependencies( + parameter_dependencies, ps′) sysnames = nameof.(systems) if length(unique(sysnames)) != length(sysnames) @@ -217,13 +234,21 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv "`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :SDESystem, force = true) end - defaults = todict(defaults) - defaults = Dict(value(k) => value(v) - for (k, v) in pairs(defaults) if value(v) !== nothing) + defaults = Dict{Any, Any}(todict(defaults)) + guesses = Dict{Any, Any}(todict(guesses)) var_to_name = Dict() - process_variables!(var_to_name, defaults, dvs′) - process_variables!(var_to_name, defaults, ps′) + process_variables!(var_to_name, defaults, guesses, dvs′) + process_variables!(var_to_name, defaults, guesses, ps′) + process_variables!( + var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies]) + process_variables!( + var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies]) + defaults = Dict{Any, Any}(value(k) => value(v) + for (k, v) in pairs(defaults) if v !== nothing) + guesses = Dict{Any, Any}(value(k) => value(v) + for (k, v) in pairs(guesses) if v !== nothing) + isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed)) tgrad = RefValue(EMPTY_TGRAD) @@ -233,14 +258,13 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs::AbstractArray, iv, dv Wfact_t = RefValue(EMPTY_JAC) cont_callbacks = SymbolicContinuousCallbacks(continuous_events) disc_callbacks = SymbolicDiscreteCallbacks(discrete_events) - parameter_dependencies, ps′ = process_parameter_dependencies( - parameter_dependencies, ps′) if is_dde === nothing is_dde = _check_if_dde(deqs, iv′, systems) end SDESystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), deqs, neqs, iv′, dvs′, ps′, tspan, var_to_name, ctrl′, observed, tgrad, jac, - ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, connector_type, + ctrl_jac, Wfact, Wfact_t, name, description, systems, defaults, guesses, + initializesystem, initialization_eqs, connector_type, cont_callbacks, disc_callbacks, parameter_dependencies, metadata, gui_metadata, complete, index_cache, parent, is_scalar_noise, is_dde; checks = checks) end diff --git a/src/systems/discrete_system/discrete_system.jl b/src/systems/discrete_system/discrete_system.jl index 01fca30235..7458237333 100644 --- a/src/systems/discrete_system/discrete_system.jl +++ b/src/systems/discrete_system/discrete_system.jl @@ -55,6 +55,19 @@ struct DiscreteSystem <: AbstractTimeDependentSystem """ defaults::Dict """ + The guesses to use as the initial conditions for the + initialization system. + """ + guesses::Dict + """ + The system for performing the initialization. + """ + initializesystem::Union{Nothing, NonlinearSystem} + """ + Extra equations to be enforced during the initialization sequence. + """ + initialization_eqs::Vector{Equation} + """ Inject assignment statements before the evaluation of the RHS function. """ preface::Any @@ -98,9 +111,8 @@ struct DiscreteSystem <: AbstractTimeDependentSystem isscheduled::Bool function DiscreteSystem(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, - observed, - name, description, - systems, defaults, preface, connector_type, parameter_dependencies = Equation[], + observed, name, description, systems, defaults, guesses, initializesystem, + initialization_eqs, preface, connector_type, parameter_dependencies = Equation[], metadata = nothing, gui_metadata = nothing, tearing_state = nothing, substitutions = nothing, complete = false, index_cache = nothing, parent = nothing, @@ -116,8 +128,7 @@ struct DiscreteSystem <: AbstractTimeDependentSystem check_units(u, discreteEqs) end new(tag, discreteEqs, iv, dvs, ps, tspan, var_to_name, observed, name, description, - systems, - defaults, + systems, defaults, guesses, initializesystem, initialization_eqs, preface, connector_type, parameter_dependencies, metadata, gui_metadata, tearing_state, substitutions, complete, index_cache, parent, isscheduled) end @@ -135,6 +146,9 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps; description = "", default_u0 = Dict(), default_p = Dict(), + guesses = Dict(), + initializesystem = nothing, + initialization_eqs = Equation[], defaults = _merge(Dict(default_u0), Dict(default_p)), preface = nothing, connector_type = nothing, @@ -155,13 +169,21 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps; "`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :DiscreteSystem, force = true) end - defaults = todict(defaults) - defaults = Dict(value(k) => value(v) - for (k, v) in pairs(defaults) if value(v) !== nothing) + defaults = Dict{Any, Any}(todict(defaults)) + guesses = Dict{Any, Any}(todict(guesses)) var_to_name = Dict() - process_variables!(var_to_name, defaults, dvs′) - process_variables!(var_to_name, defaults, ps′) + process_variables!(var_to_name, defaults, guesses, dvs′) + process_variables!(var_to_name, defaults, guesses, ps′) + process_variables!( + var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies]) + process_variables!( + var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies]) + defaults = Dict{Any, Any}(value(k) => value(v) + for (k, v) in pairs(defaults) if v !== nothing) + guesses = Dict{Any, Any}(value(k) => value(v) + for (k, v) in pairs(guesses) if v !== nothing) + isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed)) sysnames = nameof.(systems) @@ -170,7 +192,8 @@ function DiscreteSystem(eqs::AbstractVector{<:Equation}, iv, dvs, ps; end DiscreteSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), eqs, iv′, dvs′, ps′, tspan, var_to_name, observed, name, description, systems, - defaults, preface, connector_type, parameter_dependencies, metadata, gui_metadata, kwargs...) + defaults, guesses, initializesystem, initialization_eqs, preface, connector_type, + parameter_dependencies, metadata, gui_metadata, kwargs...) end function DiscreteSystem(eqs, iv; kwargs...) @@ -225,6 +248,8 @@ function flatten(sys::DiscreteSystem, noeqs = false) parameters(sys), observed = observed(sys), defaults = defaults(sys), + guesses = guesses(sys), + initialization_eqs = initialization_equations(sys), name = nameof(sys), description = description(sys), metadata = get_metadata(sys), diff --git a/src/systems/jumps/jumpsystem.jl b/src/systems/jumps/jumpsystem.jl index e5e17fb5f9..efc5a9be7d 100644 --- a/src/systems/jumps/jumpsystem.jl +++ b/src/systems/jumps/jumpsystem.jl @@ -84,6 +84,19 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem """ defaults::Dict """ + The guesses to use as the initial conditions for the + initialization system. + """ + guesses::Dict + """ + The system for performing the initialization. + """ + initializesystem::Union{Nothing, NonlinearSystem} + """ + Extra equations to be enforced during the initialization sequence. + """ + initialization_eqs::Vector{Equation} + """ Type of the system. """ connector_type::Any @@ -125,8 +138,9 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem function JumpSystem{U}( tag, ap::U, iv, unknowns, ps, var_to_name, observed, name, description, - systems, defaults, connector_type, cevents, devents, parameter_dependencies, - metadata = nothing, gui_metadata = nothing, + systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, + cevents, devents, + parameter_dependencies, metadata = nothing, gui_metadata = nothing, complete = false, index_cache = nothing, isscheduled = false; checks::Union{Bool, Int} = true) where {U <: ArrayPartition} if checks == true || (checks & CheckComponents) > 0 @@ -139,7 +153,8 @@ struct JumpSystem{U <: ArrayPartition} <: AbstractTimeDependentSystem check_units(u, ap, iv) end new{U}(tag, ap, iv, unknowns, ps, var_to_name, - observed, name, description, systems, defaults, + observed, name, description, systems, defaults, guesses, initializesystem, + initialization_eqs, connector_type, cevents, devents, parameter_dependencies, metadata, gui_metadata, complete, index_cache, isscheduled) end @@ -154,6 +169,9 @@ function JumpSystem(eqs, iv, unknowns, ps; default_u0 = Dict(), default_p = Dict(), defaults = _merge(Dict(default_u0), Dict(default_p)), + guesses = Dict(), + initializesystem = nothing, + initialization_eqs = Equation[], name = nothing, description = "", connector_type = nothing, @@ -179,13 +197,17 @@ function JumpSystem(eqs, iv, unknowns, ps; :JumpSystem, force = true) end defaults = Dict{Any, Any}(todict(defaults)) + guesses = Dict{Any, Any}(todict(guesses)) var_to_name = Dict() - process_variables!(var_to_name, defaults, us′) - process_variables!(var_to_name, defaults, ps′) - process_variables!(var_to_name, defaults, [eq.lhs for eq in parameter_dependencies]) - process_variables!(var_to_name, defaults, [eq.rhs for eq in parameter_dependencies]) + process_variables!(var_to_name, defaults, guesses, us′) + process_variables!(var_to_name, defaults, guesses, ps′) + process_variables!( + var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies]) + process_variables!( + var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies]) #! format: off defaults = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(defaults) if value(v) !== nothing) + guesses = Dict{Any, Any}(value(k) => value(v) for (k, v) in pairs(guesses) if v !== nothing) #! format: on isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed)) @@ -219,8 +241,9 @@ function JumpSystem(eqs, iv, unknowns, ps; JumpSystem{typeof(ap)}(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), ap, iv′, us′, ps′, var_to_name, observed, name, description, systems, - defaults, connector_type, cont_callbacks, disc_callbacks, parameter_dependencies, - metadata, gui_metadata, checks = checks) + defaults, guesses, initializesystem, initialization_eqs, connector_type, + cont_callbacks, disc_callbacks, + parameter_dependencies, metadata, gui_metadata, checks = checks) end ##### MTK dispatches for JumpSystems ##### @@ -494,7 +517,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi if has_equations(sys) osys = ODESystem(equations(sys).x[4], get_iv(sys), unknowns(sys), parameters(sys); observed = observed(sys), name = nameof(sys), description = description(sys), - systems = get_systems(sys), defaults = defaults(sys), + systems = get_systems(sys), defaults = defaults(sys), guesses = guesses(sys), parameter_dependencies = parameter_dependencies(sys), metadata = get_metadata(sys), gui_metadata = get_gui_metadata(sys)) osys = complete(osys) diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index b2abac5184..46d39822bf 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -57,6 +57,19 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem """ defaults::Dict """ + The guesses to use as the initial conditions for the + initialization system. + """ + guesses::Dict + """ + The system for performing the initialization. + """ + initializesystem::Union{Nothing, NonlinearSystem} + """ + Extra equations to be enforced during the initialization sequence. + """ + initialization_eqs::Vector{Equation} + """ Type of the system. """ connector_type::Any @@ -97,9 +110,8 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem function NonlinearSystem( tag, eqs, unknowns, ps, var_to_name, observed, jac, name, description, - systems, - defaults, connector_type, parameter_dependencies = Equation[], metadata = nothing, - gui_metadata = nothing, + systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, + parameter_dependencies = Equation[], metadata = nothing, gui_metadata = nothing, tearing_state = nothing, substitutions = nothing, complete = false, index_cache = nothing, parent = nothing, isscheduled = false; checks::Union{Bool, Int} = true) @@ -107,8 +119,8 @@ struct NonlinearSystem <: AbstractTimeIndependentSystem u = __get_unit_type(unknowns, ps) check_units(u, eqs) end - new(tag, eqs, unknowns, ps, var_to_name, observed, - jac, name, description, systems, defaults, + new(tag, eqs, unknowns, ps, var_to_name, observed, jac, name, description, + systems, defaults, guesses, initializesystem, initialization_eqs, connector_type, parameter_dependencies, metadata, gui_metadata, tearing_state, substitutions, complete, index_cache, parent, isscheduled) end @@ -121,6 +133,9 @@ function NonlinearSystem(eqs, unknowns, ps; default_u0 = Dict(), default_p = Dict(), defaults = _merge(Dict(default_u0), Dict(default_p)), + guesses = Dict(), + initializesystem = nothing, + initialization_eqs = Equation[], systems = NonlinearSystem[], connector_type = nothing, continuous_events = nothing, # this argument is only required for ODESystems, but is added here for the constructor to accept it without error @@ -151,21 +166,32 @@ function NonlinearSystem(eqs, unknowns, ps; eqs = [wrap(eq.lhs) isa Symbolics.Arr ? eq : 0 ~ eq.rhs - eq.lhs for eq in eqs] jac = RefValue{Any}(EMPTY_JAC) - defaults = todict(defaults) - defaults = Dict{Any, Any}(value(k) => value(v) - for (k, v) in pairs(defaults) if value(v) !== nothing) - unknowns, ps = value.(unknowns), value.(ps) + ps′ = value.(ps) + dvs′ = value.(unknowns) + parameter_dependencies, ps′ = process_parameter_dependencies( + parameter_dependencies, ps′) + + defaults = Dict{Any, Any}(todict(defaults)) + guesses = Dict{Any, Any}(todict(guesses)) var_to_name = Dict() - process_variables!(var_to_name, defaults, unknowns) - process_variables!(var_to_name, defaults, ps) + process_variables!(var_to_name, defaults, guesses, dvs′) + process_variables!(var_to_name, defaults, guesses, ps′) + process_variables!( + var_to_name, defaults, guesses, [eq.lhs for eq in parameter_dependencies]) + process_variables!( + var_to_name, defaults, guesses, [eq.rhs for eq in parameter_dependencies]) + defaults = Dict{Any, Any}(value(k) => value(v) + for (k, v) in pairs(defaults) if v !== nothing) + guesses = Dict{Any, Any}(value(k) => value(v) + for (k, v) in pairs(guesses) if v !== nothing) + isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed)) - parameter_dependencies, ps = process_parameter_dependencies( - parameter_dependencies, ps) NonlinearSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), - eqs, unknowns, ps, var_to_name, observed, jac, name, description, systems, defaults, - connector_type, parameter_dependencies, metadata, gui_metadata, checks = checks) + eqs, dvs′, ps′, var_to_name, observed, jac, name, description, systems, defaults, + guesses, initializesystem, initialization_eqs, connector_type, parameter_dependencies, + metadata, gui_metadata, checks = checks) end function NonlinearSystem(eqs; kwargs...) @@ -857,6 +883,7 @@ function flatten(sys::NonlinearSystem, noeqs = false) parameters(sys), observed = observed(sys), defaults = defaults(sys), + guesses = guesses(sys), name = nameof(sys), description = description(sys), metadata = get_metadata(sys), diff --git a/src/systems/optimization/constraints_system.jl b/src/systems/optimization/constraints_system.jl index a2756994ac..03225fc900 100644 --- a/src/systems/optimization/constraints_system.jl +++ b/src/systems/optimization/constraints_system.jl @@ -143,8 +143,8 @@ function ConstraintsSystem(constraints, unknowns, ps; for (k, v) in pairs(defaults) if value(v) !== nothing) var_to_name = Dict() - process_variables!(var_to_name, defaults, unknowns′) - process_variables!(var_to_name, defaults, ps′) + process_variables!(var_to_name, defaults, Dict(), unknowns′) + process_variables!(var_to_name, defaults, Dict(), ps′) isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed)) ConstraintsSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 0398c892eb..0b20fdef79 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -132,8 +132,8 @@ function OptimizationSystem(op, unknowns, ps; for (k, v) in pairs(defaults) if value(v) !== nothing) var_to_name = Dict() - process_variables!(var_to_name, defaults, unknowns′) - process_variables!(var_to_name, defaults, ps′) + process_variables!(var_to_name, defaults, Dict(), unknowns′) + process_variables!(var_to_name, defaults, Dict(), ps′) isempty(observed) || collect_var_to_name!(var_to_name, (eq.lhs for eq in observed)) OptimizationSystem(Threads.atomic_add!(SYSTEM_COUNT, UInt(1)), diff --git a/src/utils.jl b/src/utils.jl index e9ddad3a07..c3011c2a79 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -244,6 +244,13 @@ function setdefault(v, val) val === nothing ? v : wrap(setdefaultval(unwrap(v), value(val))) end +function process_variables!(var_to_name, defs, guesses, vars) + collect_defaults!(defs, vars) + collect_guesses!(guesses, vars) + collect_var_to_name!(var_to_name, vars) + return nothing +end + function process_variables!(var_to_name, defs, vars) collect_defaults!(defs, vars) collect_var_to_name!(var_to_name, vars) @@ -261,6 +268,17 @@ function collect_defaults!(defs, vars) return defs end +function collect_guesses!(guesses, vars) + for v in vars + symbolic_type(v) == NotSymbolic() && continue + if haskey(guesses, v) || !hasguess(unwrap(v)) || (def = getguess(v)) === nothing + continue + end + guesses[v] = getguess(v) + end + return guesses +end + function collect_var_to_name!(vars, xs) for x in xs symbolic_type(x) == NotSymbolic() && continue @@ -1146,3 +1164,12 @@ function similar_variable(var::BasicSymbolic, name = :anon) end return sym end + +function guesses_from_metadata!(guesses, vars) + varguesses = [getguess(v) for v in vars] + hasaguess = findall(!isnothing, varguesses) + for i in hasaguess + haskey(guesses, vars[i]) && continue + guesses[vars[i]] = varguesses[i] + end +end From af8cd679571d6c2d078fbebdd9219536b270f335 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 1 Dec 2024 13:07:33 +0530 Subject: [PATCH 16/38] feat: support arbitrary systems in `generate_initializesystem` --- src/systems/nonlinear/initializesystem.jl | 65 +++++++++++++---------- 1 file changed, 38 insertions(+), 27 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 2344727920..229d462911 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -3,7 +3,7 @@ $(TYPEDSIGNATURES) Generate `NonlinearSystem` which initializes an ODE problem from specified initial conditions of an `ODESystem`. """ -function generate_initializesystem(sys::ODESystem; +function generate_initializesystem(sys::AbstractSystem; u0map = Dict(), pmap = Dict(), initialization_eqs = [], @@ -12,28 +12,36 @@ function generate_initializesystem(sys::ODESystem; algebraic_only = false, check_units = true, check_defguess = false, name = nameof(sys), extra_metadata = (;), kwargs...) - trueobs, eqs = unhack_observed(observed(sys), equations(sys)) + eqs = equations(sys) + eqs = filter(x -> x isa Equation, eqs) + trueobs, eqs = unhack_observed(observed(sys), eqs) vars = unique([unknowns(sys); getfield.(trueobs, :lhs)]) vars_set = Set(vars) # for efficient in-lookup - idxs_diff = isdiffeq.(eqs) - idxs_alge = .!idxs_diff - - # prepare map for dummy derivative substitution - eqs_diff = eqs[idxs_diff] - D = Differential(get_iv(sys)) - diffmap = merge( - Dict(eq.lhs => eq.rhs for eq in eqs_diff), - Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs) - ) - - # 1) process dummy derivatives and u0map into initialization system - eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations + eqs_ics = Equation[] defs = copy(defaults(sys)) # copy so we don't modify sys.defaults additional_guesses = anydict(guesses) guesses = merge(get_guesses(sys), additional_guesses) - schedule = getfield(sys, :schedule) - if !isnothing(schedule) + idxs_diff = isdiffeq.(eqs) + + # 1) Use algebraic equations of time-dependent systems as initialization constraints + if has_iv(sys) + idxs_alge = .!idxs_diff + append!(eqs_ics, eqs[idxs_alge]) # start equation list with algebraic equations + + eqs_diff = eqs[idxs_diff] + D = Differential(get_iv(sys)) + diffmap = merge( + Dict(eq.lhs => eq.rhs for eq in eqs_diff), + Dict(D(eq.lhs) => D(eq.rhs) for eq in trueobs) + ) + else + diffmap = Dict() + end + + if has_schedule(sys) && (schedule = get_schedule(sys); !isnothing(schedule)) + # 2) process dummy derivatives and u0map into initialization system + # prepare map for dummy derivative substitution for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub) # set dummy derivatives to default_dd_guess unless specified push!(defs, x[1] => get(guesses, x[1], default_dd_guess)) @@ -61,9 +69,14 @@ function generate_initializesystem(sys::ODESystem; process_u0map_with_dummysubs(y, x) end end + else + # 2) System doesn't have a schedule, so dummy derivatives don't exist/aren't handled (SDESystem) + for (k, v) in u0map + defs[k] = v + end end - # 2) process other variables + # 3) process other variables for var in vars if var ∈ keys(defs) push!(eqs_ics, var ~ defs[var]) @@ -74,7 +87,7 @@ function generate_initializesystem(sys::ODESystem; end end - # 3) process explicitly provided initialization equations + # 4) process explicitly provided initialization equations if !algebraic_only initialization_eqs = [get_initialization_eqs(sys); initialization_eqs] for eq in initialization_eqs @@ -83,7 +96,7 @@ function generate_initializesystem(sys::ODESystem; end end - # 4) process parameters as initialization unknowns + # 5) process parameters as initialization unknowns paramsubs = Dict() if pmap isa SciMLBase.NullParameters pmap = Dict() @@ -138,7 +151,7 @@ function generate_initializesystem(sys::ODESystem; end end - # 5) parameter dependencies become equations, their LHS become unknowns + # 6) parameter dependencies become equations, their LHS become unknowns # non-numeric dependent parameters stay as parameter dependencies new_parameter_deps = Equation[] for eq in parameter_dependencies(sys) @@ -153,7 +166,7 @@ function generate_initializesystem(sys::ODESystem; push!(defs, varp => guessval) end - # 6) handle values provided for dependent parameters similar to values for observed variables + # 7) handle values provided for dependent parameters similar to values for observed variables for (k, v) in merge(defaults(sys), pmap) if is_variable_floatingpoint(k) && has_parameter_dependency_with_lhs(sys, k) push!(eqs_ics, paramsubs[k] ~ v) @@ -161,12 +174,10 @@ function generate_initializesystem(sys::ODESystem; end # parameters do not include ones that became initialization unknowns - pars = vcat( - [get_iv(sys)], # include independent variable as pseudo-parameter - [p for p in parameters(sys) if !haskey(paramsubs, p)] - ) + pars = Vector{SymbolicParam}(filter(p -> !haskey(paramsubs, p), parameters(sys))) + is_time_dependent(sys) && push!(pars, get_iv(sys)) - # 7) use observed equations for guesses of observed variables if not provided + # 8) use observed equations for guesses of observed variables if not provided for eq in trueobs haskey(defs, eq.lhs) && continue any(x -> isequal(default_toterm(x), eq.lhs), keys(defs)) && continue From 39281945c17108c2b4a3434d89a4e57911905927 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 11:32:03 +0530 Subject: [PATCH 17/38] refactor: use `initialization_data` in SciMLFunction constructors --- src/systems/diffeqs/abstractodesystem.jl | 28 ++++++++--------------- src/systems/diffeqs/sdesystem.jl | 6 ++--- src/systems/nonlinear/initializesystem.jl | 8 +------ src/systems/nonlinear/nonlinearsystem.jl | 12 ++++++---- 4 files changed, 21 insertions(+), 33 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index f4e29346ff..7d9369d441 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -359,10 +359,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, sparsity = false, analytic = nothing, split_idxs = nothing, - initializeprob = nothing, - update_initializeprob! = nothing, - initializeprobmap = nothing, - initializeprobpmap = nothing, + initialization_data = nothing, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `ODEFunction`") @@ -463,10 +460,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, observed = observedfun, sparsity = sparsity ? jacobian_sparsity(sys) : nothing, analytic = analytic, - initializeprob = initializeprob, - update_initializeprob! = update_initializeprob!, - initializeprobmap = initializeprobmap, - initializeprobpmap = initializeprobpmap) + initialization_data) end """ @@ -496,10 +490,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) sparse = false, simplify = false, eval_module = @__MODULE__, checkbounds = false, - initializeprob = nothing, - initializeprobmap = nothing, - initializeprobpmap = nothing, - update_initializeprob! = nothing, + initialization_data = nothing, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating a `DAEFunction`") @@ -547,15 +538,12 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) nothing end - DAEFunction{iip}(f, + DAEFunction{iip}(f; sys = sys, jac = _jac === nothing ? nothing : _jac, jac_prototype = jac_prototype, observed = observedfun, - initializeprob = initializeprob, - initializeprobmap = initializeprobmap, - initializeprobpmap = initializeprobpmap, - update_initializeprob! = update_initializeprob!) + initialization_data) end function DiffEqBase.DDEFunction(sys::AbstractODESystem, args...; kwargs...) @@ -567,6 +555,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) eval_expression = false, eval_module = @__MODULE__, checkbounds = false, + initialization_data = nothing, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `DDEFunction`") @@ -579,7 +568,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys) f(u, h, p, t) = f_oop(u, h, p, t) f(du, u, h, p, t) = f_iip(du, u, h, p, t) - DDEFunction{iip}(f, sys = sys) + DDEFunction{iip}(f; sys = sys, initialization_data) end function DiffEqBase.SDDEFunction(sys::AbstractODESystem, args...; kwargs...) @@ -591,6 +580,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys eval_expression = false, eval_module = @__MODULE__, checkbounds = false, + initialization_data = nothing, kwargs...) where {iip} if !iscomplete(sys) error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `SDDEFunction`") @@ -609,7 +599,7 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys g(u, h, p, t) = g_oop(u, h, p, t) g(du, u, h, p, t) = g_iip(du, u, h, p, t) - SDDEFunction{iip}(f, g, sys = sys) + SDDEFunction{iip}(f, g; sys = sys, initialization_data) end """ diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index d604863024..37e743218a 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -544,7 +544,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns( version = nothing, tgrad = false, sparse = false, jac = false, Wfact = false, eval_expression = false, eval_module = @__MODULE__, - checkbounds = false, + checkbounds = false, initialization_data = nothing, kwargs...) where {iip, specialize} if !iscomplete(sys) error("A completed `SDESystem` is required. Call `complete` or `structural_simplify` on the system before creating an `SDEFunction`") @@ -615,13 +615,13 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns( observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) - SDEFunction{iip, specialize}(f, g, + SDEFunction{iip, specialize}(f, g; sys = sys, jac = _jac === nothing ? nothing : _jac, tgrad = _tgrad === nothing ? nothing : _tgrad, Wfact = _Wfact === nothing ? nothing : _Wfact, Wfact_t = _Wfact_t === nothing ? nothing : _Wfact_t, - mass_matrix = _M, + mass_matrix = _M, initialization_data, observed = observedfun) end diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 229d462911..624ee1bd71 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -346,13 +346,7 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, u0map, pmap, defs, cmap, dvs, ps) kws = maybe_build_initialization_problem( sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; use_scc) - initprob = get(kws, :initializeprob, nothing) - if initprob === nothing - return nothing - end - return SciMLBase.OverrideInitData(initprob, get(kws, :update_initializeprob!, nothing), - get(kws, :initializeprobmap, nothing), - get(kws, :initializeprobpmap, nothing)) + return get(kws, :initialization_data, nothing) end """ diff --git a/src/systems/nonlinear/nonlinearsystem.jl b/src/systems/nonlinear/nonlinearsystem.jl index 46d39822bf..8cd8175668 100644 --- a/src/systems/nonlinear/nonlinearsystem.jl +++ b/src/systems/nonlinear/nonlinearsystem.jl @@ -344,6 +344,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s eval_expression = false, eval_module = @__MODULE__, sparse = false, simplify = false, + initialization_data = nothing, kwargs...) where {iip} if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `NonlinearFunction`") @@ -376,14 +377,14 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s resid_prototype = calculate_resid_prototype(length(equations(sys)), u0, p) end - NonlinearFunction{iip}(f, + NonlinearFunction{iip}(f; sys = sys, jac = _jac === nothing ? nothing : _jac, resid_prototype = resid_prototype, jac_prototype = sparse ? similar(calculate_jacobian(sys, sparse = sparse), Float64) : nothing, - observed = observedfun) + observed = observedfun, initialization_data) end """ @@ -395,7 +396,8 @@ respectively. """ function SciMLBase.IntervalNonlinearFunction( sys::NonlinearSystem, dvs = unknowns(sys), ps = parameters(sys), u0 = nothing; - p = nothing, eval_expression = false, eval_module = @__MODULE__, kwargs...) + p = nothing, eval_expression = false, eval_module = @__MODULE__, + initialization_data = nothing, kwargs...) if !iscomplete(sys) error("A completed `NonlinearSystem` is required. Call `complete` or `structural_simplify` on the system before creating a `IntervalNonlinearFunction`") end @@ -411,7 +413,8 @@ function SciMLBase.IntervalNonlinearFunction( observedfun = ObservedFunctionCache(sys; eval_expression, eval_module) - IntervalNonlinearFunction{false}(f; observed = observedfun, sys = sys) + IntervalNonlinearFunction{false}( + f; observed = observedfun, sys = sys, initialization_data) end """ @@ -884,6 +887,7 @@ function flatten(sys::NonlinearSystem, noeqs = false) observed = observed(sys), defaults = defaults(sys), guesses = guesses(sys), + initialization_eqs = initialization_equations(sys), name = nameof(sys), description = description(sys), metadata = get_metadata(sys), From 4044317b17b3f239f0a4fa6a87f16f368f8f28aa Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 11:33:14 +0530 Subject: [PATCH 18/38] fix: don't build initializeprob for initializeprob --- src/systems/diffeqs/abstractodesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 7d9369d441..b3d4903191 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1395,5 +1395,5 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, else NonlinearLeastSquaresProblem end - TProb(isys, u0map, parammap; kwargs...) + TProb(isys, u0map, parammap; kwargs..., build_initializeprob = false) end From 180b97801da7d134754c16becc9daebbfb688912 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 11:34:02 +0530 Subject: [PATCH 19/38] feat: build initialization system for all system types in `process_SciMLProblem` --- src/systems/problem_utils.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 6b75fb2c6d..6ce5704daa 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -540,8 +540,9 @@ function maybe_build_initialization_problem( end if (((implicit_dae || has_observed_u0s || !isempty(missing_unknowns) || !isempty(solvablepars) || has_dependent_unknowns) && - get_tearing_state(sys) !== nothing) || - !isempty(initialization_equations(sys))) && t !== nothing + (!has_tearing_state(sys) || get_tearing_state(sys) !== nothing)) || + !isempty(initialization_equations(sys))) && + (!is_time_dependent(sys) || t !== nothing) initializeprob = ModelingToolkit.InitializationProblem( sys, t, u0map, pmap; guesses, kwargs...) initializeprobmap = getu(initializeprob, unknowns(sys)) @@ -567,7 +568,9 @@ function maybe_build_initialization_problem( end empty!(missing_unknowns) return (; - initializeprob, initializeprobmap, initializeprobpmap, update_initializeprob!) + initialization_data = SciMLBase.OverrideInitData( + initializeprob, update_initializeprob!, initializeprobmap, + initializeprobpmap)) end return (;) end @@ -662,7 +665,7 @@ function process_SciMLProblem( op, missing_unknowns, missing_pars = build_operating_point( u0map, pmap, defs, cmap, dvs, ps) - if sys isa ODESystem && build_initializeprob + if build_initializeprob kws = maybe_build_initialization_problem( sys, op, u0map, pmap, t, defs, guesses, missing_unknowns; implicit_dae, warn_initialize_determined, initialization_eqs, From c9c613f314cddb9ade459920cbb102a3b57b22da Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 2 Dec 2024 17:18:07 +0530 Subject: [PATCH 20/38] fix: retain system data on `structural_simplify` of `SDESystem` --- src/systems/systems.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/systems/systems.jl b/src/systems/systems.jl index 47acd81a82..04c50bc766 100644 --- a/src/systems/systems.jl +++ b/src/systems/systems.jl @@ -164,6 +164,7 @@ function __structural_simplify(sys::AbstractSystem, io = nothing; simplify = fal return SDESystem(Vector{Equation}(full_equations(ode_sys)), noise_eqs, get_iv(ode_sys), unknowns(ode_sys), parameters(ode_sys); name = nameof(ode_sys), is_scalar_noise, observed = observed(ode_sys), defaults = defaults(sys), - parameter_dependencies = parameter_dependencies(sys)) + parameter_dependencies = parameter_dependencies(sys), + guesses = guesses(sys), initialization_eqs = initialization_equations(sys)) end end From 4d5daa3caeaad6ee4ed6f9698194aa467d1c4ea4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:32:11 +0530 Subject: [PATCH 21/38] fix: pass `t` to `process_SciMLProblem` in `SDEProblem` --- src/systems/diffeqs/sdesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 37e743218a..5e0d0e3208 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -738,7 +738,7 @@ function DiffEqBase.SDEProblem{iip, specialize}( end f, u0, p = process_SciMLProblem( SDEFunction{iip, specialize}, sys, u0map, parammap; check_length, - kwargs...) + t = tspan === nothing ? nothing : tspan[1], kwargs...) cbs = process_events(sys; callback, kwargs...) sparsenoise === nothing && (sparsenoise = get(kwargs, :sparse, false)) From 8d09409a011a4d91dfe6d22d81b255f9a0da753f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:33:03 +0530 Subject: [PATCH 22/38] feat: support arbitrary systems in `remake_initialization_data` --- src/systems/nonlinear/initializesystem.jl | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index 624ee1bd71..e32f967f4a 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -223,16 +223,19 @@ function is_parameter_solvable(p, pmap, defs, guesses) _val1 === nothing && _val2 !== nothing)) && _val3 !== nothing end -function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newu0, newp) +function SciMLBase.remake_initialization_data( + sys::AbstractSystem, odefn, u0, t0, p, newu0, newp) if u0 === missing && p === missing return odefn.initialization_data end if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair) - oldinitprob = odefn.initializeprob + oldinitdata = odefn.initialization_data + oldinitdata === nothing && return nothing + + oldinitprob = oldinitdata.initializeprob oldinitprob === nothing && return nothing if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem) - return SciMLBase.OverrideInitData(oldinitprob, odefn.update_initializeprob!, - odefn.initializeprobmap, odefn.initializeprobpmap) + return oldinitdata end pidxs = ParameterIndex[] pvals = [] @@ -254,7 +257,7 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, if p !== missing for sym in parameter_symbols(oldinitprob) push!(pidxs, parameter_index(oldinitprob, sym)) - if isequal(sym, get_iv(sys)) + if is_time_dependent(sys) && isequal(sym, get_iv(sys)) push!(pvals, t0) else push!(pvals, getp(sys, sym)(p)) @@ -283,8 +286,8 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, length(oldinitprob.f.resid_prototype), newu0, newp)) end initprob = remake(oldinitprob; f = newf, u0 = newu0, p = newp) - return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!, - odefn.initializeprobmap, odefn.initializeprobpmap) + return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!, + oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap) end dvs = unknowns(sys) ps = parameters(sys) @@ -298,7 +301,7 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, use_scc = true if SciMLBase.has_initializeprob(odefn) - oldsys = odefn.initializeprob.f.sys + oldsys = odefn.initialization_data.initializeprob.f.sys meta = get_metadata(oldsys) if meta isa InitializationSystemMetadata u0map = merge(meta.u0map, u0map) @@ -336,7 +339,7 @@ function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, pmap[p] = getp(sys, p)(newp) end end - if t0 === nothing + if t0 === nothing && is_time_dependent(sys) t0 = 0.0 end filter_missing_values!(u0map) From def207b92a83baf5363ebaded009c5b95706d645 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:33:13 +0530 Subject: [PATCH 23/38] fix: fix type promotion bug in `remake_buffer` --- src/systems/parameter_buffer.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/systems/parameter_buffer.jl b/src/systems/parameter_buffer.jl index 7d64054acc..bc4e62a773 100644 --- a/src/systems/parameter_buffer.jl +++ b/src/systems/parameter_buffer.jl @@ -574,6 +574,10 @@ function _remake_buffer(indp, oldbuf::MTKParameters, idxs, vals; validate = true @set! newbuf.tunable = narrow_buffer_type_and_fallback_undefs( oldbuf.tunable, newbuf.tunable) + if eltype(newbuf.tunable) <: Integer + T = promote_type(eltype(newbuf.tunable), Float64) + @set! newbuf.tunable = T.(newbuf.tunable) + end @set! newbuf.discrete = narrow_buffer_type_and_fallback_undefs.( oldbuf.discrete, newbuf.discrete) @set! newbuf.constant = narrow_buffer_type_and_fallback_undefs.( From d971b185ad5d4d9f1714448131b2bb02aff58b2d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 3 Dec 2024 16:33:51 +0530 Subject: [PATCH 24/38] test: test initialization on `SDEProblem`, `DDEProblem`, `SDDEProblem` --- test/initializationsystem.jl | 428 ++++++++++++++++++++--------------- 1 file changed, 247 insertions(+), 181 deletions(-) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 9c06dd2030..045b7d2d2c 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -1,4 +1,5 @@ using ModelingToolkit, OrdinaryDiffEq, NonlinearSolve, Test +using StochasticDiffEq, DelayDiffEq, StochasticDelayDiffEq using ForwardDiff using SymbolicIndexingInterface, SciMLStructures using SciMLStructures: Tunable @@ -583,108 +584,123 @@ sol = solve(oprob_2nd_order_2, Rosenbrock23()) # retcode: Success end @testset "Initialization of parameters" begin - function test_parameter(prob, sym, val, initialval = zero(val)) - @test prob.ps[sym] ≈ initialval - @test init(prob, Tsit5()).ps[sym] ≈ val - @test solve(prob, Tsit5()).ps[sym] ≈ val - end - function test_initializesystem(sys, u0map, pmap, p, equation) - isys = ModelingToolkit.generate_initializesystem( - sys; u0map, pmap, guesses = ModelingToolkit.guesses(sys)) - @test is_variable(isys, p) - @test equation in equations(isys) || (0 ~ -equation.rhs) in equations(isys) - end - @variables x(t) y(t) + @variables _x(..) y(t) @parameters p q - u0map = Dict(x => 1.0, y => 1.0) - pmap = Dict() - pmap[q] = 1.0 - # `missing` default, equation from ODEProblem - @mtkbuild sys = ODESystem( - [D(x) ~ x * q, D(y) ~ y * p], t; defaults = [p => missing], guesses = [p => 1.0]) - pmap[p] = 2q - prob = ODEProblem(sys, u0map, (0.0, 1.0), pmap) - test_parameter(prob, p, 2.0) - prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 - test_parameter(prob2, p, 2.0) - # `missing` default, provided guess - @mtkbuild sys = ODESystem( - [D(x) ~ x, p ~ x + y], t; defaults = [p => missing], guesses = [p => 0.0]) - prob = ODEProblem(sys, u0map, (0.0, 1.0)) - test_parameter(prob, p, 2.0) - test_initializesystem(sys, u0map, pmap, p, 0 ~ p - x - y) - prob2 = remake(prob; u0 = u0map) - prob2.ps[p] = 0.0 - test_parameter(prob2, p, 2.0) - - # `missing` to ODEProblem, equation from default - @mtkbuild sys = ODESystem( - [D(x) ~ x * q, D(y) ~ y * p], t; defaults = [p => 2q], guesses = [p => 1.0]) - pmap[p] = missing - prob = ODEProblem(sys, u0map, (0.0, 1.0), pmap) - test_parameter(prob, p, 2.0) - test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p) - prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 - test_parameter(prob2, p, 2.0) - # `missing` to ODEProblem, provided guess - @mtkbuild sys = ODESystem( - [D(x) ~ x, p ~ x + y], t; guesses = [p => 0.0]) - prob = ODEProblem(sys, u0map, (0.0, 1.0), pmap) - test_parameter(prob, p, 2.0) - test_initializesystem(sys, u0map, pmap, p, 0 ~ x + y - p) - prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 - test_parameter(prob2, p, 2.0) - - # No `missing`, default and guess - @mtkbuild sys = ODESystem( - [D(x) ~ x * q, D(y) ~ y * p], t; defaults = [p => 2q], guesses = [p => 0.0]) - delete!(pmap, p) - prob = ODEProblem(sys, u0map, (0.0, 1.0), pmap) - test_parameter(prob, p, 2.0) - test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p) - prob2 = remake(prob; u0 = u0map, p = pmap) - prob2.ps[p] = 0.0 - test_parameter(prob2, p, 2.0) - - # Default overridden by ODEProblem, guess provided - @mtkbuild sys = ODESystem( - [D(x) ~ q * x, D(y) ~ y * p], t; defaults = [p => 2q], guesses = [p => 1.0]) - _pmap = merge(pmap, Dict(p => q)) - prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap) - test_parameter(prob, p, _pmap[q]) - test_initializesystem(sys, u0map, _pmap, p, 0 ~ q - p) - - # ODEProblem dependent value with guess, no `missing` - @mtkbuild sys = ODESystem([D(x) ~ x * q, D(y) ~ y * p], t; guesses = [p => 0.0]) - _pmap = merge(pmap, Dict(p => 3q)) - prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap) - test_parameter(prob, p, 3pmap[q]) - - # Should not be solved for: - - # Override dependent default with direct value - @mtkbuild sys = ODESystem( - [D(x) ~ q * x, D(y) ~ y * p], t; defaults = [p => 2q], guesses = [p => 1.0]) - _pmap = merge(pmap, Dict(p => 1.0)) - prob = ODEProblem(sys, u0map, (0.0, 1.0), _pmap) - @test prob.ps[p] ≈ 1.0 - @test prob.f.initializeprob === nothing - - # Non-floating point - @parameters r::Int s::Int - @mtkbuild sys = ODESystem( - [D(x) ~ s * x, D(y) ~ y * r], t; defaults = [s => 2r], guesses = [s => 1.0]) - prob = ODEProblem(sys, u0map, (0.0, 1.0), [r => 1]) - @test prob.ps[r] == 1 - @test prob.ps[s] == 2 - @test prob.f.initializeprob === nothing - - @mtkbuild sys = ODESystem([D(x) ~ x, p ~ x + y], t; guesses = [p => 0.0]) - @test_throws ModelingToolkit.MissingParametersError ODEProblem( - sys, [x => 1.0, y => 1.0], (0.0, 1.0)) + @brownian a b + x = _x(t) + + # `System` constructor creates appropriate type with mtkbuild + # `Problem` and `alg` create the problem to test and allow calling `init` with + # the correct solver. + # `rhss` allows adding terms to the end of equations (only 2 equations allowed) to influence + # the system type (brownian vars to turn it into an SDE). + @testset "$Problem" for (Problem, alg, rhss) in [ + (ODEProblem, Tsit5(), zeros(2)), (SDEProblem, ImplicitEM(), [a, b]), + (DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), + (SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b])] + function test_parameter(prob, sym, val, initialval = zero(val)) + @test prob.ps[sym] ≈ initialval + @test init(prob, alg).ps[sym] ≈ val + @test solve(prob, alg).ps[sym] ≈ val + end + function test_initializesystem(sys, u0map, pmap, p, equation) + isys = ModelingToolkit.generate_initializesystem( + sys; u0map, pmap, guesses = ModelingToolkit.guesses(sys)) + @test is_variable(isys, p) + @test equation in equations(isys) || (0 ~ -equation.rhs) in equations(isys) + end + u0map = Dict(x => 1.0, y => 1.0) + pmap = Dict() + pmap[q] = 1.0 + # `missing` default, equation from Problem + @mtkbuild sys = System( + [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => missing], guesses = [p => 1.0]) + pmap[p] = 2q + prob = Problem(sys, u0map, (0.0, 1.0), pmap) + test_parameter(prob, p, 2.0) + prob2 = remake(prob; u0 = u0map, p = pmap) + prob2.ps[p] = 0.0 + test_parameter(prob2, p, 2.0) + # `missing` default, provided guess + @mtkbuild sys = System( + [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; defaults = [p => missing], guesses = [p => 0.0]) + prob = Problem(sys, u0map, (0.0, 1.0)) + test_parameter(prob, p, 2.0) + test_initializesystem(sys, u0map, pmap, p, 0 ~ p - x - y) + prob2 = remake(prob; u0 = u0map) + prob2.ps[p] = 0.0 + test_parameter(prob2, p, 2.0) + + # `missing` to Problem, equation from default + @mtkbuild sys = System( + [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0]) + pmap[p] = missing + prob = Problem(sys, u0map, (0.0, 1.0), pmap) + test_parameter(prob, p, 2.0) + test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p) + prob2 = remake(prob; u0 = u0map, p = pmap) + prob2.ps[p] = 0.0 + test_parameter(prob2, p, 2.0) + # `missing` to Problem, provided guess + @mtkbuild sys = System( + [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; guesses = [p => 0.0]) + prob = Problem(sys, u0map, (0.0, 1.0), pmap) + test_parameter(prob, p, 2.0) + test_initializesystem(sys, u0map, pmap, p, 0 ~ x + y - p) + prob2 = remake(prob; u0 = u0map, p = pmap) + prob2.ps[p] = 0.0 + test_parameter(prob2, p, 2.0) + + # No `missing`, default and guess + @mtkbuild sys = System( + [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 0.0]) + delete!(pmap, p) + prob = Problem(sys, u0map, (0.0, 1.0), pmap) + test_parameter(prob, p, 2.0) + test_initializesystem(sys, u0map, pmap, p, 0 ~ 2q - p) + prob2 = remake(prob; u0 = u0map, p = pmap) + prob2.ps[p] = 0.0 + test_parameter(prob2, p, 2.0) + + # Default overridden by Problem, guess provided + @mtkbuild sys = System( + [D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0]) + _pmap = merge(pmap, Dict(p => q)) + prob = Problem(sys, u0map, (0.0, 1.0), _pmap) + test_parameter(prob, p, _pmap[q]) + test_initializesystem(sys, u0map, _pmap, p, 0 ~ q - p) + + # Problem dependent value with guess, no `missing` + @mtkbuild sys = System( + [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; guesses = [p => 0.0]) + _pmap = merge(pmap, Dict(p => 3q)) + prob = Problem(sys, u0map, (0.0, 1.0), _pmap) + test_parameter(prob, p, 3pmap[q]) + + # Should not be solved for: + + # Override dependent default with direct value + @mtkbuild sys = System( + [D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0]) + _pmap = merge(pmap, Dict(p => 1.0)) + prob = Problem(sys, u0map, (0.0, 1.0), _pmap) + @test prob.ps[p] ≈ 1.0 + @test prob.f.initialization_data === nothing + + # Non-floating point + @parameters r::Int s::Int + @mtkbuild sys = System( + [D(x) ~ s * x + rhss[1], D(y) ~ y * r + rhss[2]], t; defaults = [s => 2r], guesses = [s => 1.0]) + prob = Problem(sys, u0map, (0.0, 1.0), [r => 1]) + @test prob.ps[r] == 1 + @test prob.ps[s] == 2 + @test prob.f.initialization_data === nothing + + @mtkbuild sys = System( + [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; guesses = [p => 0.0]) + @test_throws ModelingToolkit.MissingParametersError Problem( + sys, [x => 1.0, y => 1.0], (0.0, 1.0)) + end @testset "Null system" begin @variables x(t) y(t) s(t) @@ -718,103 +734,153 @@ end end @testset "Update initializeprob parameters" begin - @variables x(t) y(t) + @variables _x(..) y(t) @parameters p q - @mtkbuild sys = ODESystem( - [D(x) ~ x, p ~ x + y], t; guesses = [x => 0.0, p => 0.0]) - prob = ODEProblem(sys, [y => 1.0], (0.0, 1.0), [p => 3.0]) - @test prob.f.initializeprob.ps[p] ≈ 3.0 - @test init(prob, Tsit5())[x] ≈ 2.0 - prob.ps[p] = 2.0 - @test prob.f.initializeprob.ps[p] ≈ 3.0 - @test init(prob, Tsit5())[x] ≈ 1.0 - ModelingToolkit.defaults(prob.f.sys)[p] = missing - prob2 = remake(prob; u0 = [y => 1.0], p = [p => 3x]) - @test !is_variable(prob2.f.initializeprob, p) && - !is_parameter(prob2.f.initializeprob, p) - @test init(prob2, Tsit5())[x] ≈ 0.5 - @test_nowarn solve(prob2, Tsit5()) + @brownian a b + x = _x(t) + + @testset "$Problem" for (Problem, alg, rhss) in [ + (ODEProblem, Tsit5(), zeros(2)), (SDEProblem, ImplicitEM(), [a, b]), + (DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), + (SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b])] + @mtkbuild sys = System( + [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; guesses = [x => 0.0, p => 0.0]) + prob = Problem(sys, [y => 1.0], (0.0, 1.0), [p => 3.0]) + @test prob.f.initialization_data.initializeprob.ps[p] ≈ 3.0 + @test init(prob, alg)[x] ≈ 2.0 + prob.ps[p] = 2.0 + @test prob.f.initialization_data.initializeprob.ps[p] ≈ 3.0 + @test init(prob, alg)[x] ≈ 1.0 + ModelingToolkit.defaults(prob.f.sys)[p] = missing + prob2 = remake(prob; u0 = [y => 1.0], p = [p => 3x]) + @test !is_variable(prob2.f.initialization_data.initializeprob, p) && + !is_parameter(prob2.f.initialization_data.initializeprob, p) + @test init(prob2, alg)[x] ≈ 0.5 + @test_nowarn solve(prob2, alg) + end end @testset "Equations for dependent parameters" begin - @variables x(t) + @variables _x(..) @parameters p q=5 r - @mtkbuild sys = ODESystem( - D(x) ~ 2x + r, t; parameter_dependencies = [r ~ p + 2q, q ~ p + 3], - guesses = [p => 1.0]) - prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => missing]) - @test length(equations(ModelingToolkit.get_parent(prob.f.initializeprob.f.sys))) == 4 - integ = init(prob, Tsit5()) - @test integ.ps[p] ≈ 2 + @brownian a + x = _x(t) + + @testset "$Problem" for (Problem, alg, rhss) in [ + (ODEProblem, Tsit5(), 0), (SDEProblem, ImplicitEM(), a), + (DDEProblem, MethodOfSteps(Tsit5()), _x(t - 0.1)), + (SDDEProblem, ImplicitEM(), _x(t - 0.1) + a)] + @mtkbuild sys = System( + [D(x) ~ 2x + r + rhss], t; parameter_dependencies = [r ~ p + 2q, q ~ p + 3], + guesses = [p => 1.0]) + prob = Problem(sys, [x => 1.0], (0.0, 1.0), [p => missing]) + @test length(equations(ModelingToolkit.get_parent(prob.f.initialization_data.initializeprob.f.sys))) == + 4 + integ = init(prob, alg) + @test integ.ps[p] ≈ 2 + end end @testset "Re-creating initialization problem on remake" begin - @variables x(t) y(t) + @variables _x(..) y(t) @parameters p q - @mtkbuild sys = ODESystem( - [D(x) ~ x, p ~ x + y], t; defaults = [p => missing], guesses = [x => 0.0, p => 0.0]) - prob = ODEProblem(sys, [x => 1.0, y => 1.0], (0.0, 1.0)) - @test init(prob, Tsit5()).ps[p] ≈ 2.0 - # nonsensical value for y just to test that equations work - prob2 = remake(prob; u0 = [x => 1.0, y => 2x + exp(t)]) - @test init(prob2, Tsit5()).ps[p] ≈ 4.0 - # solve for `x` given `p` and `y` - prob3 = remake(prob; u0 = [x => nothing, y => 1.0], p = [p => 2x + exp(t)]) - @test init(prob3, Tsit5())[x] ≈ 0.0 - @test_logs (:warn, r"overdetermined") remake( - prob; u0 = [x => 1.0, y => 2.0], p = [p => 4.0]) - prob4 = remake(prob; u0 = [x => 1.0, y => 2.0], p = [p => 4.0]) - @test solve(prob4, Tsit5()).retcode == ReturnCode.InitialFailure - prob5 = remake(prob) - @test init(prob, Tsit5()).ps[p] ≈ 2.0 + @brownian a b + x = _x(t) + + @testset "$Problem" for (Problem, alg, rhss) in [ + (ODEProblem, Tsit5(), zeros(2)), (SDEProblem, ImplicitEM(), [a, b]), + (DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), + (SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b])] + @mtkbuild sys = System( + [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; defaults = [p => missing], guesses = [ + x => 0.0, p => 0.0]) + prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0)) + @test init(prob, alg).ps[p] ≈ 2.0 + # nonsensical value for y just to test that equations work + prob2 = remake(prob; u0 = [x => 1.0, y => 2x + exp(t)]) + @test init(prob2, alg).ps[p] ≈ 4.0 + # solve for `x` given `p` and `y` + prob3 = remake(prob; u0 = [x => nothing, y => 1.0], p = [p => 2x + exp(t)]) + @test init(prob3, alg)[x] ≈ 0.0 + @test_logs (:warn, r"overdetermined") remake( + prob; u0 = [x => 1.0, y => 2.0], p = [p => 4.0]) + prob4 = remake(prob; u0 = [x => 1.0, y => 2.0], p = [p => 4.0]) + @test solve(prob4, alg).retcode == ReturnCode.InitialFailure + prob5 = remake(prob) + @test init(prob, alg).ps[p] ≈ 2.0 + end end @testset "`remake` changes initialization problem types" begin - @variables x(t) y(t) z(t) + @variables _x(..) y(t) z(t) @parameters p q - @mtkbuild sys = ODESystem( - [D(x) ~ x * p + y * q, y^2 * q + q^2 * x ~ 0, z * p - p^2 * x * z ~ 0], - t; guesses = [x => 0.0, y => 0.0, z => 0.0, p => 0.0, q => 0.0]) - prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0, q => missing]) - @test is_variable(prob.f.initializeprob, q) - ps = prob.p - newps = SciMLStructures.replace(Tunable(), ps, ForwardDiff.Dual.(ps.tunable)) - prob2 = remake(prob; p = newps) - @test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual - @test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual - @test prob2.f.initializeprob.u0 ≈ prob.f.initializeprob.u0 - - prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0)) - @test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual - @test eltype(prob2.f.initializeprob.p.tunable) <: Float64 - @test prob2.f.initializeprob.u0 ≈ prob.f.initializeprob.u0 - - prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0), p = newps) - @test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual - @test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual - @test prob2.f.initializeprob.u0 ≈ prob.f.initializeprob.u0 - - prob2 = remake(prob; u0 = [x => ForwardDiff.Dual(1.0)], - p = [p => ForwardDiff.Dual(1.0), q => missing]) - @test eltype(prob2.f.initializeprob.u0) <: ForwardDiff.Dual - @test eltype(prob2.f.initializeprob.p.tunable) <: ForwardDiff.Dual - @test prob2.f.initializeprob.u0 ≈ prob.f.initializeprob.u0 + @brownian a + x = _x(t) + + @testset "$Problem" for (Problem, alg, rhss) in [ + (ODEProblem, Tsit5(), 0), (SDEProblem, ImplicitEM(), a), + (DDEProblem, MethodOfSteps(Tsit5()), _x(t - 0.1)), + (SDDEProblem, ImplicitEM(), _x(t - 0.1) + a)] + @mtkbuild sys = System( + [D(x) ~ x * p + y * q + rhss, y^2 * q + q^2 * x ~ 0, z * p - p^2 * x * z ~ 0], + t; guesses = [x => 0.0, y => 0.0, z => 0.0, p => 0.0, q => 0.0]) + prob = Problem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0, q => missing]) + @test is_variable(prob.f.initialization_data.initializeprob, q) + ps = prob.p + newps = SciMLStructures.replace(Tunable(), ps, ForwardDiff.Dual.(ps.tunable)) + prob2 = remake(prob; p = newps) + @test eltype(prob2.f.initialization_data.initializeprob.u0) <: ForwardDiff.Dual + @test eltype(prob2.f.initialization_data.initializeprob.p.tunable) <: + ForwardDiff.Dual + @test prob2.f.initialization_data.initializeprob.u0 ≈ + prob.f.initialization_data.initializeprob.u0 + + prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0)) + @test eltype(prob2.f.initialization_data.initializeprob.u0) <: ForwardDiff.Dual + @test eltype(prob2.f.initialization_data.initializeprob.p.tunable) <: Float64 + @test prob2.f.initialization_data.initializeprob.u0 ≈ + prob.f.initialization_data.initializeprob.u0 + + prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0), p = newps) + @test eltype(prob2.f.initialization_data.initializeprob.u0) <: ForwardDiff.Dual + @test eltype(prob2.f.initialization_data.initializeprob.p.tunable) <: + ForwardDiff.Dual + @test prob2.f.initialization_data.initializeprob.u0 ≈ + prob.f.initialization_data.initializeprob.u0 + + prob2 = remake(prob; u0 = [x => ForwardDiff.Dual(1.0)], + p = [p => ForwardDiff.Dual(1.0), q => missing]) + @test eltype(prob2.f.initialization_data.initializeprob.u0) <: ForwardDiff.Dual + @test eltype(prob2.f.initialization_data.initializeprob.p.tunable) <: + ForwardDiff.Dual + @test prob2.f.initialization_data.initializeprob.u0 ≈ + prob.f.initialization_data.initializeprob.u0 + end end @testset "`remake` preserves old u0map and pmap" begin - @variables x(t) y(t) + @variables _x(..) y(t) @parameters p - @mtkbuild sys = ODESystem( - [D(x) ~ x + p * y, y^2 + 4y * p^2 ~ x], t; guesses = [y => 1.0, p => 1.0]) - prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0]) - @test is_variable(prob.f.initializeprob, y) - prob2 = @test_nowarn remake(prob; p = [p => 3.0]) # ensure no over/under-determined warning - @test is_variable(prob.f.initializeprob, y) - - prob = ODEProblem(sys, [y => 1.0, x => 2.0], (0.0, 1.0), [p => missing]) - @test is_variable(prob.f.initializeprob, p) - prob2 = @test_nowarn remake(prob; u0 = [y => 0.5]) - @test is_variable(prob.f.initializeprob, p) + @brownian a + x = _x(t) + + @testset "$Problem" for (Problem, alg, rhss) in [ + (ODEProblem, Tsit5(), 0), (SDEProblem, ImplicitEM(), a), + (DDEProblem, MethodOfSteps(Tsit5()), _x(t - 0.1)), + (SDDEProblem, ImplicitEM(), _x(t - 0.1) + a)] + @mtkbuild sys = System( + [D(x) ~ x + p * y + rhss, y^2 + 4y * p^2 ~ x], t; guesses = [ + y => 1.0, p => 1.0]) + prob = Problem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0]) + @test is_variable(prob.f.initialization_data.initializeprob, y) + prob2 = @test_nowarn remake(prob; p = [p => 3.0]) # ensure no over/under-determined warning + @test is_variable(prob.f.initialization_data.initializeprob, y) + + prob = Problem(sys, [y => 1.0, x => 2.0], (0.0, 1.0), [p => missing]) + @test is_variable(prob.f.initialization_data.initializeprob, p) + prob2 = @test_nowarn remake(prob; u0 = [y => 0.5]) + @test is_variable(prob.f.initialization_data.initializeprob, p) + end end struct Multiplier{T} From 39bb59c3ccfcd8efee174829c58b2d546864cea9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 14:12:35 +0530 Subject: [PATCH 25/38] fix: handle integer `u0` in `DDEProblem` --- src/systems/diffeqs/abstractodesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index b3d4903191..40c44b7bae 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -923,7 +923,7 @@ function DiffEqBase.DDEProblem{iip}(sys::AbstractODESystem, u0map = [], h_oop, h_iip = eval_or_rgf.(h_gen; eval_expression, eval_module) h(p, t) = h_oop(p, t) h(p::MTKParameters, t) = h_oop(p..., t) - u0 = h(p, tspan[1]) + u0 = float.(h(p, tspan[1])) if u0 !== nothing u0 = u0_constructor(u0) end From 2f2e62524f42d521e69b692b2801d5ca5e90cfc1 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 14:13:00 +0530 Subject: [PATCH 26/38] feat: enable creating `InitializationProblem` for non-`AbstractODESystem`s --- src/systems/diffeqs/abstractodesystem.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 40c44b7bae..a6607e0c55 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1247,11 +1247,11 @@ Generates a NonlinearProblem or NonlinearLeastSquaresProblem from an ODESystem which represents the initialization, i.e. the calculation of the consistent initial conditions for the given DAE. """ -function InitializationProblem(sys::AbstractODESystem, args...; kwargs...) +function InitializationProblem(sys::AbstractSystem, args...; kwargs...) InitializationProblem{true}(sys, args...; kwargs...) end -function InitializationProblem(sys::AbstractODESystem, t, +function InitializationProblem(sys::AbstractSystem, t, u0map::StaticArray, args...; kwargs...) @@ -1259,11 +1259,11 @@ function InitializationProblem(sys::AbstractODESystem, t, sys, t, u0map, args...; kwargs...) end -function InitializationProblem{true}(sys::AbstractODESystem, args...; kwargs...) +function InitializationProblem{true}(sys::AbstractSystem, args...; kwargs...) InitializationProblem{true, SciMLBase.AutoSpecialize}(sys, args...; kwargs...) end -function InitializationProblem{false}(sys::AbstractODESystem, args...; kwargs...) +function InitializationProblem{false}(sys::AbstractSystem, args...; kwargs...) InitializationProblem{false, SciMLBase.FullSpecialize}(sys, args...; kwargs...) end @@ -1282,8 +1282,8 @@ function Base.showerror(io::IO, e::IncompleteInitializationError) println(io, e.uninit) end -function InitializationProblem{iip, specialize}(sys::AbstractODESystem, - t::Number, u0map = [], +function InitializationProblem{iip, specialize}(sys::AbstractSystem, + t, u0map = [], parammap = DiffEqBase.NullParameters(); guesses = [], check_length = true, @@ -1347,13 +1347,13 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, @warn "Initialization system is underdetermined. $neqs equations for $nunknown unknowns. Initialization will default to using least squares. $(scc_message)To suppress this warning pass warn_initialize_determined = false. To make this warning into an error, pass fully_determined = true" end - parammap = parammap isa DiffEqBase.NullParameters || isempty(parammap) ? - [get_iv(sys) => t] : - merge(todict(parammap), Dict(get_iv(sys) => t)) - parammap = Dict(k => v for (k, v) in parammap if v !== missing) - if isempty(u0map) - u0map = Dict() + parammap = recursive_unwrap(anydict(parammap)) + if t !== nothing + parammap[get_iv(sys)] = t end + filter!(kvp -> kvp[2] !== missing, parammap) + + u0map = to_varmap(u0map, unknowns(sys)) if isempty(guesses) guesses = Dict() end From 671b93feea590a1a7756fd469f9426cc6105766d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 14:13:10 +0530 Subject: [PATCH 27/38] fix: filter kwargs in `SDEProblem` --- src/systems/diffeqs/sdesystem.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/systems/diffeqs/sdesystem.jl b/src/systems/diffeqs/sdesystem.jl index 5e0d0e3208..0d600cfb5f 100644 --- a/src/systems/diffeqs/sdesystem.jl +++ b/src/systems/diffeqs/sdesystem.jl @@ -760,6 +760,8 @@ function DiffEqBase.SDEProblem{iip, specialize}( noise = nothing end + kwargs = filter_kwargs(kwargs) + SDEProblem{iip}(f, u0, tspan, p; callback = cbs, noise, noise_rate_prototype = noise_rate_prototype, kwargs...) end From 9987da0b19e5e2c41aed9af214a215e7dc9b714a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 14:15:05 +0530 Subject: [PATCH 28/38] test: test initialization on `NonlinearProblem` and `NonlinearLeastSquaresProblem` --- test/initializationsystem.jl | 269 ++++++++++++++++++++++++++++------- 1 file changed, 218 insertions(+), 51 deletions(-) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 045b7d2d2c..2745b9d81f 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -583,6 +583,14 @@ sol = solve(oprob_2nd_order_2, Rosenbrock23()) # retcode: Success @test all(sol(1.0, idxs = sys.x) .≈ +exp(1)) && all(sol(1.0, idxs = sys.y) .≈ -exp(1)) end +NonlinearSystemWrapper(eqs, t; kws...) = NonlinearSystem(eqs; kws...) +function NonlinearProblemWrapper(sys, u0, tspan, args...; kwargs...) + NonlinearProblem(sys, u0, args...; kwargs...) +end +function NLLSProblemWrapper(sys, u0, tspan, args...; kwargs...) + NonlinearLeastSquaresProblem(sys, u0, args...; kwargs...) +end + @testset "Initialization of parameters" begin @variables _x(..) y(t) @parameters p q @@ -594,13 +602,37 @@ end # the correct solver. # `rhss` allows adding terms to the end of equations (only 2 equations allowed) to influence # the system type (brownian vars to turn it into an SDE). - @testset "$Problem" for (Problem, alg, rhss) in [ - (ODEProblem, Tsit5(), zeros(2)), (SDEProblem, ImplicitEM(), [a, b]), - (DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), - (SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b])] + @testset "$Problem with $(SciMLBase.parameterless_type(alg))" for (System, Problem, alg, rhss) in [ + (ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)), + (ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]), + (ModelingToolkit.System, DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), + (ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b]), + # polyalg cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, + FastShortcutNonlinearPolyalg(), zeros(2)), + # generalized first order cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, NewtonRaphson(), zeros(2)), + # quasi newton cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, Klement(), zeros(2)), + # noinit cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, SimpleNewtonRaphson(), zeros(2)), + # DFSane cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, DFSane(), zeros(2)), + # Least squares + # polyalg cache + (NonlinearSystemWrapper, NLLSProblemWrapper, FastShortcutNLLSPolyalg(), zeros(2)), + # generalized first order cache + (NonlinearSystemWrapper, NLLSProblemWrapper, LevenbergMarquardt(), zeros(2)), + # noinit cache + (NonlinearSystemWrapper, NLLSProblemWrapper, SimpleGaussNewton(), zeros(2)) + ] + is_nlsolve = alg isa SciMLBase.AbstractNonlinearAlgorithm + function test_parameter(prob, sym, val, initialval = zero(val)) @test prob.ps[sym] ≈ initialval - @test init(prob, alg).ps[sym] ≈ val + if !is_nlsolve || prob.u0 !== nothing + @test init(prob, alg).ps[sym] ≈ val + end @test solve(prob, alg).ps[sym] ≈ val end function test_initializesystem(sys, u0map, pmap, p, equation) @@ -609,6 +641,8 @@ end @test is_variable(isys, p) @test equation in equations(isys) || (0 ~ -equation.rhs) in equations(isys) end + D = is_nlsolve ? v -> v^3 : Differential(t) + u0map = Dict(x => 1.0, y => 1.0) pmap = Dict() pmap[q] = 1.0 @@ -669,16 +703,14 @@ end prob = Problem(sys, u0map, (0.0, 1.0), _pmap) test_parameter(prob, p, _pmap[q]) test_initializesystem(sys, u0map, _pmap, p, 0 ~ q - p) - # Problem dependent value with guess, no `missing` @mtkbuild sys = System( - [D(x) ~ x * q + rhss[1], D(y) ~ y * p + rhss[2]], t; guesses = [p => 0.0]) + [D(x) ~ y * q + p + rhss[1], D(y) ~ x * p + q + rhss[2]], t; guesses = [p => 0.0]) _pmap = merge(pmap, Dict(p => 3q)) prob = Problem(sys, u0map, (0.0, 1.0), _pmap) test_parameter(prob, p, 3pmap[q]) # Should not be solved for: - # Override dependent default with direct value @mtkbuild sys = System( [D(x) ~ q * x + rhss[1], D(y) ~ y * p + rhss[2]], t; defaults = [p => 2q], guesses = [p => 1.0]) @@ -700,6 +732,14 @@ end [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; guesses = [p => 0.0]) @test_throws ModelingToolkit.MissingParametersError Problem( sys, [x => 1.0, y => 1.0], (0.0, 1.0)) + + # Unsatisfiable initialization + prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0), + [p => 2.0]; initialization_eqs = [x^2 + y^2 ~ 3]) + @test prob.f.initialization_data !== nothing + @test solve(prob, alg).retcode == ReturnCode.InitialFailure + cache = init(prob, alg) + @test solve!(cache).retcode == ReturnCode.InitialFailure end @testset "Null system" begin @@ -707,7 +747,9 @@ end @parameters x0 y0 @mtkbuild sys = ODESystem([x ~ x0, y ~ y0, s ~ x + y], t; guesses = [y0 => 0.0]) prob = ODEProblem(sys, [s => 1.0], (0.0, 1.0), [x0 => 0.3, y0 => missing]) - test_parameter(prob, y0, 0.7) + @test prob.ps[y0] ≈ 0.0 + @test init(prob, Tsit5()).ps[y0] ≈ 0.7 + @test solve(prob, Tsit5()).ps[y0] ≈ 0.7 end using ModelingToolkitStandardLibrary.Mechanical.TranslationalModelica: Fixed, Mass, @@ -730,7 +772,9 @@ end systems = [fixed, spring, mass, gravity, constant, damper], guesses = [spring.s_rel0 => 1.0]) prob = ODEProblem(sys, [], (0.0, 1.0), [spring.s_rel0 => missing]) - test_parameter(prob, spring.s_rel0, -3.905) + @test prob.ps[spring.s_rel0] ≈ 0.0 + @test init(prob, Tsit5()).ps[spring.s_rel0] ≈ -3.905 + @test solve(prob, Tsit5()).ps[spring.s_rel0] ≈ -3.905 end @testset "Update initializeprob parameters" begin @@ -739,10 +783,33 @@ end @brownian a b x = _x(t) - @testset "$Problem" for (Problem, alg, rhss) in [ - (ODEProblem, Tsit5(), zeros(2)), (SDEProblem, ImplicitEM(), [a, b]), - (DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), - (SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b])] + @testset "$Problem with $(SciMLBase.parameterless_type(typeof(alg)))" for (System, Problem, alg, rhss) in [ + (ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)), + (ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]), + (ModelingToolkit.System, DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), + (ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b]), + # polyalg cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, + FastShortcutNonlinearPolyalg(), zeros(2)), + # generalized first order cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, NewtonRaphson(), zeros(2)), + # quasi newton cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, Klement(), zeros(2)), + # noinit cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, SimpleNewtonRaphson(), zeros(2)), + # DFSane cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, DFSane(), zeros(2)), + # Least squares + # polyalg cache + (NonlinearSystemWrapper, NLLSProblemWrapper, FastShortcutNLLSPolyalg(), zeros(2)), + # generalized first order cache + (NonlinearSystemWrapper, NLLSProblemWrapper, LevenbergMarquardt(), zeros(2)), + # noinit cache + (NonlinearSystemWrapper, NLLSProblemWrapper, SimpleGaussNewton(), zeros(2)) + ] + is_nlsolve = alg isa SciMLBase.AbstractNonlinearAlgorithm + D = is_nlsolve ? v -> v^3 : Differential(t) + @mtkbuild sys = System( [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; guesses = [x => 0.0, p => 0.0]) prob = Problem(sys, [y => 1.0], (0.0, 1.0), [p => 3.0]) @@ -766,10 +833,33 @@ end @brownian a x = _x(t) - @testset "$Problem" for (Problem, alg, rhss) in [ - (ODEProblem, Tsit5(), 0), (SDEProblem, ImplicitEM(), a), - (DDEProblem, MethodOfSteps(Tsit5()), _x(t - 0.1)), - (SDDEProblem, ImplicitEM(), _x(t - 0.1) + a)] + @testset "$Problem with $(SciMLBase.parameterless_type(typeof(alg)))" for (System, Problem, alg, rhss) in [ + (ModelingToolkit.System, ODEProblem, Tsit5(), 0), + (ModelingToolkit.System, SDEProblem, ImplicitEM(), a), + (ModelingToolkit.System, DDEProblem, MethodOfSteps(Tsit5()), _x(t - 0.1)), + (ModelingToolkit.System, SDDEProblem, ImplicitEM(), _x(t - 0.1) + a), + # polyalg cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, + FastShortcutNonlinearPolyalg(), 0), + # generalized first order cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, NewtonRaphson(), 0), + # quasi newton cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, Klement(), 0), + # noinit cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, SimpleNewtonRaphson(), 0), + # DFSane cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, DFSane(), 0), + # Least squares + # polyalg cache + (NonlinearSystemWrapper, NLLSProblemWrapper, FastShortcutNLLSPolyalg(), 0), + # generalized first order cache + (NonlinearSystemWrapper, NLLSProblemWrapper, LevenbergMarquardt(), 0), + # noinit cache + (NonlinearSystemWrapper, NLLSProblemWrapper, SimpleGaussNewton(), 0) + ] + is_nlsolve = alg isa SciMLBase.AbstractNonlinearAlgorithm + D = is_nlsolve ? v -> v^3 : Differential(t) + @mtkbuild sys = System( [D(x) ~ 2x + r + rhss], t; parameter_dependencies = [r ~ p + 2q, q ~ p + 3], guesses = [p => 1.0]) @@ -787,21 +877,44 @@ end @brownian a b x = _x(t) - @testset "$Problem" for (Problem, alg, rhss) in [ - (ODEProblem, Tsit5(), zeros(2)), (SDEProblem, ImplicitEM(), [a, b]), - (DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), - (SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b])] + @testset "$Problem with $(SciMLBase.parameterless_type(typeof(alg)))" for (System, Problem, alg, rhss) in [ + (ModelingToolkit.System, ODEProblem, Tsit5(), zeros(2)), + (ModelingToolkit.System, SDEProblem, ImplicitEM(), [a, b]), + (ModelingToolkit.System, DDEProblem, MethodOfSteps(Tsit5()), [_x(t - 0.1), 0.0]), + (ModelingToolkit.System, SDDEProblem, ImplicitEM(), [_x(t - 0.1) + a, b]), + # polyalg cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, + FastShortcutNonlinearPolyalg(), zeros(2)), + # generalized first order cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, NewtonRaphson(), zeros(2)), + # quasi newton cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, Klement(), zeros(2)), + # noinit cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, SimpleNewtonRaphson(), zeros(2)), + # DFSane cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, DFSane(), zeros(2)), + # Least squares + # polyalg cache + (NonlinearSystemWrapper, NLLSProblemWrapper, FastShortcutNLLSPolyalg(), zeros(2)), + # generalized first order cache + (NonlinearSystemWrapper, NLLSProblemWrapper, LevenbergMarquardt(), zeros(2)), + # noinit cache + (NonlinearSystemWrapper, NLLSProblemWrapper, SimpleGaussNewton(), zeros(2)) + ] + is_nlsolve = alg isa SciMLBase.AbstractNonlinearAlgorithm + D = is_nlsolve ? v -> v^3 : Differential(t) + @mtkbuild sys = System( [D(x) ~ x + rhss[1], p ~ x + y + rhss[2]], t; defaults = [p => missing], guesses = [ x => 0.0, p => 0.0]) prob = Problem(sys, [x => 1.0, y => 1.0], (0.0, 1.0)) @test init(prob, alg).ps[p] ≈ 2.0 # nonsensical value for y just to test that equations work - prob2 = remake(prob; u0 = [x => 1.0, y => 2x + exp(t)]) - @test init(prob2, alg).ps[p] ≈ 4.0 + prob2 = remake(prob; u0 = [x => 1.0, y => 2x + exp(x)]) + @test init(prob2, alg).ps[p] ≈ 3 + exp(1) # solve for `x` given `p` and `y` - prob3 = remake(prob; u0 = [x => nothing, y => 1.0], p = [p => 2x + exp(t)]) - @test init(prob3, alg)[x] ≈ 0.0 + prob3 = remake(prob; u0 = [x => nothing, y => 1.0], p = [p => 2x + exp(y)]) + @test init(prob3, alg)[x] ≈ 1 - exp(1) @test_logs (:warn, r"overdetermined") remake( prob; u0 = [x => 1.0, y => 2.0], p = [p => 4.0]) prob4 = remake(prob; u0 = [x => 1.0, y => 2.0], p = [p => 4.0]) @@ -817,44 +930,73 @@ end @brownian a x = _x(t) - @testset "$Problem" for (Problem, alg, rhss) in [ - (ODEProblem, Tsit5(), 0), (SDEProblem, ImplicitEM(), a), - (DDEProblem, MethodOfSteps(Tsit5()), _x(t - 0.1)), - (SDDEProblem, ImplicitEM(), _x(t - 0.1) + a)] + @testset "$Problem with $(SciMLBase.parameterless_type(typeof(alg)))" for (System, Problem, alg, rhss) in [ + (ModelingToolkit.System, ODEProblem, Tsit5(), 0), + (ModelingToolkit.System, SDEProblem, ImplicitEM(), a), + (ModelingToolkit.System, DDEProblem, MethodOfSteps(Tsit5()), _x(t - 0.1)), + (ModelingToolkit.System, SDDEProblem, ImplicitEM(), _x(t - 0.1) + a), + # polyalg cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, + FastShortcutNonlinearPolyalg(), 0), + # generalized first order cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, NewtonRaphson(), 0), + # quasi newton cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, Klement(), 0), + # noinit cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, SimpleNewtonRaphson(), 0), + # DFSane cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, DFSane(), 0), + # Least squares + # polyalg cache + (NonlinearSystemWrapper, NLLSProblemWrapper, FastShortcutNLLSPolyalg(), 0), + # generalized first order cache + (NonlinearSystemWrapper, NLLSProblemWrapper, LevenbergMarquardt(), 0), + # noinit cache + (NonlinearSystemWrapper, NLLSProblemWrapper, SimpleGaussNewton(), 0) + ] + is_nlsolve = alg isa SciMLBase.AbstractNonlinearAlgorithm + D = is_nlsolve ? v -> v^3 : Differential(t) + alge_eqs = [y^2 * q + q^2 * x ~ 0, z * p - p^2 * x * z ~ 0] + @mtkbuild sys = System( - [D(x) ~ x * p + y * q + rhss, y^2 * q + q^2 * x ~ 0, z * p - p^2 * x * z ~ 0], + [D(x) ~ x * p + y^2 * q + rhss; alge_eqs], t; guesses = [x => 0.0, y => 0.0, z => 0.0, p => 0.0, q => 0.0]) - prob = Problem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0, q => missing]) + prob = Problem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0, q => missing]; + initialization_eqs = is_nlsolve ? alge_eqs : []) @test is_variable(prob.f.initialization_data.initializeprob, q) ps = prob.p newps = SciMLStructures.replace(Tunable(), ps, ForwardDiff.Dual.(ps.tunable)) prob2 = remake(prob; p = newps) - @test eltype(prob2.f.initialization_data.initializeprob.u0) <: ForwardDiff.Dual + @test eltype(state_values(prob2.f.initialization_data.initializeprob)) <: + ForwardDiff.Dual @test eltype(prob2.f.initialization_data.initializeprob.p.tunable) <: ForwardDiff.Dual - @test prob2.f.initialization_data.initializeprob.u0 ≈ - prob.f.initialization_data.initializeprob.u0 + @test state_values(prob2.f.initialization_data.initializeprob) ≈ + state_values(prob.f.initialization_data.initializeprob) prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0)) - @test eltype(prob2.f.initialization_data.initializeprob.u0) <: ForwardDiff.Dual + @test eltype(state_values(prob2.f.initialization_data.initializeprob)) <: + ForwardDiff.Dual @test eltype(prob2.f.initialization_data.initializeprob.p.tunable) <: Float64 - @test prob2.f.initialization_data.initializeprob.u0 ≈ - prob.f.initialization_data.initializeprob.u0 + @test state_values(prob2.f.initialization_data.initializeprob) ≈ + state_values(prob.f.initialization_data.initializeprob) prob2 = remake(prob; u0 = ForwardDiff.Dual.(prob.u0), p = newps) - @test eltype(prob2.f.initialization_data.initializeprob.u0) <: ForwardDiff.Dual + @test eltype(state_values(prob2.f.initialization_data.initializeprob)) <: + ForwardDiff.Dual @test eltype(prob2.f.initialization_data.initializeprob.p.tunable) <: ForwardDiff.Dual - @test prob2.f.initialization_data.initializeprob.u0 ≈ - prob.f.initialization_data.initializeprob.u0 + @test state_values(prob2.f.initialization_data.initializeprob) ≈ + state_values(prob.f.initialization_data.initializeprob) prob2 = remake(prob; u0 = [x => ForwardDiff.Dual(1.0)], p = [p => ForwardDiff.Dual(1.0), q => missing]) - @test eltype(prob2.f.initialization_data.initializeprob.u0) <: ForwardDiff.Dual + @test eltype(state_values(prob2.f.initialization_data.initializeprob)) <: + ForwardDiff.Dual @test eltype(prob2.f.initialization_data.initializeprob.p.tunable) <: ForwardDiff.Dual - @test prob2.f.initialization_data.initializeprob.u0 ≈ - prob.f.initialization_data.initializeprob.u0 + @test state_values(prob2.f.initialization_data.initializeprob) ≈ + state_values(prob.f.initialization_data.initializeprob) end end @@ -864,19 +1006,44 @@ end @brownian a x = _x(t) - @testset "$Problem" for (Problem, alg, rhss) in [ - (ODEProblem, Tsit5(), 0), (SDEProblem, ImplicitEM(), a), - (DDEProblem, MethodOfSteps(Tsit5()), _x(t - 0.1)), - (SDDEProblem, ImplicitEM(), _x(t - 0.1) + a)] + @testset "$Problem with $(SciMLBase.parameterless_type(typeof(alg)))" for (System, Problem, alg, rhss) in [ + (ModelingToolkit.System, ODEProblem, Tsit5(), 0), + (ModelingToolkit.System, SDEProblem, ImplicitEM(), a), + (ModelingToolkit.System, DDEProblem, MethodOfSteps(Tsit5()), _x(t - 0.1)), + (ModelingToolkit.System, SDDEProblem, ImplicitEM(), _x(t - 0.1) + a), + # polyalg cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, + FastShortcutNonlinearPolyalg(), 0), + # generalized first order cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, NewtonRaphson(), 0), + # quasi newton cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, Klement(), 0), + # noinit cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, SimpleNewtonRaphson(), 0), + # DFSane cache + (NonlinearSystemWrapper, NonlinearProblemWrapper, DFSane(), 0), + # Least squares + # polyalg cache + (NonlinearSystemWrapper, NLLSProblemWrapper, FastShortcutNLLSPolyalg(), 0), + # generalized first order cache + (NonlinearSystemWrapper, NLLSProblemWrapper, LevenbergMarquardt(), 0), + # noinit cache + (NonlinearSystemWrapper, NLLSProblemWrapper, SimpleGaussNewton(), 0) + ] + is_nlsolve = alg isa SciMLBase.AbstractNonlinearAlgorithm + D = is_nlsolve ? v -> v^3 : Differential(t) + alge_eqs = [y^2 + 4y * p^2 ~ x^3] @mtkbuild sys = System( - [D(x) ~ x + p * y + rhss, y^2 + 4y * p^2 ~ x], t; guesses = [ + [D(x) ~ x + p * y^2 + rhss; alge_eqs], t; guesses = [ y => 1.0, p => 1.0]) - prob = Problem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0]) + prob = Problem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0]; + initialization_eqs = is_nlsolve ? alge_eqs : []) @test is_variable(prob.f.initialization_data.initializeprob, y) prob2 = @test_nowarn remake(prob; p = [p => 3.0]) # ensure no over/under-determined warning @test is_variable(prob.f.initialization_data.initializeprob, y) - prob = Problem(sys, [y => 1.0, x => 2.0], (0.0, 1.0), [p => missing]) + prob = Problem(sys, [y => 1.0, x => 2.0], (0.0, 1.0), [p => missing]; + initialization_eqs = is_nlsolve ? alge_eqs : []) @test is_variable(prob.f.initialization_data.initializeprob, p) prob2 = @test_nowarn remake(prob; u0 = [y => 0.5]) @test is_variable(prob.f.initialization_data.initializeprob, p) From 51eeeeb27ee5094424bd92eb589f8b99f843c592 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 6 Dec 2024 17:43:57 +0530 Subject: [PATCH 29/38] fix: store and propagate `initialization_eqs` provided to Problem --- src/systems/nonlinear/initializesystem.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index e32f967f4a..cf6c6d4d42 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -21,6 +21,7 @@ function generate_initializesystem(sys::AbstractSystem; eqs_ics = Equation[] defs = copy(defaults(sys)) # copy so we don't modify sys.defaults additional_guesses = anydict(guesses) + additional_initialization_eqs = Vector{Equation}(initialization_eqs) guesses = merge(get_guesses(sys), additional_guesses) idxs_diff = isdiffeq.(eqs) @@ -191,7 +192,7 @@ function generate_initializesystem(sys::AbstractSystem; defs[k] = substitute(defs[k], paramsubs) end meta = InitializationSystemMetadata( - anydict(u0map), anydict(pmap), additional_guesses, extra_metadata) + anydict(u0map), anydict(pmap), additional_guesses, additional_initialization_eqs, extra_metadata) return NonlinearSystem(eqs_ics, vars, pars; @@ -207,6 +208,7 @@ struct InitializationSystemMetadata u0map::Dict{Any, Any} pmap::Dict{Any, Any} additional_guesses::Dict{Any, Any} + additional_initialization_eqs::Vector{Equation} extra_metadata::NamedTuple end @@ -299,6 +301,7 @@ function SciMLBase.remake_initialization_data( defs = defaults(sys) cmap, cs = get_cmap(sys) use_scc = true + initialization_eqs = Equation[] if SciMLBase.has_initializeprob(odefn) oldsys = odefn.initialization_data.initializeprob.f.sys @@ -308,6 +311,7 @@ function SciMLBase.remake_initialization_data( pmap = merge(meta.pmap, pmap) merge!(guesses, meta.additional_guesses) use_scc = get(meta.extra_metadata, :use_scc, true) + initialization_eqs = meta.additional_initialization_eqs end else # there is no initializeprob, so the original problem construction @@ -348,7 +352,7 @@ function SciMLBase.remake_initialization_data( op, missing_unknowns, missing_pars = build_operating_point( u0map, pmap, defs, cmap, dvs, ps) kws = maybe_build_initialization_problem( - sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; use_scc) + sys, op, u0map, pmap, t0, defs, guesses, missing_unknowns; use_scc, initialization_eqs) return get(kws, :initialization_data, nothing) end From 96f8d5de87813b300b61e3d8359b7231ec661529 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 14 Dec 2024 13:18:38 +0530 Subject: [PATCH 30/38] build: bump compats --- Project.toml | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 3c6126b7c0..d672c007c2 100644 --- a/Project.toml +++ b/Project.toml @@ -89,6 +89,7 @@ ConstructionBase = "1" DataInterpolations = "6.4" DataStructures = "0.17, 0.18" DeepDiffs = "1" +DelayDiffEq = "5.50" DiffEqBase = "6.157" DiffEqCallbacks = "2.16, 3, 4" DiffEqNoiseProcess = "5" @@ -117,7 +118,7 @@ Libdl = "1" LinearAlgebra = "1" MLStyle = "0.4.17" NaNMath = "0.3, 1" -NonlinearSolve = "3.14, 4" +NonlinearSolve = "4.3" OffsetArrays = "1" OrderedCollections = "1" OrdinaryDiffEq = "6.82.0" @@ -129,7 +130,7 @@ RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SCCNonlinearSolve = "1.0.0" -SciMLBase = "2.66" +SciMLBase = "2.68.1" SciMLStructures = "1.0" Serialization = "1" Setfield = "0.7, 0.8, 1" @@ -137,7 +138,9 @@ SimpleNonlinearSolve = "0.1.0, 1, 2" SparseArrays = "1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" -SymbolicIndexingInterface = "0.3.35" +StochasticDiffEq = "6.72.1" +StochasticDelayDiffEq = "1.8.1" +SymbolicIndexingInterface = "0.3.36" SymbolicUtils = "3.7" Symbolics = "6.19" URIs = "1" From 0a881e71d5f1b5fe77e170713e78b5a5ab3ec023 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 16 Dec 2024 16:22:15 +0530 Subject: [PATCH 31/38] fix: better handle reconstructing initializeprob with new types --- src/systems/diffeqs/abstractodesystem.jl | 5 ++ src/systems/nonlinear/initializesystem.jl | 69 ++++++++++------------- 2 files changed, 34 insertions(+), 40 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index a6607e0c55..dd4165cb57 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1310,6 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem, pmap = parammap, guesses, extra_metadata = (; use_scc)); fully_determined) end + meta = get_metadata(isys) + if meta isa InitializationSystemMetadata + @set! isys.metadata.oop_reconstruct_u0_p = ReconstructInitializeprob(sys, isys) + end + ts = get_tearing_state(isys) unassigned_vars = StructuralTransformations.singular_check(ts) if warn_initialize_determined && !isempty(unassigned_vars) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index cf6c6d4d42..602fc7cac9 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -192,7 +192,8 @@ function generate_initializesystem(sys::AbstractSystem; defs[k] = substitute(defs[k], paramsubs) end meta = InitializationSystemMetadata( - anydict(u0map), anydict(pmap), additional_guesses, additional_initialization_eqs, extra_metadata) + anydict(u0map), anydict(pmap), additional_guesses, + additional_initialization_eqs, extra_metadata, nothing) return NonlinearSystem(eqs_ics, vars, pars; @@ -204,12 +205,30 @@ function generate_initializesystem(sys::AbstractSystem; kwargs...) end +struct ReconstructInitializeprob + getter::Any + setter::Any +end + +function ReconstructInitializeprob(srcsys::AbstractSystem, dstsys::AbstractSystem) + syms = [unknowns(dstsys); + reduce(vcat, reorder_parameters(dstsys, parameters(dstsys)); init = [])] + getter = getu(srcsys, syms) + setter = setsym_oop(dstsys, syms) + return ReconstructInitializeprob(getter, setter) +end + +function (rip::ReconstructInitializeprob)(srcvalp, dstvalp) + rip.setter(dstvalp, rip.getter(srcvalp)) +end + struct InitializationSystemMetadata u0map::Dict{Any, Any} pmap::Dict{Any, Any} additional_guesses::Dict{Any, Any} additional_initialization_eqs::Vector{Equation} extra_metadata::NamedTuple + oop_reconstruct_u0_p::Union{Nothing, ReconstructInitializeprob} end function is_parameter_solvable(p, pmap, defs, guesses) @@ -239,45 +258,15 @@ function SciMLBase.remake_initialization_data( if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem) return oldinitdata end - pidxs = ParameterIndex[] - pvals = [] - u0idxs = Int[] - u0vals = [] - for sym in variable_symbols(oldinitprob) - if is_variable(sys, sym) || has_observed_with_lhs(sys, sym) - u0 !== missing || continue - idx = variable_index(oldinitprob, sym) - push!(u0idxs, idx) - push!(u0vals, eltype(u0)(state_values(oldinitprob, idx))) - else - p !== missing || continue - idx = variable_index(oldinitprob, sym) - push!(u0idxs, idx) - push!(u0vals, typeof(getp(sys, sym)(p))(state_values(oldinitprob, idx))) - end - end - if p !== missing - for sym in parameter_symbols(oldinitprob) - push!(pidxs, parameter_index(oldinitprob, sym)) - if is_time_dependent(sys) && isequal(sym, get_iv(sys)) - push!(pvals, t0) - else - push!(pvals, getp(sys, sym)(p)) - end - end - end - if isempty(u0idxs) - newu0 = state_values(oldinitprob) - else - newu0 = remake_buffer( - oldinitprob.f.sys, state_values(oldinitprob), u0idxs, u0vals) - end - if isempty(pidxs) - newp = parameter_values(oldinitprob) + oldinitsys = oldinitprob.f.sys + meta = get_metadata(oldinitsys) + if meta isa InitializationSystemMetadata && meta.oop_reconstruct_u0_p !== nothing + reconstruct_fn = meta.oop_reconstruct_u0_p else - newp = remake_buffer( - oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals) + reconstruct_fn = ReconstructInitializeprob(sys, oldinitsys) end + new_initu0, new_initp = reconstruct_fn( + ProblemState(; u = newu0, p = newp, t = t0), oldinitprob) if oldinitprob.f.resid_prototype === nothing newf = oldinitprob.f else @@ -285,9 +274,9 @@ function SciMLBase.remake_initialization_data( SciMLBase.isinplace(oldinitprob.f), SciMLBase.specialization(oldinitprob.f)}( oldinitprob.f; resid_prototype = calculate_resid_prototype( - length(oldinitprob.f.resid_prototype), newu0, newp)) + length(oldinitprob.f.resid_prototype), new_initu0, new_initp)) end - initprob = remake(oldinitprob; f = newf, u0 = newu0, p = newp) + initprob = remake(oldinitprob; f = newf, u0 = new_initu0, p = new_initp) return SciMLBase.OverrideInitData(initprob, oldinitdata.update_initializeprob!, oldinitdata.initializeprobmap, oldinitdata.initializeprobpmap) end From 2e07200e03af428b127814af85f48fdeccdaf574 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 16 Dec 2024 16:22:51 +0530 Subject: [PATCH 32/38] test: fix incorrect initial values in tests --- test/nonlinearsystem.jl | 2 +- test/reduction.jl | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/test/nonlinearsystem.jl b/test/nonlinearsystem.jl index f766ca0131..cbd95a50d3 100644 --- a/test/nonlinearsystem.jl +++ b/test/nonlinearsystem.jl @@ -293,7 +293,7 @@ sys = structural_simplify(ns; conservative = true) eqs = [0 ~ σ * (y - x) 0 ~ x * (ρ - z) - y 0 ~ x * y - β * z] - guesses = [x => 1.0, y => 0.0, z => 0.0] + guesses = [x => 1.0, z => 0.0] ps = [σ => 10.0, ρ => 26.0, β => 8 / 3] @mtkbuild ns = NonlinearSystem(eqs) diff --git a/test/reduction.jl b/test/reduction.jl index 6d7a05b99e..fa9029a652 100644 --- a/test/reduction.jl +++ b/test/reduction.jl @@ -158,9 +158,7 @@ eqs = [u1 ~ u2 reducedsys = structural_simplify(sys) @test length(observed(reducedsys)) == 2 -u0 = [u1 => 1 - u2 => 1 - u3 => 0.3] +u0 = [u2 => 1] pp = [2] nlprob = NonlinearProblem(reducedsys, u0, [p => pp[1]]) reducedsol = solve(nlprob, NewtonRaphson()) From 878122305e616c0d0b4a40df493d58deb15d8227 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 24 Dec 2024 15:33:10 +0530 Subject: [PATCH 33/38] fix: fix `flatten_equations` for higher-dimension array equations --- src/systems/diffeqs/abstractodesystem.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index f4e29346ff..8e7b2ac97e 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1230,7 +1230,7 @@ function flatten_equations(eqs) error("LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must either both be array expressions or both scalar") size(eq.lhs) == size(eq.rhs) || error("Size of LHS ($(eq.lhs)) and RHS ($(eq.rhs)) must match: got $(size(eq.lhs)) and $(size(eq.rhs))") - return collect(eq.lhs) .~ collect(eq.rhs) + return vec(collect(eq.lhs) .~ collect(eq.rhs)) else eq end From 994b2db06507c26826e4950d3f70e19b69d6e09c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 24 Dec 2024 15:34:38 +0530 Subject: [PATCH 34/38] fix: fix `timeseries_parameter_index` for array symbolics --- src/systems/index_cache.jl | 4 +++- test/symbolic_indexing_interface.jl | 11 +++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/src/systems/index_cache.jl b/src/systems/index_cache.jl index 4fc2f18bfd..623623da42 100644 --- a/src/systems/index_cache.jl +++ b/src/systems/index_cache.jl @@ -404,6 +404,7 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy sym = get(ic.symbol_to_variable, sym, nothing) sym === nothing && return nothing end + sym = unwrap(sym) idx = check_index_map(ic.discrete_idx, sym) idx === nothing || return ParameterTimeseriesIndex(idx.clock_idx, (idx.buffer_idx, idx.idx_in_clock)) @@ -411,7 +412,8 @@ function SymbolicIndexingInterface.timeseries_parameter_index(ic::IndexCache, sy args = arguments(sym) idx = timeseries_parameter_index(ic, args[1]) idx === nothing && return nothing - ParameterIndex(idx.portion, (idx.idx..., args[2:end]...), idx.validate_size) + return ParameterTimeseriesIndex( + idx.timeseries_idx, (idx.parameter_idx..., args[2:end]...)) end function check_index_map(idxmap, sym) diff --git a/test/symbolic_indexing_interface.jl b/test/symbolic_indexing_interface.jl index e46c197dde..4a7cd926b4 100644 --- a/test/symbolic_indexing_interface.jl +++ b/test/symbolic_indexing_interface.jl @@ -224,3 +224,14 @@ end end @test isempty(get_all_timeseries_indexes(sys, a)) end + +@testset "`timeseries_parameter_index` on unwrapped scalarized timeseries parameter" begin + @variables x(t)[1:2] + @parameters p(t)[1:2, 1:2] + ev = [x[1] ~ 2.0] => [p ~ -ones(2, 2)] + @mtkbuild sys = ODESystem(D(x) ~ p * x, t; continuous_events = [ev]) + p = ModelingToolkit.unwrap(p) + @test timeseries_parameter_index(sys, p) === ParameterTimeseriesIndex(1, (1, 1)) + @test timeseries_parameter_index(sys, p[1, 1]) === + ParameterTimeseriesIndex(1, (1, 1, 1, 1)) +end From 1835a56c04e138a2b4afd07bf3f25994bfb19e4c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 24 Dec 2024 16:10:52 +0530 Subject: [PATCH 35/38] test: fix `IfLifting` test --- test/if_lifting.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/if_lifting.jl b/test/if_lifting.jl index d702355506..b72b468bf1 100644 --- a/test/if_lifting.jl +++ b/test/if_lifting.jl @@ -21,7 +21,9 @@ using ModelingToolkit: t_nounits as t, D_nounits as D, IfLifting, no_if_lift @test operation(only(equations(ss2)).rhs) === ifelse discvar = only(parameters(ss2)) - prob2 = ODEProblem(ss2, [x => 0.0], (0.0, 5.0)) + prob1 = ODEProblem(ss1, [ss1.x => 0.0], (0.0, 5.0)) + sol1 = solve(prob1, Tsit5()) + prob2 = ODEProblem(ss2, [ss2.x => 0.0], (0.0, 5.0)) sol2 = solve(prob2, Tsit5()) @test count(isapprox(pi), sol2.t) == 2 @test any(isapprox(pi), sol2.discretes[1].t) From 127cf3c36831595d57e423ff729fe600b367bcea Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Jan 2025 18:44:38 +0530 Subject: [PATCH 36/38] build: bump OrdinaryDiffEqDefault compat --- Project.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Project.toml b/Project.toml index d672c007c2..b3ffc46164 100644 --- a/Project.toml +++ b/Project.toml @@ -123,6 +123,7 @@ OffsetArrays = "1" OrderedCollections = "1" OrdinaryDiffEq = "6.82.0" OrdinaryDiffEqCore = "1.13.0" +OrdinaryDiffEqDefault = "1.2" OrdinaryDiffEqNonlinearSolve = "1.3.0" PrecompileTools = "1" REPL = "1" @@ -166,6 +167,7 @@ OptimizationMOI = "fd9f6733-72f4-499f-8506-86b2bdd0dea1" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7" OrdinaryDiffEqNonlinearSolve = "127b3ac7-2247-4354-8eb6-78cf4e7c58e8" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" From da03e00bc5499d6b58417fbcae0c7c4229dc8227 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 2 Jan 2025 20:24:57 +0530 Subject: [PATCH 37/38] test: add `OrdinaryDiffEqDefault` to the test environment --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b3ffc46164..3d14a2dc65 100644 --- a/Project.toml +++ b/Project.toml @@ -183,4 +183,4 @@ Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"] +test = ["AmplNLWriter", "BenchmarkTools", "ControlSystemsBase", "DataInterpolations", "DelayDiffEq", "NonlinearSolve", "ForwardDiff", "Ipopt", "Ipopt_jll", "ModelingToolkitStandardLibrary", "Optimization", "OptimizationOptimJL", "OptimizationMOI", "OrdinaryDiffEq", "OrdinaryDiffEqCore", "OrdinaryDiffEqDefault", "REPL", "Random", "ReferenceTests", "SafeTestsets", "StableRNGs", "Statistics", "SteadyStateDiffEq", "Test", "StochasticDiffEq", "Sundials", "StochasticDelayDiffEq", "Pkg", "JET", "OrdinaryDiffEqNonlinearSolve"] From 59a41b1581a86d407fcdb8489510af8807335cea Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sun, 5 Jan 2025 22:41:36 +0530 Subject: [PATCH 38/38] build: bump minor version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3d14a2dc65..7b0997823c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ModelingToolkit" uuid = "961ee093-0014-501f-94e3-6117800e7a78" authors = ["Yingbo Ma ", "Chris Rackauckas and contributors"] -version = "9.59.0" +version = "9.60.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"