Skip to content

Commit

Permalink
feat: implementation of new SymbolicIndexingInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Dec 12, 2023
1 parent e7fe1b5 commit d179b4f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 32 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ MacroTools = "0.5"
NaNMath = "0.3, 1"
OrdinaryDiffEq = "6"
PrecompileTools = "1"
RecursiveArrayTools = "2.3"
RecursiveArrayTools = "2.3, 3"
Reexport = "0.2, 1"
RuntimeGeneratedFunctions = "0.5.9"
SciMLBase = "2.0.1"
Expand All @@ -98,7 +98,7 @@ SimpleNonlinearSolve = "0.1.0, 1"
SparseArrays = "1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "0.10, 0.11, 0.12, 1.0"
SymbolicIndexingInterface = "0.1, 0.2"
SymbolicIndexingInterface = "0.3"
SymbolicUtils = "1.0"
Symbolics = "5.7"
URIs = "1"
Expand Down
3 changes: 1 addition & 2 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ using PrecompileTools, Reexport

using RecursiveArrayTools

import SymbolicIndexingInterface
import SymbolicIndexingInterface: independent_variables, states, parameters
using SymbolicIndexingInterface
export independent_variables, states, parameters
import SymbolicUtils
import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype,
Expand Down
135 changes: 109 additions & 26 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,18 @@ function independent_variable(sys::AbstractSystem)
isdefined(sys, :iv) ? getfield(sys, :iv) : nothing
end

#Treat the result as a vector of symbols always
function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem)
systype = typeof(sys)
@warn "Please declare ($systype) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`."
function independent_variables(sys::AbstractTimeDependentSystem)
return [getfield(sys, :iv)]
end

independent_variables(::AbstractTimeIndependentSystem) = []

function independent_variables(sys::AbstractMultivariateSystem)
return getfield(sys, :ivs)
end

function independent_variables(sys::AbstractSystem)
@warn "Please declare ($(typeof(sys))) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`."
if isdefined(sys, :iv)
return [getfield(sys, :iv)]
elseif isdefined(sys, :ivs)
Expand All @@ -174,14 +182,102 @@ function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem)
end
end

function SymbolicIndexingInterface.independent_variables(sys::AbstractTimeDependentSystem)
[getfield(sys, :iv)]
#Treat the result as a vector of symbols always
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
return unwrap(sym) in 1:length(unknown_states(sys))
end
return any(isequal(sym), unknown_states(sys)) || hasname(sym) && is_variable(sys, getname(sym))
end

function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
return any(isequal(sym), getname.(unknown_states(sys))) || count('', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys)))) == 1
end
SymbolicIndexingInterface.independent_variables(sys::AbstractTimeIndependentSystem) = []
function SymbolicIndexingInterface.independent_variables(sys::AbstractMultivariateSystem)
getfield(sys, :ivs)

function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym)
end
idx = findfirst(isequal(sym), unknown_states(sys))
if idx === nothing && hasname(sym)
idx = variable_index(sys, getname(sym))
end
return idx
end

function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol)
idx = findfirst(isequal(sym), getname.(unknown_states(sys)))
if idx !== nothing
return idx
elseif count('', string(sym)) == 1
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys))))
end
return nothing
end

function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem)
return unknown_states(sys)
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym) in 1:length(parameters(sys))
end

return any(isequal(sym), parameters(sys)) || hasname(sym) && is_parameter(sys, getname(sym))
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
return any(isequal(sym), getname.(parameters(sys))) ||
count('', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys)))) == 1
end

function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym)
end
idx = findfirst(isequal(sym), parameters(sys))
if idx === nothing && hasname(sym)
idx = parameter_index(sys, getname(sym))
end
return idx
end

function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
idx = findfirst(isequal(sym), getname.(parameters(sys)))
if idx !== nothing
return idx
elseif count('', string(sym)) == 1
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys))))
end
return nothing
end

function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
return parameters(sys)
end

function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
return any(isequal(sym), independent_variables(sys))
end

function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym::Symbol)
return any(isequal(sym), getname.(independent_variables(sys)))
end

function SymbolicIndexingInterface.independent_variable_symbols(sys::AbstractSystem)
return independent_variables(sys)
end

function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym)
return !is_variable(sys, sym) && !is_parameter(sys, sym) && !is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic()
end

SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true
SymbolicIndexingInterface.is_time_dependent(::AbstractTimeIndependentSystem) = false

SymbolicIndexingInterface.constant_structure(::AbstractSystem) = true

iscomplete(sys::AbstractSystem) = isdefined(sys, :complete) && getfield(sys, :complete)

"""
Expand Down Expand Up @@ -534,12 +630,15 @@ function states(sys::AbstractSystem)
[sts; reduce(vcat, namespace_variables.(systems))])
end

function SymbolicIndexingInterface.parameters(sys::AbstractSystem)
function parameters(sys::AbstractSystem)
ps = get_ps(sys)
systems = get_systems(sys)
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
end

# required in `src/connectors.jl:437`
parameters(_) = []

function controls(sys::AbstractSystem)
ctrls = get_ctrls(sys)
systems = get_systems(sys)
Expand Down Expand Up @@ -638,8 +737,6 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem)
return x
end

SymbolicIndexingInterface.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys))

"""
$(SIGNATURES)
Expand All @@ -653,20 +750,6 @@ function unknown_states(sys::AbstractSystem)
return sts
end

function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym)
findfirst(isequal(sym), unknown_states(sys))
end
function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym)
!isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym))
end

function SymbolicIndexingInterface.param_sym_to_index(sys::AbstractSystem, sym)
findfirst(isequal(sym), SymbolicIndexingInterface.parameters(sys))
end
function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym)
!isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym))
end

###
### System utils
###
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -516,9 +516,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s
tgrad = _tgrad === nothing ? nothing : _tgrad,
mass_matrix = _M,
jac_prototype = jac_prototype,
syms = Symbol.(states(sys)),
syms = collect(Symbol.(states(sys))),
indepsym = Symbol(get_iv(sys)),
paramsyms = Symbol.(ps),
paramsyms = collect(Symbol.(ps)),
observed = observedfun,
sparsity = sparsity ? jacobian_sparsity(sys) : nothing,
analytic = analytic)
Expand Down

0 comments on commit d179b4f

Please sign in to comment.