Skip to content

Commit

Permalink
feat: add support for causal connections of variables
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 9, 2025
1 parent 2092c4b commit 1659d31
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 2 deletions.
9 changes: 8 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1875,6 +1875,13 @@ Equivalent to `length(equations(expand_connections(sys))) - length(filter(eq ->
function n_expanded_connection_equations(sys::AbstractSystem)
# TODO: what about inputs?
isconnector(sys) && return length(get_unknowns(sys))
sys = remove_analysis_points(sys)
n_variable_connect_eqs = 0
for eq in equations(sys)
is_causal_variable_connection(eq.rhs) || continue
n_variable_connect_eqs += length(get_systems(eq.rhs)) - 1
end

sys, (csets, _) = generate_connection_set(sys)
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
n_outer_stream_variables = 0
Expand All @@ -1897,7 +1904,7 @@ function n_expanded_connection_equations(sys::AbstractSystem)
# n_toplevel_unused_flows += count(x->get_connection_type(x) === Flow && !(x in toplevel_flows), get_unknowns(m))
#end

nextras = n_outer_stream_variables + length(ceqs)
nextras = n_outer_stream_variables + length(ceqs) + n_variable_connect_eqs
end

function Base.show(
Expand Down
17 changes: 16 additions & 1 deletion src/systems/analysis_points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ function Symbolics.connect(in::AbstractSystem, name::Symbol, out, outs...; verbo
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
end

function Symbolics.connect(in::ConnectableSymbolicT, name::Symbol, out::ConnectableSymbolicT, outs::ConnectableSymbolicT...; verbose = true)
allvars = (in, out, outs...)
validate_causal_variables_connection(allvars)
return AnalysisPoint() ~ AnalysisPoint(in, name, [out; collect(outs)]; verbose)
end

"""
$(TYPEDSIGNATURES)
Expand Down Expand Up @@ -240,7 +246,7 @@ connection. This is the variable named `u` if present, and otherwise the only
variable in the system. If the system does not have a variable named `u` and
contains multiple variables, throw an error.
"""
function ap_var(sys)
function ap_var(sys::AbstractSystem)
if hasproperty(sys, :u)
return sys.u
end
Expand All @@ -249,6 +255,15 @@ function ap_var(sys)
error("Could not determine the analysis-point variable in system $(nameof(sys)). To use an analysis point, apply it to a connection between causal blocks which have a variable named `u` or a single unknown of the same size.")
end

"""
$(TYPEDSIGNATURES)
For an `AnalysisPoint` involving causal variables. Simply return the variable.
"""
function ap_var(var::ConnectableSymbolicT)
return var
end

"""
$(TYPEDEF)
Expand Down
119 changes: 119 additions & 0 deletions src/systems/connectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,89 @@ SymbolicUtils.promote_symtype(::typeof(instream), _) = Real

isconnector(s::AbstractSystem) = has_connector_type(s) && get_connector_type(s) !== nothing

"""
$(TYPEDEF)
Utility struct which wraps a symbolic variable used in a `Connection` to enable `Base.show`
to work.
"""
struct SymbolicWithNameof
var::Any
end

function Base.nameof(x::SymbolicWithNameof)
return Symbol(x.var)
end

is_causal_variable_connection(c) = false
is_causal_variable_connection(c::Connection) = all(x -> x isa SymbolicWithNameof, get_systems(c))

const ConnectableSymbolicT = Union{BasicSymbolic, Num, Symbolics.Arr}

const CAUSAL_CONNECTION_ERR = """
Only causal variables can be used in a `connect` statement. The first argument must \
be a single output variable and all subsequent variables must be input variables.
"""

function VariableNotOutputError(var)
ArgumentError("""
$CAUSAL_CONNECTION_ERR Expected $var to be marked as an output with `[output = true]` \
in the variable metadata.
""")
end

function VariableNotInputError(var)
ArgumentError("""
$CAUSAL_CONNECTION_ERR Expected $var to be marked an input with `[input = true]` \
in the variable metadata.
""")
end

"""
$(TYPEDSIGNATURES)
Perform validation for a connect statement involving causal variables.
"""
function validate_causal_variables_connection(allvars)
var1 = allvars[1]
var2 = allvars[2]
vars = Base.tail(Base.tail(allvars))
for var in allvars
vtype = getvariabletype(var)
vtype === VARIABLE || throw(ArgumentError("Expected $var to be of kind `$VARIABLE`. Got `$vtype`."))
end
if length(unique(allvars)) !== length(allvars)
throw(ArgumentError("Expected all connection variables to be unique. Got variables $allvars which contains duplicate entries."))
end
allsizes = map(size, allvars)
if !allequal(allsizes)
throw(ArgumentError("Expected all connection variables to have the same size. Got variables $allvars with sizes $allsizes respectively."))
end
isoutput(var1) || throw(VariableNotOutputError(var1))
isinput(var2) || throw(VariableNotInputError(var2))
for var in vars
isinput(var) || throw(VariableNotInputError(var))
end
end

"""
$(TYPEDSIGNATURES)
Connect multiple causal variables. The first variable must be an output, and all subsequent
variables must be inputs. The statement `connect(var1, var2, var3, ...)` expands to:
```julia
var1 ~ var2
var1 ~ var3
# ...
```
"""
function Symbolics.connect(var1::ConnectableSymbolicT, var2::ConnectableSymbolicT, vars::ConnectableSymbolicT...)
allvars = (var1, var2, vars...)
validate_causal_variables_connection(allvars)
return Equation(Connection(), Connection(map(SymbolicWithNameof, allvars)))
end

function flowvar(sys::AbstractSystem)
sts = get_unknowns(sys)
for s in sts
Expand Down Expand Up @@ -329,6 +412,10 @@ function generate_connection_set!(connectionsets, domain_csets,
for eq in eqs′
lhs = eq.lhs
rhs = eq.rhs

# causal variable connections will be expanded before we get here,
# but this guard is useful for `n_expanded_connection_equations`.
is_causal_variable_connection(rhs) && continue
if find !== nothing && find(rhs, _getname(namespace))
neweq, extra_unknown = replace(rhs, _getname(namespace))
if extra_unknown isa AbstractArray
Expand Down Expand Up @@ -479,9 +566,41 @@ function domain_defaults(sys, domain_csets)
def
end

"""
$(TYPEDSIGNATURES)
Recursively descend through the hierarchy of `sys` and expand all connection equations
of causal variables. Return the modified system.
"""
function expand_variable_connections(sys::AbstractSystem)
eqs = copy(get_eqs(sys))
valid_idxs = trues(length(eqs))
additional_eqs = Equation[]

for (i, eq) in enumerate(eqs)
eq.lhs isa Connection || continue
connection = eq.rhs
elements = connection.systems
is_causal_variable_connection(connection) || continue

valid_idxs[i] = false
elements = map(x -> x.var, elements)
outvar = first(elements)
for invar in Iterators.drop(elements, 1)
push!(additional_eqs, outvar ~ invar)
end
end
eqs = [eqs[valid_idxs]; additional_eqs]
subsystems = map(expand_variable_connections, get_systems(sys))
@set! sys.eqs = eqs
@set! sys.systems = subsystems
return sys
end

function expand_connections(sys::AbstractSystem, find = nothing, replace = nothing;
debug = false, tol = 1e-10, scalarize = true)
sys = remove_analysis_points(sys)
sys = expand_variable_connections(sys)
sys, (csets, domain_csets) = generate_connection_set(sys, find, replace; scalarize)
ceqs, instream_csets = generate_connection_equations_and_stream_connections(csets)
_sys = expand_instream(instream_csets, sys; debug = debug, tol = tol)
Expand Down

0 comments on commit 1659d31

Please sign in to comment.