From 9a07bb472eb65e09a76c9b1b8c551d060990c35b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 28 Dec 2023 22:51:56 +0530 Subject: [PATCH] fix: check hasname before using getname --- src/systems/abstractsystem.jl | 28 ++++++++++++++++------------ test/symbolic_indexing_interface.jl | 22 ++++++++++++++++++++++ 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/src/systems/abstractsystem.jl b/src/systems/abstractsystem.jl index 2c9f1884fa..ffcc0fc5e7 100644 --- a/src/systems/abstractsystem.jl +++ b/src/systems/abstractsystem.jl @@ -191,11 +191,15 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym) hasname(sym) && is_variable(sys, getname(sym)) end +function _syms_as_symbols(syms) + return [hasname(sym) ? getname(sym) : Symbol(sym) for sym in syms] +end + function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol) - return any(isequal(sym), getname.(variable_symbols(sys))) || + vars = _syms_as_symbols(variable_symbols(sys)) + return any(isequal(sym), vars) || count('₊', string(sym)) == 1 && - count(isequal(sym), Symbol.(nameof(sys), :₊, getname.(variable_symbols(sys)))) == - 1 + count(isequal(sym), Symbol.(nameof(sys), :₊, vars)) == 1 end function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym) @@ -210,12 +214,12 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym) end function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol) - idx = findfirst(isequal(sym), getname.(variable_symbols(sys))) + vars = _syms_as_symbols(variable_symbols(sys)) + idx = findfirst(isequal(sym), vars) if idx !== nothing return idx elseif count('₊', string(sym)) == 1 - return findfirst(isequal(sym), - Symbol.(nameof(sys), :₊, getname.(variable_symbols(sys)))) + return findfirst(isequal(sym), Symbol.(nameof(sys), :₊, vars)) end return nothing end @@ -236,10 +240,10 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym) end function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol) - return any(isequal(sym), getname.(parameter_symbols(sys))) || + vars = _syms_as_symbols(parameter_symbols(sys)) + return any(isequal(sym), vars) || count('₊', string(sym)) == 1 && - count(isequal(sym), - Symbol.(nameof(sys), :₊, getname.(parameter_symbols(sys)))) == 1 + count(isequal(sym), Symbol.(nameof(sys), :₊, vars)) == 1 end function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) @@ -254,12 +258,12 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym) end function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol) - idx = findfirst(isequal(sym), getname.(parameter_symbols(sys))) + vars = _syms_as_symbols(parameter_symbols(sys)) + idx = findfirst(isequal(sym), vars) if idx !== nothing return idx elseif count('₊', string(sym)) == 1 - return findfirst(isequal(sym), - Symbol.(nameof(sys), :₊, getname.(parameter_symbols(sys)))) + return findfirst(isequal(sym), Symbol.(nameof(sys), :₊, vars)) end return nothing end diff --git a/test/symbolic_indexing_interface.jl b/test/symbolic_indexing_interface.jl index 3d0ab8f7c1..5913736a7a 100644 --- a/test/symbolic_indexing_interface.jl +++ b/test/symbolic_indexing_interface.jl @@ -77,3 +77,25 @@ analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1] @test isequal(pdesys.ps, [h => 1]) @test isequal(parameter_symbols(pdesys), [h]) @test isequal(parameters(pdesys), [h]) + +# Test systems with symbols that don't have a valid `getname` +dt = 0.1 +@variables t x(t) y(t) u(t) yd(t) ud(t) r(t) +@parameters kp +D = Differential(t) +# u(n + 1) := f(u(n)) + +eqs = [yd ~ Sample(t, dt)(y) + ud ~ kp * (r - yd) + r ~ 1.0 + +# plant (time continuous part) + u ~ Hold(ud) + D(x) ~ -x + u + y ~ x] +@named sys = ODESystem(eqs) +ss = structural_simplify(sys) # has `Hold()(ud(t))` as a parameter + +@test parameter_index(ss, :kp) == 1 +@test parameter_index(ss, Symbol(Hold(ud))) == 2 +@test parameter_index(ss, Sample(t, dt)(y)) == 3