From 9589a1f347673cb85b639fbccdab508201d64402 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 21 Nov 2024 12:50:57 +0530 Subject: [PATCH 1/7] feat: store args and kwargs in `EmptySciMLFunction` --- src/systems/problem_utils.jl | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 1cf110df13..1cfb8eb40a 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -431,12 +431,16 @@ end $(TYPEDEF) A simple utility meant to be used as the `constructor` passed to `process_SciMLProblem` in -case constructing a SciMLFunction is not required. +case constructing a SciMLFunction is not required. The arguments passed to it are available +in the `args` field, and the keyword arguments in the `kwargs` field. """ -struct EmptySciMLFunction end +struct EmptySciMLFunction{A, K} + args::A + kwargs::K +end function EmptySciMLFunction(args...; kwargs...) - return nothing + return EmptySciMLFunction{typeof(args), typeof(kwargs)}(args, kwargs) end """ From 8d4f5423e424279ba0855bf79e008a25e6544a77 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 20 Nov 2024 12:54:33 +0530 Subject: [PATCH 2/7] feat: propagate `ODEProblem` guesses to `remake` --- src/systems/diffeqs/abstractodesystem.jl | 4 +- src/systems/nonlinear/initializesystem.jl | 90 ++++++----------------- src/systems/problem_utils.jl | 14 +++- test/initializationsystem.jl | 23 ++++++ 4 files changed, 60 insertions(+), 71 deletions(-) diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index cbf835f48d..af96c4fcfe 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -1310,11 +1310,11 @@ function InitializationProblem{iip, specialize}(sys::AbstractODESystem, elseif isempty(u0map) && get_initializesystem(sys) === nothing isys = structural_simplify( generate_initializesystem( - sys; initialization_eqs, check_units, pmap = parammap); fully_determined) + sys; initialization_eqs, check_units, pmap = parammap, guesses); fully_determined) else isys = structural_simplify( generate_initializesystem( - sys; u0map, initialization_eqs, check_units, pmap = parammap); fully_determined) + sys; u0map, initialization_eqs, check_units, pmap = parammap, guesses); fully_determined) end ts = get_tearing_state(isys) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index eff19afb07..fe41e44d84 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -30,7 +30,8 @@ function generate_initializesystem(sys::ODESystem; # 1) process dummy derivatives and u0map into initialization system eqs_ics = eqs[idxs_alge] # start equation list with algebraic equations defs = copy(defaults(sys)) # copy so we don't modify sys.defaults - guesses = merge(get_guesses(sys), todict(guesses)) + additional_guesses = anydict(guesses) + guesses = merge(get_guesses(sys), additional_guesses) schedule = getfield(sys, :schedule) if !isnothing(schedule) for x in filter(x -> !isnothing(x[1]), schedule.dummy_sub) @@ -178,7 +179,7 @@ function generate_initializesystem(sys::ODESystem; for k in keys(defs) defs[k] = substitute(defs[k], paramsubs) end - meta = InitializationSystemMetadata(Dict{Any, Any}(u0map), Dict{Any, Any}(pmap)) + meta = InitializationSystemMetadata(anydict(u0map), anydict(pmap), additional_guesses) return NonlinearSystem(eqs_ics, vars, pars; @@ -193,6 +194,7 @@ end struct InitializationSystemMetadata u0map::Dict{Any, Any} pmap::Dict{Any, Any} + additional_guesses::Dict{Any, Any} end function is_parameter_solvable(p, pmap, defs, guesses) @@ -263,75 +265,29 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) return initprob, odefn.update_initializeprob!, odefn.initializeprobmap, odefn.initializeprobpmap end - if u0 === missing || isempty(u0) - u0 = Dict() - elseif !(eltype(u0) <: Pair) - u0 = Dict(unknowns(sys) .=> u0) - end - if p === missing - p = Dict() + dvs = unknowns(sys) + ps = parameters(sys) + u0map = to_varmap(u0, dvs) + pmap = to_varmap(p, ps) + guesses = Dict() + if SciMLBase.has_initializeprob(odefn) + oldsys = odefn.initializeprob.f.sys + meta = get_metadata(oldsys) + if meta isa InitializationSystemMetadata + u0map = merge(meta.u0map, u0map) + pmap = merge(meta.pmap, pmap) + merge!(guesses, meta.additional_guesses) + end end if t0 === nothing t0 = 0.0 end - u0 = todict(u0) - defs = defaults(sys) - varmap = merge(defs, u0) - for k in collect(keys(varmap)) - if varmap[k] === nothing - delete!(varmap, k) - end - end - varmap = canonicalize_varmap(varmap) - missingvars = setdiff(unknowns(sys), collect(keys(varmap))) - setobserved = filter(keys(varmap)) do var - has_observed_with_lhs(sys, var) || has_observed_with_lhs(sys, default_toterm(var)) - end - p = todict(p) - guesses = ModelingToolkit.guesses(sys) - solvablepars = [par - for par in parameters(sys) - if is_parameter_solvable(par, p, defs, guesses)] - pvarmap = merge(defs, p) - setparobserved = filter(keys(pvarmap)) do var - has_parameter_dependency_with_lhs(sys, var) - end - if (((!isempty(missingvars) || !isempty(solvablepars) || - !isempty(setobserved) || !isempty(setparobserved)) && - ModelingToolkit.get_tearing_state(sys) !== nothing) || - !isempty(initialization_equations(sys))) - if SciMLBase.has_initializeprob(odefn) - oldsys = odefn.initializeprob.f.sys - meta = get_metadata(oldsys) - if meta isa InitializationSystemMetadata - u0 = merge(meta.u0map, u0) - p = merge(meta.pmap, p) - end - end - for k in collect(keys(u0)) - if u0[k] === nothing - delete!(u0, k) - end - end - for k in collect(keys(p)) - if p[k] === nothing - delete!(p, k) - end - end - - initprob = InitializationProblem(sys, t0, u0, p) - initprobmap = getu(initprob, unknowns(sys)) - punknowns = [p for p in all_variable_symbols(initprob) if is_parameter(sys, p)] - getpunknowns = getu(initprob, punknowns) - setpunknowns = setp(sys, punknowns) - initprobpmap = GetUpdatedMTKParameters(getpunknowns, setpunknowns) - reqd_syms = parameter_symbols(initprob) - update_initializeprob! = UpdateInitializeprob( - getu(sys, reqd_syms), setu(initprob, reqd_syms)) - return initprob, update_initializeprob!, initprobmap, initprobpmap - else - return nothing, nothing, nothing, nothing - end + filter_missing_values!(u0map) + filter_missing_values!(pmap) + f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0) + kws = f.kwargs + return get(kws, :initializeprob, nothing), get(kws, :update_initializeprob!, nothing), get(kws, :initializeprobmap, nothing), + get(kws, :initializeprobpmap, nothing) end """ diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 1cfb8eb40a..9521df7872 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -4,12 +4,13 @@ const AnyDict = Dict{Any, Any} $(TYPEDSIGNATURES) If called without arguments, return `Dict{Any, Any}`. Otherwise, interpret the input -as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters` -and `nothing`. +as a symbolic map and turn it into a `Dict{Any, Any}`. Handles `SciMLBase.NullParameters`, +`missing` and `nothing`. """ anydict() = AnyDict() anydict(::SciMLBase.NullParameters) = AnyDict() anydict(::Nothing) = AnyDict() +anydict(::Missing) = AnyDict() anydict(x::AnyDict) = x anydict(x) = AnyDict(x) @@ -388,6 +389,15 @@ function evaluate_varmap!(varmap::AbstractDict, vars; limit = 100) end end +""" + $(TYPEDSIGNATURES) + +Remove keys in `varmap` whose values are `nothing`. +""" +function filter_missing_values!(varmap::AbstractDict) + filter!(kvp -> kvp[2] !== nothing, varmap) +end + struct GetUpdatedMTKParameters{G, S} # `getu` functor which gets parameters that are unknowns during initialization getpunknowns::G diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index 770f9f12bd..edb4199356 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -975,3 +975,26 @@ end @test integ.ps[p] ≈ 1.0 @test integ.ps[q]≈cbrt(2) rtol=1e-6 end + +@testset "Guesses provided to `ODEProblem` are used in `remake`" begin + @variables x(t) y(t)=2x + @parameters p q=3x + @mtkbuild sys = ODESystem([D(x) ~ x * p + q, x^3 + y^3 ~ 3], t) + prob = ODEProblem( + sys, [], (0.0, 1.0), [p => 1.0]; guesses = [x => 1.0, y => 1.0, q => 1.0]) + @test prob[x] == 0.0 + @test prob[y] == 0.0 + @test prob.ps[p] == 1.0 + @test prob.ps[q] == 0.0 + integ = init(prob) + @test integ[x] ≈ 1 / cbrt(3) + @test integ[y] ≈ 2 / cbrt(3) + @test integ.ps[p] == 1.0 + @test integ.ps[q] ≈ 3 / cbrt(3) + prob2 = remake(prob; u0 = [y => 3x], p = [q => 2x]) + integ2 = init(prob2) + @test integ2[x] ≈ cbrt(3 / 28) + @test integ2[y] ≈ 3cbrt(3 / 28) + @test integ2.ps[p] == 1.0 + @test integ2.ps[q] ≈ 2cbrt(3 / 28) +end From 2641fd8a97cfc785d340e2fa7cea80a3a3815182 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 22 Nov 2024 14:04:13 +0530 Subject: [PATCH 3/7] fix: handle initial values passed as `Symbol`s --- src/systems/problem_utils.jl | 38 ++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/src/systems/problem_utils.jl b/src/systems/problem_utils.jl index 9521df7872..b837eef98e 100644 --- a/src/systems/problem_utils.jl +++ b/src/systems/problem_utils.jl @@ -52,6 +52,42 @@ function add_toterms(varmap::AbstractDict; toterm = default_toterm) return cp end +""" + $(TYPEDSIGNATURES) + +Turn any `Symbol` keys in `varmap` to the appropriate symbolic variables in `sys`. Any +symbols that cannot be converted are ignored. +""" +function symbols_to_symbolics!(sys::AbstractSystem, varmap::AbstractDict) + if is_split(sys) + ic = get_index_cache(sys) + for k in collect(keys(varmap)) + k isa Symbol || continue + newk = get(ic.symbol_to_variable, k, nothing) + newk === nothing && continue + varmap[newk] = varmap[k] + delete!(varmap, k) + end + else + syms = all_symbols(sys) + for k in collect(keys(varmap)) + k isa Symbol || continue + idx = findfirst(syms) do sym + hasname(sym) || return false + name = getname(sym) + return name == k + end + idx === nothing && continue + newk = syms[idx] + if iscall(newk) && operation(newk) === getindex + newk = arguments(newk)[1] + end + varmap[newk] = varmap[k] + delete!(varmap, k) + end + end +end + """ $(TYPEDSIGNATURES) @@ -530,8 +566,10 @@ function process_SciMLProblem( pType = typeof(pmap) _u0map = u0map u0map = to_varmap(u0map, dvs) + symbols_to_symbolics!(sys, u0map) _pmap = pmap pmap = to_varmap(pmap, ps) + symbols_to_symbolics!(sys, pmap) defs = add_toterms(recursive_unwrap(defaults(sys))) cmap, cs = get_cmap(sys) kwargs = NamedTuple(kwargs) From 26980151eacd8d3b1977f544f4a284022ce71ad9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 22 Nov 2024 14:04:34 +0530 Subject: [PATCH 4/7] fix: handle `remake` with no pre-existing initializeprob --- src/systems/nonlinear/initializesystem.jl | 58 ++++++++++++++++++----- 1 file changed, 47 insertions(+), 11 deletions(-) diff --git a/src/systems/nonlinear/initializesystem.jl b/src/systems/nonlinear/initializesystem.jl index fe41e44d84..4abb345822 100644 --- a/src/systems/nonlinear/initializesystem.jl +++ b/src/systems/nonlinear/initializesystem.jl @@ -210,17 +210,16 @@ function is_parameter_solvable(p, pmap, defs, guesses) _val1 === nothing && _val2 !== nothing)) && _val3 !== nothing end -function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) +function SciMLBase.remake_initialization_data(sys::ODESystem, odefn, u0, t0, p, newu0, newp) if u0 === missing && p === missing - return odefn.initializeprob, odefn.update_initializeprob!, odefn.initializeprobmap, - odefn.initializeprobpmap + return odefn.initialization_data end if !(eltype(u0) <: Pair) && !(eltype(p) <: Pair) oldinitprob = odefn.initializeprob - if oldinitprob === nothing || !SciMLBase.has_sys(oldinitprob.f) || - !(oldinitprob.f.sys isa NonlinearSystem) - return oldinitprob, odefn.update_initializeprob!, odefn.initializeprobmap, - odefn.initializeprobpmap + oldinitprob === nothing && return nothing + if !SciMLBase.has_sys(oldinitprob.f) || !(oldinitprob.f.sys isa NonlinearSystem) + return SciMLBase.OverrideInitData(oldinitprob, odefn.update_initializeprob!, + odefn.initializeprobmap, odefn.initializeprobpmap) end pidxs = ParameterIndex[] pvals = [] @@ -262,14 +261,17 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) oldinitprob.f.sys, parameter_values(oldinitprob), pidxs, pvals) end initprob = remake(oldinitprob; u0 = newu0, p = newp) - return initprob, odefn.update_initializeprob!, odefn.initializeprobmap, - odefn.initializeprobpmap + return SciMLBase.OverrideInitData(initprob, odefn.update_initializeprob!, + odefn.initializeprobmap, odefn.initializeprobpmap) end dvs = unknowns(sys) ps = parameters(sys) u0map = to_varmap(u0, dvs) + symbols_to_symbolics!(sys, u0map) pmap = to_varmap(p, ps) + symbols_to_symbolics!(sys, pmap) guesses = Dict() + defs = defaults(sys) if SciMLBase.has_initializeprob(odefn) oldsys = odefn.initializeprob.f.sys meta = get_metadata(oldsys) @@ -278,6 +280,35 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) pmap = merge(meta.pmap, pmap) merge!(guesses, meta.additional_guesses) end + else + # there is no initializeprob, so the original problem construction + # had no solvable parameters and had the differential variables + # specified in `u0map`. + if u0 === missing + # the user didn't pass `u0` to `remake`, so they want to retain + # existing values. Fill the differential variables in `u0map`, + # initialization will either be elided or solve for the algebraic + # variables + diff_idxs = isdiffeq.(equations(sys)) + for i in eachindex(dvs) + diff_idxs[i] || continue + u0map[dvs[i]] = newu0[i] + end + end + if p === missing + # the user didn't pass `p` to `remake`, so they want to retain + # existing values. Fill all parameters in `pmap` so that none of + # them are solvable. + for p in ps + pmap[p] = getp(sys, p)(newp) + end + end + # all non-solvable parameters need values regardless + for p in ps + haskey(pmap, p) && continue + is_parameter_solvable(p, pmap, defs, guesses) && continue + pmap[p] = getp(sys, p)(newp) + end end if t0 === nothing t0 = 0.0 @@ -286,8 +317,13 @@ function SciMLBase.remake_initializeprob(sys::ODESystem, odefn, u0, t0, p) filter_missing_values!(pmap) f, _ = process_SciMLProblem(EmptySciMLFunction, sys, u0map, pmap; guesses, t = t0) kws = f.kwargs - return get(kws, :initializeprob, nothing), get(kws, :update_initializeprob!, nothing), get(kws, :initializeprobmap, nothing), - get(kws, :initializeprobpmap, nothing) + initprob = get(kws, :initializeprob, nothing) + if initprob === nothing + return nothing + end + return SciMLBase.OverrideInitData(initprob, get(kws, :update_initializeprob!, nothing), + get(kws, :initializeprobmap, nothing), + get(kws, :initializeprobpmap, nothing)) end """ From 3383bcede7c841530a9d82f601ff70e40f32deca Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 22 Nov 2024 14:05:11 +0530 Subject: [PATCH 5/7] test: test `remake` without initializeprob and `Symbol` values --- test/initializationsystem.jl | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/test/initializationsystem.jl b/test/initializationsystem.jl index edb4199356..0b0dc42c1e 100644 --- a/test/initializationsystem.jl +++ b/test/initializationsystem.jl @@ -998,3 +998,37 @@ end @test integ2.ps[p] == 1.0 @test integ2.ps[q] ≈ 2cbrt(3 / 28) end + +@testset "Remake problem with no initializeprob" begin + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p [guess = 1.0] q [guess = 1.0] + @mtkbuild sys = ODESystem( + [D(x) ~ p * x + q * y, y ~ 2x], t; parameter_dependencies = [q ~ 2p]) + prob = ODEProblem(sys, [x => 1.0], (0.0, 1.0), [p => 1.0]) + @test prob.f.initialization_data === nothing + prob2 = remake(prob; u0 = [x => 2.0]) + @test prob2[x] == 2.0 + @test prob2.f.initialization_data === nothing + prob3 = remake(prob; u0 = [y => 2.0]) + @test prob3.f.initialization_data !== nothing + @test init(prob3)[x] ≈ 1.0 + prob4 = remake(prob; p = [p => 1.0]) + @test prob4.f.initialization_data === nothing + prob5 = remake(prob; p = [p => missing, q => 2.0]) + @test prob5.f.initialization_data !== nothing + @test init(prob5).ps[p] ≈ 1.0 +end + +@testset "Variables provided as symbols" begin + @variables x(t) [guess = 1.0] y(t) [guess = 1.0] + @parameters p [guess = 1.0] q [guess = 1.0] + @mtkbuild sys = ODESystem( + [D(x) ~ p * x + q * y, y ~ 2x], t; parameter_dependencies = [q ~ 2p]) + prob = ODEProblem(sys, [:x => 1.0], (0.0, 1.0), [p => 1.0]) + @test prob.f.initialization_data === nothing + prob2 = remake(prob; u0 = [:x => 2.0]) + @test prob2.f.initialization_data === nothing + prob3 = remake(prob; u0 = [:y => 1.0]) + @test prob3.f.initialization_data !== nothing + @test init(prob3)[x] ≈ 0.5 +end From 83ed891cd7cf2319ffbef0c5bb870424de5ab6ff Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Nov 2024 23:37:12 +0530 Subject: [PATCH 6/7] build: bump SciMLBase compat --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 093c632c42..1621669c2c 100644 --- a/Project.toml +++ b/Project.toml @@ -126,7 +126,7 @@ REPL = "1" RecursiveArrayTools = "3.26" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" -SciMLBase = "2.57.1" +SciMLBase = "2.64" SciMLStructures = "1.0" Serialization = "1" Setfield = "0.7, 0.8, 1" From 10cc9c16854fc277b88967bbcd709bf42a3b14d2 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Sat, 30 Nov 2024 23:40:22 +0530 Subject: [PATCH 7/7] refactor: format --- src/systems/optimization/optimizationsystem.jl | 3 ++- src/utils.jl | 1 - 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/systems/optimization/optimizationsystem.jl b/src/systems/optimization/optimizationsystem.jl index 3425216b5f..43e9294dd3 100644 --- a/src/systems/optimization/optimizationsystem.jl +++ b/src/systems/optimization/optimizationsystem.jl @@ -168,7 +168,8 @@ function OptimizationSystem(objective; constraints = [], kwargs...) push!(new_ps, p) end end - return OptimizationSystem(objective, collect(allunknowns), collect(new_ps); constraints, kwargs...) + return OptimizationSystem( + objective, collect(allunknowns), collect(new_ps); constraints, kwargs...) end function flatten(sys::OptimizationSystem) diff --git a/src/utils.jl b/src/utils.jl index 1555cd624e..416efd8f2c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -569,7 +569,6 @@ function collect_vars!(unknowns, parameters, p::Pair, iv; depth = 0, op = Differ return nothing end - function collect_var!(unknowns, parameters, var, iv; depth = 0) isequal(var, iv) && return nothing check_scope_depth(getmetadata(var, SymScope, LocalScope()), depth) || return nothing