Skip to content

Commit

Permalink
Make extend() handle all fields of arbitrary system types
Browse files Browse the repository at this point in the history
  • Loading branch information
hersle committed Jul 18, 2024
1 parent cf82212 commit 319d890
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 36 deletions.
38 changes: 13 additions & 25 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -731,9 +731,6 @@ $(TYPEDSIGNATURES)
Remake the system `sys` with every field replaced by the value in `kwargs`.
If `skip_nonexisting`, fields that do not exist in `sys` are *silently skipped*!
This means that typos in field names will not error, so use this with caution!
```julia
@variables x(t) y(t)
@named sysx = ODESystem([x ~ 0], t)
Expand All @@ -746,7 +743,6 @@ WARNING: intended for internal use; does not perform any sanity checks.
# TODO: use SciMLBase's generic remake()? doesn't work out of the box, though
function remake(sys::AbstractSystem; skip_nonexisting = false, kwargs...)
for (field, value) in kwargs
!hasfield(typeof(sys), field) && skip_nonexisting && continue
# like `Setfield.@set! sys.field = value`, but with `field` replaced by an arbitrarily named symbol
# (e.g. https://discourse.julialang.org/t/accessing-struct-via-symbol/58809/4)
sys = Setfield.set(sys, Setfield.PropertyLens{field}(), value)
Expand Down Expand Up @@ -2623,27 +2619,19 @@ function extend(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nam
end
end

return remake(sys; skip_nonexisting = true,
# fields provided directly to extend()
name, gui_metadata,

# fields common to all system types
eqs = union(get_eqs(basesys), get_eqs(sys)),
unknowns = union(get_unknowns(basesys), get_unknowns(sys)),
ps = union(get_ps(basesys), get_ps(sys)),
iv = isempty(ivs) ? missing : ivs[1], # TODO: handle generally
parameter_dependencies = union_nothing(get_parameter_dependencies(basesys), get_parameter_dependencies(sys)),
observed = union(get_observed(basesys), get_observed(sys)),
continuous_events = union(get_continuous_events(basesys), get_continuous_events(sys)),
discrete_events = union(get_discrete_events(basesys), get_discrete_events(sys)),
defaults = merge(get_defaults(basesys), get_defaults(sys)), # prefer `sys`,
systems = union(get_systems(basesys), get_systems(sys)),
metadata = union_nothing(get_metadata(basesys), get_metadata(sys)),

# fields specific to some system types
initialization_eqs = hasfield(T, :initialization_eqs) ? union(get_initialization_eqs(basesys), get_initialization_eqs(sys)) : missing,
guesses = hasfield(T, :guesses) ? merge(get_guesses(basesys), get_guesses(sys)) : missing, # prefer `sys`
)
# gracefully fields, being nice if one or both are nothing
ext(x, y) = y # prefer sys (y) over basesys (x)
ext(x, y::Nothing) = x
ext(x::Nothing, y) = ext(y, x)
ext(x::Nothing, y::Nothing) = nothing
ext(x::AbstractDict, y::AbstractDict) = merge(x, y) # prefer sys (y) over basesys (x)
ext(x::AbstractVector, y::AbstractVector) = union(x, y)
ext(field::Symbol) = ext(getfield(sys, field), getfield(basesys, field)) # TODO: use get_...?

# both systems were individually sanity-checked upon construction,
# so it should be fine to merge their fields without further checking
kwargs = Dict(field => ext(field) for field in fieldnames(T))
return remake(sys; kwargs..., name, gui_metadata)
end

function Base.:(&)(sys::AbstractSystem, basesys::AbstractSystem; name::Symbol = nameof(sys))
Expand Down
11 changes: 0 additions & 11 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
"""
union_nothing(x::Union{T1, Nothing}, y::Union{T2, Nothing}) where {T1, T2}
Unite x and y gracefully when they could be nothing. If neither is nothing, x and y are united normally. If one is nothing, the other is returned unmodified. If both are nothing, nothing is returned.
"""
function union_nothing(x::Union{T1, Nothing}, y::Union{T2, Nothing}) where {T1, T2}
isnothing(x) && return y # y can be nothing or something
isnothing(y) && return x # x can be nothing or something
return union(x, y) # both x and y are something and can be united normally
end

get_iv(D::Differential) = D.x

function make_operation(@nospecialize(op), args)
Expand Down

0 comments on commit 319d890

Please sign in to comment.