From d179b4ff30fc41b874ecb6ad5b54aeb260b84fea Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 3 Nov 2023 15:08:51 +0530 Subject: [PATCH] feat: implementation of new SymbolicIndexingInterface --- Project.toml | 4 +- src/ModelingToolkit.jl | 3 +- src/systems/abstractsystem.jl | 135 ++++++++++++++++++----- src/systems/diffeqs/abstractodesystem.jl | 4 +- 4 files changed, 114 insertions(+), 32 deletions(-) diff --git a/Project.toml b/Project.toml index e12c73a538..96d5d0f43c 100644 --- a/Project.toml +++ b/Project.toml @@ -88,7 +88,7 @@ MacroTools = "0.5" NaNMath = "0.3, 1" OrdinaryDiffEq = "6" PrecompileTools = "1" -RecursiveArrayTools = "2.3" +RecursiveArrayTools = "2.3, 3" Reexport = "0.2, 1" RuntimeGeneratedFunctions = "0.5.9" SciMLBase = "2.0.1" @@ -98,7 +98,7 @@ SimpleNonlinearSolve = "0.1.0, 1" SparseArrays = "1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "0.10, 0.11, 0.12, 1.0" -SymbolicIndexingInterface = "0.1, 0.2" +SymbolicIndexingInterface = "0.3" SymbolicUtils = "1.0" Symbolics = "5.7" URIs = "1" diff --git a/src/ModelingToolkit.jl b/src/ModelingToolkit.jl index 99bf0b015b..57a9477e04 100644 --- a/src/ModelingToolkit.jl +++ b/src/ModelingToolkit.jl @@ -35,8 +35,7 @@ using PrecompileTools, Reexport using RecursiveArrayTools - import SymbolicIndexingInterface - import SymbolicIndexingInterface: independent_variables, states, parameters + using SymbolicIndexingInterface export independent_variables, states, parameters import SymbolicUtils import SymbolicUtils: istree, arguments, operation, similarterm, promote_symtype, diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 36a3d1a124..0cc9811b07 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -161,10 +161,18 @@ function independent_variable(sys::AbstractSystem) isdefined(sys, :iv) ? getfield(sys, :iv) : nothing end -#Treat the result as a vector of symbols always -function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem) - systype = typeof(sys) - @warn "Please declare ($systype) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`." +function independent_variables(sys::AbstractTimeDependentSystem) + return [getfield(sys, :iv)] +end + +independent_variables(::AbstractTimeIndependentSystem) = [] + +function independent_variables(sys::AbstractMultivariateSystem) + return getfield(sys, :ivs) +end + +function independent_variables(sys::AbstractSystem) + @warn "Please declare ($(typeof(sys))) as a subtype of `AbstractTimeDependentSystem`, `AbstractTimeIndependentSystem` or `AbstractMultivariateSystem`." if isdefined(sys, :iv) return [getfield(sys, :iv)] elseif isdefined(sys, :ivs) @@ -174,14 +182,102 @@ function SymbolicIndexingInterface.independent_variables(sys::AbstractSystem) end end -function SymbolicIndexingInterface.independent_variables(sys::AbstractTimeDependentSystem) - [getfield(sys, :iv)] +#Treat the result as a vector of symbols always +function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym) + if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num + return unwrap(sym) in 1:length(unknown_states(sys)) + end + return any(isequal(sym), unknown_states(sys)) || hasname(sym) && is_variable(sys, getname(sym)) +end + +function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol) + return any(isequal(sym), getname.(unknown_states(sys))) || count('₊', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys)))) == 1 end -SymbolicIndexingInterface.independent_variables(sys::AbstractTimeIndependentSystem) = [] -function SymbolicIndexingInterface.independent_variables(sys::AbstractMultivariateSystem) - getfield(sys, :ivs) + +function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym) + if unwrap(sym) isa Int + return unwrap(sym) + end + idx = findfirst(isequal(sym), unknown_states(sys)) + if idx === nothing && hasname(sym) + idx = variable_index(sys, getname(sym)) + end + return idx end +function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol) + idx = findfirst(isequal(sym), getname.(unknown_states(sys))) + if idx !== nothing + return idx + elseif count('₊', string(sym)) == 1 + return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys)))) + end + return nothing +end + +function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem) + return unknown_states(sys) +end + +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) + if unwrap(sym) isa Int + return unwrap(sym) in 1:length(parameters(sys)) + end + + return any(isequal(sym), parameters(sys)) || hasname(sym) && is_parameter(sys, getname(sym)) +end + +function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol) + return any(isequal(sym), getname.(parameters(sys))) || + count('₊', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys)))) == 1 +end + +function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) + if unwrap(sym) isa Int + return unwrap(sym) + end + idx = findfirst(isequal(sym), parameters(sys)) + if idx === nothing && hasname(sym) + idx = parameter_index(sys, getname(sym)) + end + return idx +end + +function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol) + idx = findfirst(isequal(sym), getname.(parameters(sys))) + if idx !== nothing + return idx + elseif count('₊', string(sym)) == 1 + return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys)))) + end + return nothing +end + +function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem) + return parameters(sys) +end + +function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym) + return any(isequal(sym), independent_variables(sys)) +end + +function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym::Symbol) + return any(isequal(sym), getname.(independent_variables(sys))) +end + +function SymbolicIndexingInterface.independent_variable_symbols(sys::AbstractSystem) + return independent_variables(sys) +end + +function SymbolicIndexingInterface.is_observed(sys::AbstractSystem, sym) + return !is_variable(sys, sym) && !is_parameter(sys, sym) && !is_independent_variable(sys, sym) && symbolic_type(sym) != NotSymbolic() +end + +SymbolicIndexingInterface.is_time_dependent(::AbstractTimeDependentSystem) = true +SymbolicIndexingInterface.is_time_dependent(::AbstractTimeIndependentSystem) = false + +SymbolicIndexingInterface.constant_structure(::AbstractSystem) = true + iscomplete(sys::AbstractSystem) = isdefined(sys, :complete) && getfield(sys, :complete) """ @@ -534,12 +630,15 @@ function states(sys::AbstractSystem) [sts; reduce(vcat, namespace_variables.(systems))]) end -function SymbolicIndexingInterface.parameters(sys::AbstractSystem) +function parameters(sys::AbstractSystem) ps = get_ps(sys) systems = get_systems(sys) unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))]) end +# required in `src/connectors.jl:437` +parameters(_) = [] + function controls(sys::AbstractSystem) ctrls = get_ctrls(sys) systems = get_systems(sys) @@ -638,8 +737,6 @@ function time_varying_as_func(x, sys::AbstractTimeDependentSystem) return x end -SymbolicIndexingInterface.is_indep_sym(sys::AbstractSystem, sym) = isequal(sym, get_iv(sys)) - """ $(SIGNATURES) @@ -653,20 +750,6 @@ function unknown_states(sys::AbstractSystem) return sts end -function SymbolicIndexingInterface.state_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal(sym), unknown_states(sys)) -end -function SymbolicIndexingInterface.is_state_sym(sys::AbstractSystem, sym) - !isnothing(SymbolicIndexingInterface.state_sym_to_index(sys, sym)) -end - -function SymbolicIndexingInterface.param_sym_to_index(sys::AbstractSystem, sym) - findfirst(isequal(sym), SymbolicIndexingInterface.parameters(sys)) -end -function SymbolicIndexingInterface.is_param_sym(sys::AbstractSystem, sym) - !isnothing(SymbolicIndexingInterface.param_sym_to_index(sys, sym)) -end - ### ### System utils ### diff --git a/src/systems/diffeqs/abstractodesystem.jl b/src/systems/diffeqs/abstractodesystem.jl index 3ac137f6e3..f79859da57 100644 --- a/src/systems/diffeqs/abstractodesystem.jl +++ b/src/systems/diffeqs/abstractodesystem.jl @@ -516,9 +516,9 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem, dvs = s tgrad = _tgrad === nothing ? nothing : _tgrad, mass_matrix = _M, jac_prototype = jac_prototype, - syms = Symbol.(states(sys)), + syms = collect(Symbol.(states(sys))), indepsym = Symbol(get_iv(sys)), - paramsyms = Symbol.(ps), + paramsyms = collect(Symbol.(ps)), observed = observedfun, sparsity = sparsity ? jacobian_sparsity(sys) : nothing, analytic = analytic)