From 1659d310eeadeb87ba852ccb8b07a894bc8e4652 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 9 Jan 2025 16:56:02 +0530 Subject: [PATCH] feat: add support for causal connections of variables --- src/systems/abstractsystem.jl | 9 ++- src/systems/analysis_points.jl | 17 ++++- src/systems/connectors.jl | 119 +++++++++++++++++++++++++++++++++ 3 files changed, 143 insertions(+), 2 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index ccc749baeb..3dc9cda3ba 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -1875,6 +1875,13 @@ Equivalent to `length(equations(expand_connections(sys))) - length(filter(eq -> function n_expanded_connection_equations(sys::AbstractSystem) # TODO: what about inputs? isconnector(sys) && return length(get_unknowns(sys)) + sys = remove_analysis_points(sys) + n_variable_connect_eqs = 0 + for eq in equations(sys) + is_causal_variable_connection(eq.rhs) || continue + n_variable_connect_eqs += length(get_systems(eq.rhs)) - 1 + end + sys, (csets, _) = generate_connection_set(sys) ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets) n_outer_stream_variables = 0 @@ -1897,7 +1904,7 @@ function n_expanded_connection_equations(sys::AbstractSystem) # n_toplevel_unused_flows += count(x->get_connection_type(x) === Flow && !(x in toplevel_flows), get_unknowns(m)) #end - nextras = n_outer_stream_variables + length(ceqs) + nextras = n_outer_stream_variables + length(ceqs) + n_variable_connect_eqs end function Base.show( diff --git a/src/systems/analysis_points.jl b/src/systems/analysis_points.jl index 87f278d4dc..d1080c083b 100644 --- a/src/systems/analysis_points.jl +++ b/src/systems/analysis_points.jl @@ -208,6 +208,12 @@ function Symbolics.connect(in::AbstractSystem, name::Symbol, out, outs...; verbo return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose) end +function Symbolics.connect(in::ConnectableSymbolicT, name::Symbol, out::ConnectableSymbolicT, outs::ConnectableSymbolicT...; verbose = true) + allvars = (in, out, outs...) + validate_causal_variables_connection(allvars) + return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose) +end + """ $(TYPEDSIGNATURES) @@ -240,7 +246,7 @@ connection. This is the variable named `u` if present, and otherwise the only variable in the system. If the system does not have a variable named `u` and contains multiple variables, throw an error. """ -function ap_var(sys) +function ap_var(sys::AbstractSystem) if hasproperty(sys, :u) return sys.u end @@ -249,6 +255,15 @@ function ap_var(sys) error("Could not determine the analysis-point variable in system $(nameof(sys)). To use an analysis point, apply it to a connection between causal blocks which have a variable named `u` or a single unknown of the same size.") end +""" + $(TYPEDSIGNATURES) + +For an `AnalysisPoint` involving causal variables. Simply return the variable. +""" +function ap_var(var::ConnectableSymbolicT) + return var +end + """ $(TYPEDEF) diff --git a/src/systems/connectors.jl b/src/systems/connectors.jl index f0f49acb85..3a2ffa9a9d 100644 --- a/src/systems/connectors.jl +++ b/src/systems/connectors.jl @@ -68,6 +68,89 @@ SymbolicUtils.promote_symtype(::typeof(instream), _) = Real isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing +""" + $(TYPEDEF) + +Utility struct which wraps a symbolic variable used in a `Connection` to enable `Base.show` +to work. +""" +struct SymbolicWithNameof + var::Any +end + +function Base.nameof(x::SymbolicWithNameof) + return Symbol(x.var) +end + +is_causal_variable_connection(c) = false +is_causal_variable_connection(c::Connection) = all(x -> x isa SymbolicWithNameof, get_systems(c)) + +const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics.Arr} + +const CAUSAL_CONNECTION_ERR = """ +Only causal variables can be used in a `connect` statement. The first argument must \ +be a single output variable and all subsequent variables must be input variables. +""" + +function VariableNotOutputError(var) + ArgumentError(""" + $CAUSAL_CONNECTION_ERR Expected $var to be marked as an output with `[output = true]` \ + in the variable metadata. + """) +end + +function VariableNotInputError(var) + ArgumentError(""" + $CAUSAL_CONNECTION_ERR Expected $var to be marked an input with `[input = true]` \ + in the variable metadata. + """) +end + +""" + $(TYPEDSIGNATURES) + +Perform validation for a connect statement involving causal variables. +""" +function validate_causal_variables_connection(allvars) + var1 = allvars[1] + var2 = allvars[2] + vars = Base.tail(Base.tail(allvars)) + for var in allvars + vtype = getvariabletype(var) + vtype === VARIABLE || throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`.")) + end + if length(unique(allvars)) !== length(allvars) + throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries.")) + end + allsizes = map(size, allvars) + if !allequal(allsizes) + throw(ArgumentError("Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively.")) + end + isoutput(var1) || throw(VariableNotOutputError(var1)) + isinput(var2) || throw(VariableNotInputError(var2)) + for var in vars + isinput(var) || throw(VariableNotInputError(var)) + end +end + +""" + $(TYPEDSIGNATURES) + +Connect multiple causal variables. The first variable must be an output, and all subsequent +variables must be inputs. The statement `connect(var1, var2, var3, ...)` expands to: + +```julia +var1 ~ var2 +var1 ~ var3 +# ... +``` +""" +function Symbolics.connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolicT, vars::ConnectableSymbolicT...) + allvars = (var1, var2, vars...) + validate_causal_variables_connection(allvars) + return Equation(Connection(), Connection(map(SymbolicWithNameof, allvars))) +end + function flowvar(sys::AbstractSystem) sts = get_unknowns(sys) for s in sts @@ -329,6 +412,10 @@ function generate_connection_set!(connectionsets, domain_csets, for eq in eqs′ lhs = eq.lhs rhs = eq.rhs + + # causal variable connections will be expanded before we get here, + # but this guard is useful for `n_expanded_connection_equations`. + is_causal_variable_connection(rhs) && continue if find !== nothing && find(rhs, _getname(namespace)) neweq, extra_unknown = replace(rhs, _getname(namespace)) if extra_unknown isa AbstractArray @@ -479,9 +566,41 @@ function domain_defaults(sys, domain_csets) def end +""" + $(TYPEDSIGNATURES) + +Recursively descend through the hierarchy of `sys` and expand all connection equations +of causal variables. Return the modified system. +""" +function expand_variable_connections(sys::AbstractSystem) + eqs = copy(get_eqs(sys)) + valid_idxs = trues(length(eqs)) + additional_eqs = Equation[] + + for (i, eq) in enumerate(eqs) + eq.lhs isa Connection || continue + connection = eq.rhs + elements = connection.systems + is_causal_variable_connection(connection) || continue + + valid_idxs[i] = false + elements = map(x -> x.var, elements) + outvar = first(elements) + for invar in Iterators.drop(elements, 1) + push!(additional_eqs, outvar ~ invar) + end + end + eqs = [eqs[valid_idxs]; additional_eqs] + subsystems = map(expand_variable_connections, get_systems(sys)) + @set! sys.eqs = eqs + @set! sys.systems = subsystems + return sys +end + function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing; debug = false, tol = 1e-10, scalarize = true) sys = remove_analysis_points(sys) + sys = expand_variable_connections(sys) sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize) ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets) _sys = expand_instream(instream_csets, sys; debug = debug, tol = tol)