Skip to content

Commit

Permalink
fix: fix hierarchical discrete systems (#2593)
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal authored Apr 1, 2024
1 parent 7fb1d99 commit ab61554
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 0 deletions.
16 changes: 16 additions & 0 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,22 @@ function DiscreteSystem(eqs, iv; kwargs...)
collect(allunknowns), collect(new_ps); kwargs...)
end

function flatten(sys::DiscreteSystem, noeqs = false)
systems = get_systems(sys)
if isempty(systems)
return sys
else
return DiscreteSystem(noeqs ? Equation[] : equations(sys),
get_iv(sys),
unknowns(sys),
parameters(sys),
observed = observed(sys),
defaults = defaults(sys),
name = nameof(sys),
checks = false)
end
end

function generate_function(
sys::DiscreteSystem, dvs = unknowns(sys), ps = full_parameters(sys); kwargs...)
generate_custom_function(sys, [eq.rhs for eq in equations(sys)], dvs, ps; kwargs...)
Expand Down
31 changes: 31 additions & 0 deletions test/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,34 @@ prob = DiscreteProblem(de, [], (0, 10))
prob = DiscreteProblem(de, [x(k - 1) => 2.0], (0, 10))
@test prob[x] == 3.0
@test prob[x(k - 1)] == 2.0

# Issue#2585
getdata(buffer, t) = buffer[mod1(Int(t), length(buffer))]
@register_symbolic getdata(buffer::Vector, t)
k = ShiftIndex(t)
function SampledData(; name, buffer)
L = length(buffer)
pars = @parameters begin
buffer[1:L] = buffer
end
@variables output(t) time(t)
eqs = [time ~ time(k - 1) + 1
output ~ getdata(buffer, time)]
return DiscreteSystem(eqs, t; name)
end
function System(; name, buffer)
@named y_sys = SampledData(; buffer = buffer)
pars = @parameters begin
α = 0.5, [description = "alpha"]
β = 0.5, [description = "beta"]
end
vars = @variables y(t)=0.0 y_shk(t)=0.0

eqs = [y_shk ~ y_sys.output
# y[t] = 0.5 * y[t - 1] + 0.5 * y[t + 1] + y_shk[t]
y(k - 1) ~ α * y(k - 2) +* y(k) + y_shk(k - 1))]

DiscreteSystem(eqs, t, vars, pars; systems = [y_sys], name = name)
end

@test_nowarn @mtkbuild sys = System(; buffer = ones(10))

0 comments on commit ab61554

Please sign in to comment.