Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for causal connections of variables #3304

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ include("systems/index_cache.jl")
include("systems/parameter_buffer.jl")
include("systems/abstractsystem.jl")
include("systems/model_parsing.jl")
include("systems/analysis_points.jl")
include("systems/connectors.jl")
include("systems/analysis_points.jl")
include("systems/imperative_affect.jl")
include("systems/callbacks.jl")
include("systems/problem_utils.jl")
Expand Down
9 changes: 8 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,13 @@ Equivalent to `length(equations(expand_connections(sys))) - length(filter(eq ->
function n_expanded_connection_equations(sys::AbstractSystem)
# TODO: what about inputs?
isconnector(sys) && return length(get_unknowns(sys))
sys = remove_analysis_points(sys)
n_variable_connect_eqs = 0
for eq in equations(sys)
is_causal_variable_connection(eq.rhs) || continue
n_variable_connect_eqs += length(get_systems(eq.rhs)) - 1
end

sys, (csets, _) = generate_connection_set(sys)
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
n_outer_stream_variables = 0
Expand All @@ -1897,7 +1904,7 @@ function n_expanded_connection_equations(sys::AbstractSystem)
# n_toplevel_unused_flows += count(x->get_connection_type(x) === Flow && !(x in toplevel_flows), get_unknowns(m))
#end

nextras = n_outer_stream_variables + length(ceqs)
nextras = n_outer_stream_variables + length(ceqs) + n_variable_connect_eqs
end

function Base.show(
Expand Down
17 changes: 16 additions & 1 deletion src/systems/analysis_points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
end

function Symbolics.connect(in::ConnectableSymbolicT, name::Symbol, out::ConnectableSymbolicT, outs::ConnectableSymbolicT...; verbose = true)
allvars = (in, out, outs...)
validate_causal_variables_connection(allvars)
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
end

"""
$(TYPEDSIGNATURES)

Expand Down Expand Up @@ -240,7 +246,7 @@
variable in the system. If the system does not have a variable named `u` and
contains multiple variables, throw an error.
"""
function ap_var(sys)
function ap_var(sys::AbstractSystem)
if hasproperty(sys, :u)
return sys.u
end
Expand All @@ -249,6 +255,15 @@
error("Could not determine the analysis-point variable in system $(nameof(sys)). To use an analysis point, apply it to a connection between causal blocks which have a variable named `u` or a single unknown of the same size.")
end

"""
$(TYPEDSIGNATURES)

For an `AnalysisPoint` involving causal variables. Simply return the variable.
"""
function ap_var(var::ConnectableSymbolicT)
return var
end

"""
$(TYPEDEF)

Expand Down Expand Up @@ -493,7 +508,7 @@
ap_idx = analysis_point_index(ap_sys, tf.ap)
ap_idx === nothing &&
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
# get the anlysis point

Check warning on line 511 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"anlysis" should be "analysis".
ap_sys_eqs = copy(get_eqs(ap_sys))
ap = ap_sys_eqs[ap_idx].rhs

Expand Down Expand Up @@ -547,7 +562,7 @@
ap_idx = analysis_point_index(ap_sys, tf.ap)
ap_idx === nothing &&
error("Analysis point $(nameof(tf.ap)) not found in system $(nameof(sys)).")
# modified quations

Check warning on line 565 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"quations" should be "equations".
ap_sys_eqs = copy(get_eqs(ap_sys))
@set! ap_sys.eqs = ap_sys_eqs
ap = ap_sys_eqs[ap_idx].rhs
Expand Down Expand Up @@ -859,7 +874,7 @@
# Keyword Arguments

- `system_modifier`: a function which takes the modified system and returns a new system
with any required further modifications peformed.

Check warning on line 877 in src/systems/analysis_points.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"peformed" should be "performed".
"""
function open_loop(sys, ap::Union{Symbol, AnalysisPoint}; system_modifier = identity)
ap = only(canonicalize_ap(sys, ap))
Expand Down
119 changes: 119 additions & 0 deletions src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,89 @@ SymbolicUtils.promote_symtype(::typeof(instream), _) = Real

isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing

"""
$(TYPEDEF)

Utility struct which wraps a symbolic variable used in a `Connection` to enable `Base.show`
to work.
"""
struct SymbolicWithNameof
var::Any
end

function Base.nameof(x::SymbolicWithNameof)
return Symbol(x.var)
end

is_causal_variable_connection(c) = false
is_causal_variable_connection(c::Connection) = all(x -> x isa SymbolicWithNameof, get_systems(c))

const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics.Arr}

const CAUSAL_CONNECTION_ERR = """
Only causal variables can be used in a `connect` statement. The first argument must \
be a single output variable and all subsequent variables must be input variables.
"""

function VariableNotOutputError(var)
ArgumentError("""
$CAUSAL_CONNECTION_ERR Expected $var to be marked as an output with `[output = true]` \
in the variable metadata.
""")
end

function VariableNotInputError(var)
ArgumentError("""
$CAUSAL_CONNECTION_ERR Expected $var to be marked an input with `[input = true]` \
in the variable metadata.
""")
end

"""
$(TYPEDSIGNATURES)

Perform validation for a connect statement involving causal variables.
"""
function validate_causal_variables_connection(allvars)
var1 = allvars[1]
var2 = allvars[2]
vars = Base.tail(Base.tail(allvars))
for var in allvars
vtype = getvariabletype(var)
vtype === VARIABLE || throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`."))
end
if length(unique(allvars)) !== length(allvars)
throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries."))
end
allsizes = map(size, allvars)
if !allequal(allsizes)
throw(ArgumentError("Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively."))
end
isoutput(var1) || throw(VariableNotOutputError(var1))
isinput(var2) || throw(VariableNotInputError(var2))
for var in vars
isinput(var) || throw(VariableNotInputError(var))
end
end

"""
$(TYPEDSIGNATURES)

Connect multiple causal variables. The first variable must be an output, and all subsequent
variables must be inputs. The statement `connect(var1, var2, var3, ...)` expands to:

```julia
var1 ~ var2
var1 ~ var3
# ...
```
"""
function Symbolics.connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolicT, vars::ConnectableSymbolicT...)
allvars = (var1, var2, vars...)
validate_causal_variables_connection(allvars)
return Equation(Connection(), Connection(map(SymbolicWithNameof, allvars)))
end

function flowvar(sys::AbstractSystem)
sts = get_unknowns(sys)
for s in sts
Expand Down Expand Up @@ -329,6 +412,10 @@ function generate_connection_set!(connectionsets, domain_csets,
for eq in eqs′
lhs = eq.lhs
rhs = eq.rhs

# causal variable connections will be expanded before we get here,
# but this guard is useful for `n_expanded_connection_equations`.
is_causal_variable_connection(rhs) && continue
if find !== nothing && find(rhs, _getname(namespace))
neweq, extra_unknown = replace(rhs, _getname(namespace))
if extra_unknown isa AbstractArray
Expand Down Expand Up @@ -479,9 +566,41 @@ function domain_defaults(sys, domain_csets)
def
end

"""
$(TYPEDSIGNATURES)

Recursively descend through the hierarchy of `sys` and expand all connection equations
of causal variables. Return the modified system.
"""
function expand_variable_connections(sys::AbstractSystem)
eqs = copy(get_eqs(sys))
valid_idxs = trues(length(eqs))
additional_eqs = Equation[]

for (i, eq) in enumerate(eqs)
eq.lhs isa Connection || continue
connection = eq.rhs
elements = connection.systems
is_causal_variable_connection(connection) || continue

valid_idxs[i] = false
elements = map(x -> x.var, elements)
outvar = first(elements)
for invar in Iterators.drop(elements, 1)
push!(additional_eqs, outvar ~ invar)
end
end
eqs = [eqs[valid_idxs]; additional_eqs]
subsystems = map(expand_variable_connections, get_systems(sys))
@set! sys.eqs = eqs
@set! sys.systems = subsystems
return sys
end

function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing;
debug = false, tol = 1e-10, scalarize = true)
sys = remove_analysis_points(sys)
sys = expand_variable_connections(sys)
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize)
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
_sys = expand_instream(instream_csets, sys; debug = debug, tol = tol)
Expand Down
98 changes: 98 additions & 0 deletions test/causal_variables_connection.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
using ModelingToolkit, ModelingToolkitStandardLibrary.Blocks
using ModelingToolkit: t_nounits as t, D_nounits as D

@testset "Error checking" begin
@variables begin
x(t)
y(t), [input = true]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧌 😎

z(t), [output = true]
w(t)
v(t), [input = true]
u(t), [output = true]
xarr(t)[1:4], [output = true]
yarr(t)[1:2, 1:2], [input = true]
end
@parameters begin
p, [input = true]
q, [output = true]
end

@test_throws ["p", "kind", "VARIABLE", "PARAMETER"] connect(z, p)
@test_throws ["q", "kind", "VARIABLE", "PARAMETER"] connect(q, y)
@test_throws ["p", "kind", "VARIABLE", "PARAMETER"] connect(z, y, p)

@test_throws ["unique"] connect(z, y, y)

@test_throws ["same size"] connect(xarr, yarr)

@test_throws ["Expected", "x", "output = true", "metadata"] connect(x, y)
@test_throws ["Expected", "y", "output = true", "metadata"] connect(y, v)

@test_throws ["Expected", "x", "input = true", "metadata"] connect(z, x)
@test_throws ["Expected", "x", "input = true", "metadata"] connect(z, y, x)
@test_throws ["Expected", "u", "input = true", "metadata"] connect(z, u)
@test_throws ["Expected", "u", "input = true", "metadata"] connect(z, y, u)
end

@testset "Connection expansion" begin
@named P = FirstOrder(k = 1, T = 1)
@named C = Gain(; k = -1)

eqs = [connect(P.output.u, C.input.u)
connect(C.output.u, P.input.u)]
sys1 = ODESystem(eqs, t, systems = [P, C], name = :hej)
sys = expand_connections(sys1)
@test any(isequal(P.output.u ~ C.input.u), equations(sys))
@test any(isequal(C.output.u ~ P.input.u), equations(sys))

@named sysouter = ODESystem(Equation[], t; systems = [sys1])
sys = expand_connections(sysouter)
@test any(isequal(sys1.P.output.u ~ sys1.C.input.u), equations(sys))
@test any(isequal(sys1.C.output.u ~ sys1.P.input.u), equations(sys))
end

@testset "With Analysis Points" begin
@named P = FirstOrder(k = 1, T = 1)
@named C = Gain(; k = -1)

ap = AnalysisPoint(:plant_input)
eqs = [connect(P.output, C.input), connect(C.output.u, ap, P.input.u)]
sys = ODESystem(eqs, t, systems = [P, C], name = :hej)
@named nested_sys = ODESystem(Equation[], t; systems = [sys])

test_cases = [
("inner", sys, sys.plant_input),
("nested", nested_sys, nested_sys.hej.plant_input),
("inner - Symbol", sys, :plant_input),
("nested - Symbol", nested_sys, nameof(sys.plant_input))
]

@testset "get_sensitivity - $name" for (name, sys, ap) in test_cases
matrices, _ = get_sensitivity(sys, ap)
@test matrices.A[] == -2
@test matrices.B[] * matrices.C[] == -1 # either one negative
@test matrices.D[] == 1
end

@testset "get_comp_sensitivity - $name" for (name, sys, ap) in test_cases
matrices, _ = get_comp_sensitivity(sys, ap)
@test matrices.A[] == -2
@test matrices.B[] * matrices.C[] == 1 # both positive or negative
@test matrices.D[] == 0
end

@testset "get_looptransfer - $name" for (name, sys, ap) in test_cases
matrices, _ = get_looptransfer(sys, ap)
@test matrices.A[] == -1
@test matrices.B[] * matrices.C[] == -1 # either one negative
@test matrices.D[] == 0
end

@testset "open_loop - $name" for (name, sys, ap) in test_cases
open_sys, (du, u) = open_loop(sys, ap)
matrices, _ = linearize(open_sys, [du], [u])
@test matrices.A[] == -1
@test matrices.B[] * matrices.C[] == -1 # either one negative
@test matrices.D[] == 0
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ end
@safetestset "Constraints Test" include("constraints.jl")
@safetestset "IfLifting Test" include("if_lifting.jl")
@safetestset "Analysis Points Test" include("analysis_points.jl")
@safetestset "Causal Variables Connection Test" include("causal_variables_connection.jl")
end
end

Expand Down
Loading