Skip to content

Commit

Permalink
fix: check hasname before using getname
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 9, 2024
1 parent 3fbbdb0 commit 9a07bb4
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 12 deletions.
28 changes: 16 additions & 12 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,15 @@ function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
hasname(sym) && is_variable(sys, getname(sym))
end

function _syms_as_symbols(syms)
return [hasname(sym) ? getname(sym) : Symbol(sym) for sym in syms]

Check warning on line 195 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L194-L195

Added lines #L194 - L195 were not covered by tests
end

function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
return any(isequal(sym), getname.(variable_symbols(sys))) ||
vars = _syms_as_symbols(variable_symbols(sys))
return any(isequal(sym), vars) ||

Check warning on line 200 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L199-L200

Added lines #L199 - L200 were not covered by tests
count('', string(sym)) == 1 &&
count(isequal(sym), Symbol.(nameof(sys), :₊, getname.(variable_symbols(sys)))) ==
1
count(isequal(sym), Symbol.(nameof(sys), :₊, vars)) == 1
end

function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
Expand All @@ -210,12 +214,12 @@ function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
end

function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol)
idx = findfirst(isequal(sym), getname.(variable_symbols(sys)))
vars = _syms_as_symbols(variable_symbols(sys))
idx = findfirst(isequal(sym), vars)

Check warning on line 218 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L217-L218

Added lines #L217 - L218 were not covered by tests
if idx !== nothing
return idx
elseif count('', string(sym)) == 1
return findfirst(isequal(sym),
Symbol.(nameof(sys), :₊, getname.(variable_symbols(sys))))
return findfirst(isequal(sym), Symbol.(nameof(sys), :₊, vars))

Check warning on line 222 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L222

Added line #L222 was not covered by tests
end
return nothing
end
Expand All @@ -236,10 +240,10 @@ function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
return any(isequal(sym), getname.(parameter_symbols(sys))) ||
vars = _syms_as_symbols(parameter_symbols(sys))
return any(isequal(sym), vars) ||

Check warning on line 244 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L243-L244

Added lines #L243 - L244 were not covered by tests
count('', string(sym)) == 1 &&
count(isequal(sym),
Symbol.(nameof(sys), :₊, getname.(parameter_symbols(sys)))) == 1
count(isequal(sym), Symbol.(nameof(sys), :₊, vars)) == 1
end

function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
Expand All @@ -254,12 +258,12 @@ function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
end

function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))
vars = _syms_as_symbols(parameter_symbols(sys))
idx = findfirst(isequal(sym), vars)

Check warning on line 262 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L261-L262

Added lines #L261 - L262 were not covered by tests
if idx !== nothing
return idx
elseif count('', string(sym)) == 1
return findfirst(isequal(sym),
Symbol.(nameof(sys), :₊, getname.(parameter_symbols(sys))))
return findfirst(isequal(sym), Symbol.(nameof(sys), :₊, vars))

Check warning on line 266 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L266

Added line #L266 was not covered by tests
end
return nothing
end
Expand Down
22 changes: 22 additions & 0 deletions test/symbolic_indexing_interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,25 @@ analytic_function = (ps, t, x) -> -ps[1] * x * (x - 1) * sin(x) * exp(-2 * ps[1]
@test isequal(pdesys.ps, [h => 1])
@test isequal(parameter_symbols(pdesys), [h])
@test isequal(parameters(pdesys), [h])

# Test systems with symbols that don't have a valid `getname`
dt = 0.1
@variables t x(t) y(t) u(t) yd(t) ud(t) r(t)
@parameters kp
D = Differential(t)
# u(n + 1) := f(u(n))

eqs = [yd ~ Sample(t, dt)(y)
ud ~ kp * (r - yd)
r ~ 1.0

# plant (time continuous part)
u ~ Hold(ud)
D(x) ~ -x + u
y ~ x]
@named sys = ODESystem(eqs)
ss = structural_simplify(sys) # has `Hold()(ud(t))` as a parameter

@test parameter_index(ss, :kp) == 1
@test parameter_index(ss, Symbol(Hold(ud))) == 2
@test parameter_index(ss, Sample(t, dt)(y)) == 3

0 comments on commit 9a07bb4

Please sign in to comment.