Skip to content

Commit

Permalink
refactor: handle NullParameters, fix independent_variable_symbols for…
Browse files Browse the repository at this point in the history
… multivariate systems
  • Loading branch information
AayushSabharwal committed Dec 20, 2023
1 parent 4a5d676 commit b32b123
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -185,70 +185,72 @@ end
#Treat the result as a vector of symbols always
function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym)
if unwrap(sym) isa Int # [x, 1] coerces 1 to a Num
return unwrap(sym) in 1:length(unknown_states(sys))
return unwrap(sym) in 1:length(variable_symbols(sys))

Check warning on line 188 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L188

Added line #L188 was not covered by tests
end
return any(isequal(sym), unknown_states(sys)) || hasname(sym) && is_variable(sys, getname(sym))
return any(isequal(sym), variable_symbols(sys)) || hasname(sym) && is_variable(sys, getname(sym))

Check warning on line 190 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L190

Added line #L190 was not covered by tests
end

function SymbolicIndexingInterface.is_variable(sys::AbstractSystem, sym::Symbol)
return any(isequal(sym), getname.(unknown_states(sys))) || count('', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys)))) == 1
return any(isequal(sym), getname.(variable_symbols(sys))) || count('', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(variable_symbols(sys)))) == 1

Check warning on line 194 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L194

Added line #L194 was not covered by tests
end

function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym)
end
idx = findfirst(isequal(sym), unknown_states(sys))
idx = findfirst(isequal(sym), variable_symbols(sys))

Check warning on line 201 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L201

Added line #L201 was not covered by tests
if idx === nothing && hasname(sym)
idx = variable_index(sys, getname(sym))
end
return idx
end

function SymbolicIndexingInterface.variable_index(sys::AbstractSystem, sym::Symbol)
idx = findfirst(isequal(sym), getname.(unknown_states(sys)))
idx = findfirst(isequal(sym), getname.(variable_symbols(sys)))

Check warning on line 209 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L209

Added line #L209 was not covered by tests
if idx !== nothing
return idx
elseif count('', string(sym)) == 1
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(unknown_states(sys))))
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(variable_symbols(sys))))

Check warning on line 213 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L213

Added line #L213 was not covered by tests
end
return nothing
end

SymbolicIndexingInterface.variable_symbols(sys::AbstractMultivariateSystem) = sys.dvs

Check warning on line 218 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L218

Added line #L218 was not covered by tests

function SymbolicIndexingInterface.variable_symbols(sys::AbstractSystem)
return unknown_states(sys)
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym) in 1:length(parameters(sys))
return unwrap(sym) in 1:length(parameter_symbols(sys))

Check warning on line 226 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L226

Added line #L226 was not covered by tests
end

return any(isequal(sym), parameters(sys)) || hasname(sym) && is_parameter(sys, getname(sym))
return any(isequal(sym), parameter_symbols(sys)) || hasname(sym) && is_parameter(sys, getname(sym))

Check warning on line 229 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L229

Added line #L229 was not covered by tests
end

function SymbolicIndexingInterface.is_parameter(sys::AbstractSystem, sym::Symbol)
return any(isequal(sym), getname.(parameters(sys))) ||
count('', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys)))) == 1
return any(isequal(sym), getname.(parameter_symbols(sys))) ||

Check warning on line 233 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L233

Added line #L233 was not covered by tests
count('', string(sym)) == 1 && count(isequal(sym), Symbol.(sys.name, :₊, getname.(parameter_symbols(sys)))) == 1
end

function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym)
if unwrap(sym) isa Int
return unwrap(sym)
end
idx = findfirst(isequal(sym), parameters(sys))
idx = findfirst(isequal(sym), parameter_symbols(sys))

Check warning on line 241 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L241

Added line #L241 was not covered by tests
if idx === nothing && hasname(sym)
idx = parameter_index(sys, getname(sym))
end
return idx
end

function SymbolicIndexingInterface.parameter_index(sys::AbstractSystem, sym::Symbol)
idx = findfirst(isequal(sym), getname.(parameters(sys)))
idx = findfirst(isequal(sym), getname.(parameter_symbols(sys)))

Check warning on line 249 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L249

Added line #L249 was not covered by tests
if idx !== nothing
return idx
elseif count('', string(sym)) == 1
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(parameters(sys))))
return findfirst(isequal(sym), Symbol.(sys.name, :₊, getname.(parameter_symbols(sys))))

Check warning on line 253 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L253

Added line #L253 was not covered by tests
end
return nothing
end
Expand All @@ -258,7 +260,7 @@ function SymbolicIndexingInterface.parameter_symbols(sys::AbstractSystem)
end

function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym)
return any(isequal(sym), independent_variables(sys))
return any(isequal(sym), independent_variable_symbols(sys))

Check warning on line 263 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L263

Added line #L263 was not covered by tests
end

function SymbolicIndexingInterface.is_independent_variable(sys::AbstractSystem, sym::Symbol)
Expand Down Expand Up @@ -632,6 +634,9 @@ end

function parameters(sys::AbstractSystem)
ps = get_ps(sys)
if ps == SciMLBase.NullParameters()
return []

Check warning on line 638 in src/systems/abstractsystem.jl

View check run for this annotation

Codecov / codecov/patch

src/systems/abstractsystem.jl#L637-L638

Added lines #L637 - L638 were not covered by tests
end
systems = get_systems(sys)
unique(isempty(systems) ? ps : [ps; reduce(vcat, namespace_parameters.(systems))])
end
Expand Down

0 comments on commit b32b123

Please sign in to comment.