Skip to content

Commit

Permalink
add getters for noise, unit, connect, misc
Browse files Browse the repository at this point in the history
  • Loading branch information
vyudu committed Jan 16, 2025
1 parent 143d5c3 commit 80b2549
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 13 deletions.
4 changes: 3 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,9 @@ export initial_state, transition, activeState, entry, ticksInState, timeInState
export @component, @mtkmodel, @mtkbuild
export isinput, isoutput, getbounds, hasbounds, getguess, hasguess, isdisturbance,
istunable, getdist, hasdist,
tunable_parameters, isirreducible, getdescription, hasdescription
tunable_parameters, isirreducible, getdescription, hasdescription,
hasnoise, getnoise, hasunit, getunit, hasconnect, getconnect,
hasmisc, getmisc
export ode_order_lowering, dae_order_lowering, liouville_transform
export PDESystem
export Differential, expand_derivatives, @derivatives
Expand Down
93 changes: 81 additions & 12 deletions src/variables.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ ModelingToolkit.dump_variable_metadata(p)
"""
function dump_variable_metadata(var)
uvar = unwrap(var)
vartype, name = get(uvar.metadata, VariableSource, (:unknown, :unknown))
vartype, name = Symbolics.getmetadata(uvar, VariableSource, (:unknown, :unknown))
type = symtype(uvar)
if type <: AbstractArray
shape = Symbolics.shape(var)
Expand All @@ -39,14 +39,14 @@ function dump_variable_metadata(var)
else
shape = nothing
end
unit = get(uvar.metadata, VariableUnit, nothing)
connect = get(uvar.metadata, VariableConnectType, nothing)
noise = get(uvar.metadata, VariableNoiseType, nothing)
unit = getunit(uvar)
connect = getconnect(uvar)
noise = getnoise(uvar)
input = isinput(uvar) || nothing
output = isoutput(uvar) || nothing
irreducible = get(uvar.metadata, VariableIrreducible, nothing)
state_priority = get(uvar.metadata, VariableStatePriority, nothing)
misc = get(uvar.metadata, VariableMisc, nothing)
irreducible = isirreducible(var)
state_priority = Symbolics.getmetadata(uvar, VariableStatePriority, nothing)
misc = getmisc(uvar)
bounds = hasbounds(uvar) ? getbounds(uvar) : nothing
desc = getdescription(var)
if desc == ""
Expand All @@ -57,12 +57,13 @@ function dump_variable_metadata(var)
disturbance = isdisturbance(uvar) || nothing
tunable = istunable(uvar, isparameter(uvar))
dist = getdist(uvar)
type = symtype(uvar)
variable_type = getvariabletype(uvar)

meta = (
var = var,
vartype,
name,
variable_type,
shape,
unit,
connect,
Expand All @@ -85,11 +86,28 @@ function dump_variable_metadata(var)
return NamedTuple(k => v for (k, v) in pairs(meta) if v !== nothing)
end

### Connect
abstract type AbstractConnectType end
struct Equality <: AbstractConnectType end # Equality connection
struct Flow <: AbstractConnectType end # sum to 0
struct Stream <: AbstractConnectType end # special stream connector

"""
getconnect(x)
Get the connect type of x. See also [`hasconnect`](@ref).
"""
getconnect(x) = getconnect(unwrap(x))
getconnect(x::Symbolic) = Symbolics.getmetadata(x, VariableConnectType, nothing)
"""
hasconnect(x)
Determine whether variable `x` has a connect type. See also [`getconnect`](@ref).
"""
hasconnect(x) = getconnect(x) !== nothing
setconnect(x, t::Type{T}) where T <: AbstractConnectType = setmetadata(x, VariableConnectType, t)

### Input, Output, Irreducible
isvarkind(m, x::Union{Num, Symbolics.Arr}) = isvarkind(m, value(x))
function isvarkind(m, x)
iskind = getmetadata(x, m, nothing)
Expand All @@ -98,15 +116,17 @@ function isvarkind(m, x)
getmetadata(x, m, false)
end

setinput(x, v) = setmetadata(x, VariableInput, v)
setoutput(x, v) = setmetadata(x, VariableOutput, v)
setio(x, i, o) = setoutput(setinput(x, i), o)
setinput(x, v::Bool) = setmetadata(x, VariableInput, v)
setoutput(x, v::Bool) = setmetadata(x, VariableOutput, v)
setio(x, i::Bool, o::Bool) = setoutput(setinput(x, i), o)

isinput(x) = isvarkind(VariableInput, x)
isoutput(x) = isvarkind(VariableOutput, x)

# Before the solvability check, we already have handled IO variables, so
# irreducibility is independent from IO.
isirreducible(x) = isvarkind(VariableIrreducible, x)
setirreducible(x, v) = setmetadata(x, VariableIrreducible, v)
setirreducible(x, v::Bool) = setmetadata(x, VariableIrreducible, v)
state_priority(x) = convert(Float64, getmetadata(x, VariableStatePriority, 0.0))::Float64

function default_toterm(x)
Expand Down Expand Up @@ -545,3 +565,52 @@ function get_default_or_guess(x)
return getguess(x)
end
end

## Miscellaneous metadata ======================================================================
"""
getmisc(x)
Fetch any miscellaneous data associated with symbolic variable `x`.
See also [`hasmisc(x)`](@ref).
"""
getmisc(x) = getmisc(unwrap(x))
getmisc(x::Symbolic) = Symbolics.getmetadata(x, VariableMisc, nothing)
"""
hasmisc(x)
Determine whether a symbolic variable `x` has misc
metadata associated with it.
See also [`getmisc(x)`](@ref).
"""
hasmisc(x) = getmisc(x) !== nothing
setmisc(x, miscdata) = setmetadata(x, VariableMisc, miscdata)

## Units ======================================================================
"""
getunit(x)
Alias for [`get_unit(x)`](@ref).
"""
getunit(x) = get_unit(x)
"""
hasunit(x)
Check if the variable `x` has a unit.
"""
hasunit(x) = getunit(x) !== nothing

## Noise ======================================================================
"""
getnoise(x)
Get the noise type of variable `x`.
"""
getnoise(x) = getnoise(unwrap(x))
getnoise(x::Symbolic) = Symbolics.getmetadata(x, VariableNoiseType, nothing)
"""
hasnoise(x)
Determine if variable `x` has a noise type.
"""
hasnoise(x) = getnoise(x) !== nothing
44 changes: 44 additions & 0 deletions test/test_variable_metadata.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using ModelingToolkit
using DynamicQuantities

# Bounds
@variables u [bounds = (-1, 1)]
Expand Down Expand Up @@ -185,3 +186,46 @@ params_meta = ModelingToolkit.dump_parameters(sys)
params_meta = Dict([ModelingToolkit.getname(meta.var) => meta for meta in params_meta])
@test params_meta[:p].default == 3.0
@test isequal(params_meta[:q].dependency, 2p)

# Noise
@variables x [noise = 1]
@test hasnoise(x)
@test getnoise(x) == 1
@test ModelingToolkit.dump_variable_metadata(x).noise == 1

# Connect
@variables x [connect = Flow]
@test hasconnect(x)
@test getconnect(x) == Flow
@test ModelingToolkit.dump_variable_metadata(x).connect == Flow
x = ModelingToolkit.setconnect(x, ModelingToolkit.Stream)
@test getconnect(x) == ModelingToolkit.Stream

struct BadConnect end
@test_throws Exception ModelingToolkit.setconnect(x, BadConnect)

# Unit
@variables x [unit = u"s"]
@test hasunit(x)
@test getunit(x) == u"s"
@test ModelingToolkit.dump_variable_metadata(x).unit == u"s"

# Misc data
@variables x [misc = [:good]]
@test hasmisc(x)
@test getmisc(x) == [:good]
x = ModelingToolkit.setmisc(x, "okay")
@test getmisc(x) == "okay"

# Variable Type
@variables x
@test ModelingToolkit.getvariabletype(x) == ModelingToolkit.VARIABLE
@test ModelingToolkit.dump_variable_metadata(x).variable_type == ModelingToolkit.VARIABLE
x = ModelingToolkit.toparam(x)
@test ModelingToolkit.getvariabletype(x) == ModelingToolkit.PARAMETER

@parameters y
@test ModelingToolkit.getvariabletype(y) == ModelingToolkit.PARAMETER

@brownian z
@test ModelingToolkit.getvariabletype(z) == ModelingToolkit.BROWNIAN

0 comments on commit 80b2549

Please sign in to comment.