diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 6cdcd72855..2798349ae6 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -2521,18 +2521,16 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nam eqs = union(get_eqs(basesys), get_eqs(sys)) sts = union(get_unknowns(basesys), get_unknowns(sys)) ps = union(get_ps(basesys), get_ps(sys)) - base_deps = parameter_dependencies(basesys) - deps = parameter_dependencies(sys) - dep_ps = isnothing(base_deps) ? deps : - isnothing(deps) ? base_deps : union(base_deps, deps) + dep_ps = union_nothing(parameter_dependencies(basesys), parameter_dependencies(sys)) obs = union(get_observed(basesys), get_observed(sys)) cevs = union(get_continuous_events(basesys), get_continuous_events(sys)) devs = union(get_discrete_events(basesys), get_discrete_events(sys)) defs = merge(get_defaults(basesys), get_defaults(sys)) # prefer `sys` + meta = union_nothing(get_metadata(basesys), get_metadata(sys)) syss = union(get_systems(basesys), get_systems(sys)) args = length(ivs) == 0 ? (eqs, sts, ps) : (eqs, ivs[1], sts, ps) kwargs = (parameter_dependencies = dep_ps, observed = obs, continuous_events = cevs, - discrete_events = devs, defaults = defs, systems = syss, + discrete_events = devs, defaults = defs, systems = syss, metadata = meta, name = name, gui_metadata = gui_metadata) # collect fields specific to some system types diff --git a/src/utils.jl b/src/utils.jl index 4ee221e14f..67c5ce278b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,14 @@ +""" + union_nothing(x::Union{T1, Nothing}, y::Union{T2, Nothing}) where {T1, T2} + +Unite x and y gracefully when they could be nothing. If neither is nothing, x and y are united normally. If one is nothing, the other is returned unmodified. If both are nothing, nothing is returned. +""" +function union_nothing(x::Union{T1, Nothing}, y::Union{T2, Nothing}) where {T1, T2} + isnothing(x) && return y # y can be nothing or something + isnothing(y) && return x # x can be nothing or something + return union(x, y) # both x and y are something and can be united normally +end + get_iv(D::Differential) = D.x function make_operation(@nospecialize(op), args) diff --git a/test/odesystem.jl b/test/odesystem.jl index c7da5ba702..d54baa8a93 100644 --- a/test/odesystem.jl +++ b/test/odesystem.jl @@ -1195,6 +1195,20 @@ end @test buffer ≈ [2.0, 3.0, 4.0] end +# https://github.com/SciML/ModelingToolkit.jl/issues/2502 +@testset "Extend systems with a field that can be nothing" begin + A = Dict(:a => 1) + B = Dict(:b => 2) + @named A1 = ODESystem(Equation[], t, [], []) + @named B1 = ODESystem(Equation[], t, [], []) + @named A2 = ODESystem(Equation[], t, [], []; metadata = A) + @named B2 = ODESystem(Equation[], t, [], []; metadata = B) + @test ModelingToolkit.get_metadata(extend(A1, B1)) == nothing + @test ModelingToolkit.get_metadata(extend(A1, B2)) == B + @test ModelingToolkit.get_metadata(extend(A2, B1)) == A + @test Set(ModelingToolkit.get_metadata(extend(A2, B2))) == Set(A ∪ B) +end + # https://github.com/SciML/ModelingToolkit.jl/issues/2859 @testset "Initialization with defaults from observed equations (edge case)" begin @variables x(t) y(t) z(t)