Skip to content

Commit

Permalink
feat: support array of components in @mtkmodel
Browse files Browse the repository at this point in the history
- for loop or a list comprehension can be used to declare component arrays
  • Loading branch information
ven-k committed Dec 5, 2023
1 parent cf282d2 commit 89e99f7
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 48 deletions.
7 changes: 5 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ function _named(name, call, runtime = false)
end
end

function _named_idxs(name::Symbol, idxs, call)
function _named_idxs(name::Symbol, idxs, call; extra_args = "")
if call.head !== :->
throw(ArgumentError("Not an anonymous function"))
end
Expand All @@ -1015,7 +1015,10 @@ function _named_idxs(name::Symbol, idxs, call)
ex = Base.Cartesian.poplinenum(ex)
ex = _named(:(Symbol($(Meta.quot(name)), :_, $sym)), ex, true)
ex = Base.Cartesian.poplinenum(ex)
:($name = $map($sym -> $ex, $idxs))
:($name = map($sym -> begin
$extra_args
$ex
end, $idxs))
end

function single_named_expr(expr)
Expand Down
132 changes: 86 additions & 46 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function _model_macro(mod, name, expr, isconnector)
exprs = Expr(:block)
dict = Dict{Symbol, Any}()
dict[:kwargs] = Dict{Symbol, Any}()
comps = Symbol[]
comps = Union{Symbol, Expr}[]
ext = Ref{Any}(nothing)
eqs = Expr[]
icon = Ref{Union{String, URI}}()
Expand Down Expand Up @@ -639,7 +639,7 @@ end

### Parsing Components:

function component_args!(a, b, expr, varexpr, kwargs)
function component_args!(a, b, varexpr, kwargs; index_name = nothing)
# Whenever `b` is a function call, skip the first arg aka the function name.
# Whenever it is a kwargs list, include it.
start = b.head == :call ? 2 : 1
Expand All @@ -648,71 +648,113 @@ function component_args!(a, b, expr, varexpr, kwargs)
arg isa LineNumberNode && continue
MLStyle.@match arg begin
x::Symbol || Expr(:kw, x) => begin
_v = _rename(a, x)
b.args[i] = Expr(:kw, x, _v)
push!(varexpr.args, :((@isdefined $x) && ($_v = $x)))
push!(kwargs, Expr(:kw, _v, nothing))
# dict[:kwargs][_v] = nothing
varname, _varname = _rename(a, x)
b.args[i] = Expr(:kw, x, _varname)
push!(varexpr.args, :((if $varname !== nothing
$_varname = $varname
elseif @isdefined $x
# Allow users to define a var in `structural_parameters` and set
# that as positional arg of subcomponents; it is useful for cases
# where it needs to be passed to multiple subcomponents.
$_varname = $x
end)))
push!(kwargs, Expr(:kw, varname, nothing))
# dict[:kwargs][varname] = nothing
end
Expr(:parameters, x...) => begin
component_args!(a, arg, expr, varexpr, kwargs)
component_args!(a, arg, varexpr, kwargs)
end
Expr(:kw, x, y) => begin
_v = _rename(a, x)
b.args[i] = Expr(:kw, x, _v)
push!(varexpr.args, :($_v = $_v === nothing ? $y : $_v))
push!(kwargs, Expr(:kw, _v, nothing))
# dict[:kwargs][_v] = nothing
varname, _varname = _rename(a, x)
b.args[i] = Expr(:kw, x, _varname)
if isnothing(index_name)
push!(varexpr.args, :($_varname = $varname === nothing ? $y : $varname))
else
push!(varexpr.args,
:($_varname = $varname === nothing ? $y : $varname[$index_name]))
end
push!(kwargs, Expr(:kw, varname, nothing))
# dict[:kwargs][varname] = nothing
end
_ => error("Could not parse $arg of component $a")
end
end
end

function _parse_components!(exprs, body, kwargs)
expr = Expr(:block)
model_name(name, range) = Symbol.(name, :_, collect(range))

function _parse_components!(body, kwargs)
local expr
varexpr = Expr(:block)
# push!(exprs, varexpr)
comps = Vector{Symbol}[]
comps = Vector{Union{Expr, Symbol}}[]
comp_names = []

for arg in body.args
arg isa LineNumberNode && continue
MLStyle.@match arg begin
Expr(:block) => begin
# TODO: Do we need this?
error("Multiple `@components` block detected within a single block")
end
Expr(:(=), a, b) => begin
arg = deepcopy(arg)
b = deepcopy(arg.args[2])
Base.remove_linenums!(body)
arg = body.args[end]

component_args!(a, b, expr, varexpr, kwargs)
MLStyle.@match arg begin
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d)))) => begin
array_varexpr = Expr(:block)

arg.args[2] = b
push!(expr.args, arg)
push!(comp_names, a)
push!(comps, [a, b.args[1]])
end
_ => error("Couldn't parse the component body: $arg")
push!(comp_names, :($a...))
push!(comps, [a, b.args[1], d])
b = deepcopy(b)

component_args!(a, b, array_varexpr, kwargs; index_name = c)

expr = _named_idxs(a, d, :($c -> $b); extra_args = array_varexpr)
end
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:filter, e, Expr(:(=), c, d))))) => begin
error("List comprehensions with conditional statements aren't supported.")
end
Expr(:(=), a, Expr(:comprehension, Expr(:generator, b, Expr(:(=), c, d), e...))) => begin
# Note that `e` is of the form `Tuple{Expr(:(=), c, d)}`
error("More than one index isn't supported while building component array")
end
Expr(:block) => begin
# TODO: Do we need this?
error("Multiple `@components` block detected within a single block")
end
Expr(:(=), a, Expr(:for, Expr(:(=), c, d), b)) => begin
Base.remove_linenums!(b)
array_varexpr = Expr(:block)
push!(array_varexpr.args, b.args[1:(end - 1)]...)
push!(comp_names, :($a...))
push!(comps, [a, b.args[end].args[1], d])
b = deepcopy(b)

component_args!(a, b.args[end], array_varexpr, kwargs; index_name = c)

expr = _named_idxs(a, d, :($c -> $(b.args[end])); extra_args = array_varexpr)
end
Expr(:(=), a, b) => begin
arg = deepcopy(arg)
b = deepcopy(arg.args[2])

component_args!(a, b, varexpr, kwargs)

arg.args[2] = b
expr = :(@named $arg)
push!(comp_names, a)
push!(comps, [a, b.args[1]])
end
_ => error("Couldn't parse the component body: $arg")
end

return comp_names, comps, expr, varexpr
end

function push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
blk = Expr(:block)
push!(blk.args, varexpr)
push!(blk.args, :(@named begin
$(expr_vec.args...)
end))
push!(blk.args, expr_vec)
push!(blk.args, :($push!(systems, $(comp_names...))))
push!(ifexpr.args, blk)
end

function handle_if_x!(mod, exprs, ifexpr, x, kwargs, condition = nothing)
push!(ifexpr.args, condition)
comp_names, comps, expr_vec, varexpr = _parse_components!(ifexpr, x, kwargs)
comp_names, comps, expr_vec, varexpr = _parse_components!(x, kwargs)
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
comps
end
Expand All @@ -728,7 +770,7 @@ function handle_if_y!(exprs, ifexpr, y, kwargs)
push!(ifexpr.args, elseifexpr)
(comps...,)
else
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs, y, kwargs)
comp_names, comps, expr_vec, varexpr = _parse_components!(y, kwargs)
push_conditional_component!(ifexpr, expr_vec, comp_names, varexpr)
comps
end
Expand All @@ -753,25 +795,23 @@ function parse_components!(exprs, cs, dict, compbody, kwargs)
Expr(:if, condition, x, y) => begin
handle_conditional_components(condition, dict, exprs, kwargs, x, y)
end
Expr(:(=), a, b) => begin
comp_names, comps, expr_vec, varexpr = _parse_components!(exprs,
:(begin
# Either the arg is top level component declaration or an invalid cause - both are handled by `_parse_components`
_ => begin
comp_names, comps, expr_vec, varexpr = _parse_components!(:(begin
$arg
end),
kwargs)
push!(cs, comp_names...)
push!(dict[:components], comps...)
push!(exprs, varexpr, :(@named begin
$(expr_vec.args...)
end))
push!(exprs, varexpr, expr_vec)
end
_ => error("Couldn't parse the component body $compbody")
end
end
end

function _rename(compname, varname)
compname = Symbol(compname, :__, varname)
(compname, Symbol(:_, compname))
end

# Handle top level branching
Expand Down
38 changes: 38 additions & 0 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -508,3 +508,41 @@ end
@test Equation[ternary_true.ternary_parameter_true ~ 0] == equations(ternary_true)
@test Equation[ternary_false.ternary_parameter_false ~ 0] == equations(ternary_false)
end

@testset "Component array" begin
@mtkmodel SubComponent begin
@parameters begin
sc
end
end

@mtkmodel Component begin
@structural_parameters begin
N = 2
end
@components begin
comprehension = [SubComponent(sc = i) for i in 1:N]
written_out_for = for i in 1:N
sc = i + 1
SubComponent(; sc)
end
single_sub_component = SubComponent()
end
end

@named component = Component()
component = complete(component)

@test nameof.(ModelingToolkit.get_systems(component)) == [
:comprehension_1,
:comprehension_2,
:written_out_for_1,
:written_out_for_2,
:single_sub_component,
]

@test getdefault(component.comprehension_1.sc) == 1
@test getdefault(component.comprehension_2.sc) == 2
@test getdefault(component.written_out_for_1.sc) == 2
@test getdefault(component.written_out_for_2.sc) == 3
end

0 comments on commit 89e99f7

Please sign in to comment.