Skip to content

Commit

Permalink
feat: add support for conditional parameters and variables
Browse files Browse the repository at this point in the history
  • Loading branch information
ven-k committed Oct 23, 2023
1 parent 97186f6 commit 6f59697
Show file tree
Hide file tree
Showing 2 changed files with 197 additions and 85 deletions.
150 changes: 127 additions & 23 deletions src/systems/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,17 @@ function _model_macro(mod, name, expr, isconnector)
exprs = Expr(:block)
dict = Dict{Symbol, Any}()
dict[:kwargs] = Dict{Symbol, Any}()
dict[:parameters] = Any[Dict{Symbol, Dict{Symbol, Any}}()]
dict[:variables] = Any[Dict{Symbol, Dict{Symbol, Any}}()]
comps = Symbol[]
ext = Ref{Any}(nothing)
eqs = Expr[]
icon = Ref{Union{String, URI}}()
ps, sps, vs, = [], [], []
kwargs = Set()

push!(exprs.args, :(variables = []))
push!(exprs.args, :(parameters = []))
push!(exprs.args, :(systems = ODESystem[]))
push!(exprs.args, :(equations = Equation[]))

Expand All @@ -59,7 +63,24 @@ function _model_macro(mod, name, expr, isconnector)
Expr(:if, condition, x) => begin
component_blk, equations_blk, parameter_blk, variable_blk = parse_top_level_branch(condition,
x.args)

parameter_blk !== nothing && parse_variables!(exprs.args,
ps,
dict,
mod,
:(begin
$parameter_blk
end),
:parameters,
kwargs)
variable_blk !== nothing && parse_variables!(exprs.args,
ps,
dict,
mod,
:(begin
$variable_blk
end),
:variables,
kwargs)
component_blk !== nothing &&
parse_components!(exprs.args,
comps,
Expand All @@ -72,14 +93,30 @@ function _model_macro(mod, name, expr, isconnector)
parse_equations!(exprs.args, eqs, dict, :(begin
$equations_blk
end))
# parameter_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $parameter_blk end), :parameters, kwargs)
# variable_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $variable_blk end), :variables, kwargs)
end
Expr(:if, condition, x, y) => begin
component_blk, equations_blk, parameter_blk, variable_blk = parse_top_level_branch(condition,
x.args,
y)

parameter_blk !== nothing && parse_variables!(exprs.args,
ps,
dict,
mod,
:(begin
$parameter_blk
end),
:parameters,
kwargs)
variable_blk !== nothing && parse_variables!(exprs.args,
ps,
dict,
mod,
:(begin
$variable_blk
end),
:variables,
kwargs)
component_blk !== nothing &&
parse_components!(exprs.args,
comps, dict, :(begin
Expand All @@ -89,8 +126,6 @@ function _model_macro(mod, name, expr, isconnector)
parse_equations!(exprs.args, eqs, dict, :(begin
$equations_blk
end))
# parameter_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $parameter_blk end), :parameters, kwargs)
# variable_blk !== nothing && parse_variables!(exprs.args, ps, dict, mod, :(begin $variable_blk end), :variables, kwargs)
end
_ => error("Got an invalid argument: $arg")
end
Expand All @@ -108,13 +143,15 @@ function _model_macro(mod, name, expr, isconnector)
iv = dict[:independent_variable] = variable(:t)
end

push!(exprs.args, :(push!(systems, $(comps...))))
push!(exprs.args, :(push!(equations, $(eqs...))))
push!(exprs.args, :(push!(parameters, $(ps...))))
push!(exprs.args, :(push!(systems, $(comps...))))
push!(exprs.args, :(push!(variables, $(vs...))))

gui_metadata = isassigned(icon) > 0 ? GUIMetadata(GlobalRef(mod, name), icon[]) :
GUIMetadata(GlobalRef(mod, name))

sys = :($ODESystem($Equation[equations...], $iv, [$(vs...)], [$(ps...)];
sys = :($ODESystem($Equation[equations...], $iv, variables, parameters;
name, systems, gui_metadata = $gui_metadata))

if ext[] === nothing
Expand Down Expand Up @@ -165,12 +202,12 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing)
Base.remove_linenums!(b)
def, meta = parse_default(mod, b)
var, _ = parse_variable_def!(dict, mod, a, varclass, kwargs, def)
dict[varclass][getname(var)][:default] = def
dict[varclass][1][getname(var)][:default] = def
if meta !== nothing
for (type, key) in metatypes
if (mt = get(meta, key, nothing)) !== nothing
key == VariableConnectType && (mt = nameof(mt))
dict[varclass][getname(var)][type] = mt
dict[varclass][1][getname(var)][type] = mt
end
end
var = set_var_metadata(var, meta)
Expand All @@ -184,13 +221,19 @@ function parse_variable_def!(dict, mod, arg, varclass, kwargs, def = nothing)
for (type, key) in metatypes
if (mt = get(meta, key, nothing)) !== nothing
key == VariableConnectType && (mt = nameof(mt))
dict[varclass][getname(var)][type] = mt
dict[varclass][1][getname(var)][type] = mt
end
end
var = set_var_metadata(var, meta)
end
(set_var_metadata(var, meta), def)
end
#= Expr(:if, condition, a) => begin
var, def = [], []
for var_def in a.args
parse_variable_def!(dict, mod, var_def, varclass, kwargs)
end
end =#
_ => error("$arg cannot be parsed")
end
end
Expand All @@ -205,9 +248,7 @@ end

function generate_var!(dict, a, varclass)
var = generate_var(a, varclass)
vd = get!(dict, varclass) do
Dict{Symbol, Dict{Symbol, Any}}()
end
vd = first(dict[varclass])
vd[a] = Dict{Symbol, Any}()
var
end
Expand All @@ -218,9 +259,7 @@ function generate_var!(dict, a, b, varclass)
iv
end
@assert isequal(iv, prev_iv)
vd = get!(dict, varclass) do
Dict{Symbol, Dict{Symbol, Any}}()
end
vd = first(dict[varclass])
vd[a] = Dict{Symbol, Any}()
var = Symbolics.variable(a, T = SymbolicUtils.FnType{Tuple{Real}, Real})(iv)
if varclass == :parameters
Expand Down Expand Up @@ -416,22 +455,87 @@ function parse_extend!(exprs, ext, dict, mod, body, kwargs)
return nothing
end

function parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs)
function parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs)
name, ex = parse_variable_arg(dict, mod, arg, varclass, kwargs)
push!(vs, name)
push!(exprs, ex)
end

function parse_variable_arg(dict, mod, arg, varclass, kwargs)
vv, def = parse_variable_def!(dict, mod, arg, varclass, kwargs)
v = Num(vv)
name = getname(v)
push!(vs, name)
push!(expr.args,
:($name = $name === nothing ? $setdefault($vv, $def) : $setdefault($vv, $name)))
# push!(vs, name)
return name,
:($name = $name === nothing ? $setdefault($vv, $def) : $setdefault($vv, $name))
end

function handle_conditional_vars!(arg, conditional_branch, mod, varclass, kwargs)
conditional_dict = Dict(:kwargs => Dict(),
:parameters => Any[Dict{Symbol, Dict{Symbol, Any}}()],
:variables => Any[Dict{Symbol, Dict{Symbol, Any}}()])
for _arg in arg.args
name, ex = parse_variable_arg(conditional_dict, mod, _arg, varclass, kwargs)
push!(conditional_branch.args, ex)
push!(conditional_branch.args, :(push!($varclass, $name)))
end
conditional_dict
end

function parse_variables!(exprs, vs, dict, mod, body, varclass, kwargs)
expr = Expr(:block)
push!(exprs, expr)
for arg in body.args
arg isa LineNumberNode && continue
parse_variable_arg!(expr, vs, dict, mod, arg, varclass, kwargs)
MLStyle.@match arg begin
Expr(:if, condition, x) => begin
conditional_expr = Expr(:if, condition, Expr(:block))
conditional_dict = handle_conditional_vars!(x,
conditional_expr.args[2],
mod,
varclass,
kwargs)
push!(expr.args, conditional_expr)
push!(dict[varclass], (:if, condition, conditional_dict, nothing))
end
Expr(:if, condition, x, y) => begin
conditional_expr = Expr(:if, condition, Expr(:block))
conditional_dict = handle_conditional_vars!(x,
conditional_expr.args[2],
mod,
varclass,
kwargs)
conditional_y_expr, conditional_y_dict = handle_y_vars(y,
conditional_dict,
mod,
varclass,
kwargs)
push!(conditional_expr.args, conditional_y_expr)
push!(expr.args, conditional_expr)
push!(dict[varclass],
(:if, condition, conditional_dict, conditional_y_dict))
end
_ => parse_variable_arg!(exprs, vs, dict, mod, arg, varclass, kwargs)
end
end
end

function handle_y_vars(y, dict, mod, varclass, kwargs)
conditional_dict = if Meta.isexpr(y, :elseif)
conditional_y_expr = Expr(:elseif, y.args[1], Expr(:block))
conditional_dict = handle_conditional_vars!(y.args[2],
conditional_y_expr.args[2],
mod,
varclass,
kwargs)
_y_expr, _conditional_dict = handle_y_vars(y.args[end], dict, mod, varclass, kwargs)
push!(conditional_y_expr.args, _y_expr)
(:elseif, y.args[1], conditional_dict, _conditional_dict)
else
conditional_y_expr = Expr(:block)
handle_conditional_vars!(y, conditional_y_expr, mod, varclass, kwargs)
end
conditional_y_expr, conditional_dict
end

function handle_if_x_equations!(ifexpr, condition, x)
Expand Down Expand Up @@ -550,7 +654,7 @@ function _parse_components!(exprs, body, kwargs)
arg isa LineNumberNode && continue
MLStyle.@match arg begin
Expr(:block) => begin
# TODO: Do we need this?
# TODO: Do we need this?
error("Multiple `@components` block detected within a single block")
end
Expr(:(=), a, b) => begin
Expand Down Expand Up @@ -634,7 +738,7 @@ function parse_components!(exprs, cs, dict, compbody, kwargs)
$(expr_vec.args...)
end))
end
_ => @info "410 Couldn't parse the component body $compbody" @__LINE__
_ => error("Couldn't parse the component body $compbody")
end
end
end
Expand Down
Loading

0 comments on commit 6f59697

Please sign in to comment.