diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 7462122998..b63112bde1 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -145,8 +145,8 @@ include("systems/index_cache.jl") include("systems/parameter_buffer.jl") include("systems/abstractsystem.jl") include("systems/model_parsing.jl") -include("systems/analysis_points.jl") include("systems/connectors.jl") +include("systems/analysis_points.jl") include("systems/imperative_affect.jl") include("systems/callbacks.jl") include("systems/problem_utils.jl") diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index ccc749baeb..3dc9cda3ba 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1875,6 +1875,13 @@ Equivalent to `length(equations(expand_connections(sys))) - length(filter(eq -> function n_expanded_connection_equations(sys::AbstractSystem) # TODO: what about inputs? isconnector(sys) && return length(get_unknowns(sys)) + sys = remove_analysis_points(sys) + n_variable_connect_eqs = 0 + for eq in equations(sys) + is_causal_variable_connection(eq.rhs) || continue + n_variable_connect_eqs += length(get_systems(eq.rhs)) - 1 + end + sys, (csets, _) = generate_connection_set(sys) ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets) n_outer_stream_variables = 0 @@ -1897,7 +1904,7 @@ function n_expanded_connection_equations(sys::AbstractSystem) # n_toplevel_unused_flows += count(x->get_connection_type(x) === Flow && !(x in toplevel_flows), get_unknowns(m)) #end - nextras = n_outer_stream_variables + length(ceqs) + nextras = n_outer_stream_variables + length(ceqs) + n_variable_connect_eqs end function Base.show( diff --git a/src/systems/analysis_points.jl b/src/systems/analysis_points.jl index 87f278d4dc..d1080c083b 100644 --- a/src/systems/analysis_points.jl +++ b/src/systems/analysis_points.jl @@ -208,6 +208,12 @@ function Symbolics.connect(in::AbstractSystem, name::Symbol, out, outs...; verbo return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose) end +function Symbolics.connect(in::ConnectableSymbolicT, name::Symbol, out::ConnectableSymbolicT, outs::ConnectableSymbolicT...; verbose = true) + allvars = (in, out, outs...) + validate_causal_variables_connection(allvars) + return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose) +end + """ $(TYPEDSIGNATURES) @@ -240,7 +246,7 @@ connection. This is the variable named `u` if present, and otherwise the only variable in the system. If the system does not have a variable named `u` and contains multiple variables, throw an error. """ -function ap_var(sys) +function ap_var(sys::AbstractSystem) if hasproperty(sys, :u) return sys.u end @@ -249,6 +255,15 @@ function ap_var(sys) error("Could not determine the analysis-point variable in system $(nameof(sys)). To use an analysis point, apply it to a connection between causal blocks which have a variable named `u` or a single unknown of the same size.") end +""" + $(TYPEDSIGNATURES) + +For an `AnalysisPoint` involving causal variables. Simply return the variable. +""" +function ap_var(var::ConnectableSymbolicT) + return var +end + """ $(TYPEDEF) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index f0f49acb85..3a2ffa9a9d 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -68,6 +68,89 @@ SymbolicUtils.promote_symtype(::typeof(instream), _) = Real isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing +""" + $(TYPEDEF) + +Utility struct which wraps a symbolic variable used in a `Connection` to enable `Base.show` +to work. +""" +struct SymbolicWithNameof + var::Any +end + +function Base.nameof(x::SymbolicWithNameof) + return Symbol(x.var) +end + +is_causal_variable_connection(c) = false +is_causal_variable_connection(c::Connection) = all(x -> x isa SymbolicWithNameof, get_systems(c)) + +const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics.Arr} + +const CAUSAL_CONNECTION_ERR = """ +Only causal variables can be used in a `connect` statement. The first argument must \ +be a single output variable and all subsequent variables must be input variables. +""" + +function VariableNotOutputError(var) + ArgumentError(""" + $CAUSAL_CONNECTION_ERR Expected $var to be marked as an output with `[output = true]` \ + in the variable metadata. + """) +end + +function VariableNotInputError(var) + ArgumentError(""" + $CAUSAL_CONNECTION_ERR Expected $var to be marked an input with `[input = true]` \ + in the variable metadata. + """) +end + +""" + $(TYPEDSIGNATURES) + +Perform validation for a connect statement involving causal variables. +""" +function validate_causal_variables_connection(allvars) + var1 = allvars[1] + var2 = allvars[2] + vars = Base.tail(Base.tail(allvars)) + for var in allvars + vtype = getvariabletype(var) + vtype === VARIABLE || throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`.")) + end + if length(unique(allvars)) !== length(allvars) + throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries.")) + end + allsizes = map(size, allvars) + if !allequal(allsizes) + throw(ArgumentError("Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively.")) + end + isoutput(var1) || throw(VariableNotOutputError(var1)) + isinput(var2) || throw(VariableNotInputError(var2)) + for var in vars + isinput(var) || throw(VariableNotInputError(var)) + end +end + +""" + $(TYPEDSIGNATURES) + +Connect multiple causal variables. The first variable must be an output, and all subsequent +variables must be inputs. The statement `connect(var1, var2, var3, ...)` expands to: + +```julia +var1 ~ var2 +var1 ~ var3 +# ... +``` +""" +function Symbolics.connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolicT, vars::ConnectableSymbolicT...) + allvars = (var1, var2, vars...) + validate_causal_variables_connection(allvars) + return Equation(Connection(), Connection(map(SymbolicWithNameof, allvars))) +end + function flowvar(sys::AbstractSystem) sts = get_unknowns(sys) for s in sts @@ -329,6 +412,10 @@ function generate_connection_set!(connectionsets, domain_csets, for eq in eqs′ lhs = eq.lhs rhs = eq.rhs + + # causal variable connections will be expanded before we get here, + # but this guard is useful for `n_expanded_connection_equations`. + is_causal_variable_connection(rhs) && continue if find !== nothing && find(rhs, _getname(namespace)) neweq, extra_unknown = replace(rhs, _getname(namespace)) if extra_unknown isa AbstractArray @@ -479,9 +566,41 @@ function domain_defaults(sys, domain_csets) def end +""" + $(TYPEDSIGNATURES) + +Recursively descend through the hierarchy of `sys` and expand all connection equations +of causal variables. Return the modified system. +""" +function expand_variable_connections(sys::AbstractSystem) + eqs = copy(get_eqs(sys)) + valid_idxs = trues(length(eqs)) + additional_eqs = Equation[] + + for (i, eq) in enumerate(eqs) + eq.lhs isa Connection || continue + connection = eq.rhs + elements = connection.systems + is_causal_variable_connection(connection) || continue + + valid_idxs[i] = false + elements = map(x -> x.var, elements) + outvar = first(elements) + for invar in Iterators.drop(elements, 1) + push!(additional_eqs, outvar ~ invar) + end + end + eqs = [eqs[valid_idxs]; additional_eqs] + subsystems = map(expand_variable_connections, get_systems(sys)) + @set! sys.eqs = eqs + @set! sys.systems = subsystems + return sys +end + function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing; debug = false, tol = 1e-10, scalarize = true) sys = remove_analysis_points(sys) + sys = expand_variable_connections(sys) sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize) ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets) _sys = expand_instream(instream_csets, sys; debug = debug, tol = tol) diff --git a/test/causal_variables_connection.jl b/test/causal_variables_connection.jl new file mode 100644 index 0000000000..c22e8319d4 --- /dev/null +++ b/test/causal_variables_connection.jl @@ -0,0 +1,98 @@ +using ModelingToolkit, ModelingToolkitStandardLibrary.Blocks +using ModelingToolkit: t_nounits as t, D_nounits as D + +@testset "Error checking" begin + @variables begin + x(t) + y(t), [input = true] + z(t), [output = true] + w(t) + v(t), [input = true] + u(t), [output = true] + xarr(t)[1:4], [output = true] + yarr(t)[1:2, 1:2], [input = true] + end + @parameters begin + p, [input = true] + q, [output = true] + end + + @test_throws ["p", "kind", "VARIABLE", "PARAMETER"] connect(z, p) + @test_throws ["q", "kind", "VARIABLE", "PARAMETER"] connect(q, y) + @test_throws ["p", "kind", "VARIABLE", "PARAMETER"] connect(z, y, p) + + @test_throws ["unique"] connect(z, y, y) + + @test_throws ["same size"] connect(xarr, yarr) + + @test_throws ["Expected", "x", "output = true", "metadata"] connect(x, y) + @test_throws ["Expected", "y", "output = true", "metadata"] connect(y, v) + + @test_throws ["Expected", "x", "input = true", "metadata"] connect(z, x) + @test_throws ["Expected", "x", "input = true", "metadata"] connect(z, y, x) + @test_throws ["Expected", "u", "input = true", "metadata"] connect(z, u) + @test_throws ["Expected", "u", "input = true", "metadata"] connect(z, y, u) +end + +@testset "Connection expansion" begin + @named P = FirstOrder(k = 1, T = 1) + @named C = Gain(; k = -1) + + eqs = [connect(P.output.u, C.input.u) + connect(C.output.u, P.input.u)] + sys1 = ODESystem(eqs, t, systems = [P, C], name = :hej) + sys = expand_connections(sys1) + @test any(isequal(P.output.u ~ C.input.u), equations(sys)) + @test any(isequal(C.output.u ~ P.input.u), equations(sys)) + + @named sysouter = ODESystem(Equation[], t; systems = [sys1]) + sys = expand_connections(sysouter) + @test any(isequal(sys1.P.output.u ~ sys1.C.input.u), equations(sys)) + @test any(isequal(sys1.C.output.u ~ sys1.P.input.u), equations(sys)) +end + +@testset "With Analysis Points" begin + @named P = FirstOrder(k = 1, T = 1) + @named C = Gain(; k = -1) + + ap = AnalysisPoint(:plant_input) + eqs = [connect(P.output, C.input), connect(C.output.u, ap, P.input.u)] + sys = ODESystem(eqs, t, systems = [P, C], name = :hej) + @named nested_sys = ODESystem(Equation[], t; systems = [sys]) + + test_cases = [ + ("inner", sys, sys.plant_input), + ("nested", nested_sys, nested_sys.hej.plant_input), + ("inner - Symbol", sys, :plant_input), + ("nested - Symbol", nested_sys, nameof(sys.plant_input)) + ] + + @testset "get_sensitivity - $name" for (name, sys, ap) in test_cases + matrices, _ = get_sensitivity(sys, ap) + @test matrices.A[] == -2 + @test matrices.B[] * matrices.C[] == -1 # either one negative + @test matrices.D[] == 1 + end + + @testset "get_comp_sensitivity - $name" for (name, sys, ap) in test_cases + matrices, _ = get_comp_sensitivity(sys, ap) + @test matrices.A[] == -2 + @test matrices.B[] * matrices.C[] == 1 # both positive or negative + @test matrices.D[] == 0 + end + + @testset "get_looptransfer - $name" for (name, sys, ap) in test_cases + matrices, _ = get_looptransfer(sys, ap) + @test matrices.A[] == -1 + @test matrices.B[] * matrices.C[] == -1 # either one negative + @test matrices.D[] == 0 + end + + @testset "open_loop - $name" for (name, sys, ap) in test_cases + open_sys, (du, u) = open_loop(sys, ap) + matrices, _ = linearize(open_sys, [du], [u]) + @test matrices.A[] == -1 + @test matrices.B[] * matrices.C[] == -1 # either one negative + @test matrices.D[] == 0 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index d38aa86bb2..95154e550e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -85,6 +85,7 @@ end @safetestset "Constraints Test" include("constraints.jl") @safetestset "IfLifting Test" include("if_lifting.jl") @safetestset "Analysis Points Test" include("analysis_points.jl") + @safetestset "Causal Variables Connection Test" include("causal_variables_connection.jl") end end