Skip to content

Commit

Permalink
fix: propagate checkbounds information to observed function generation
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 6, 2025
1 parent 0648e5c commit 8c55fe0
Show file tree
Hide file tree
Showing 7 changed files with 15 additions and 14 deletions.
11 changes: 6 additions & 5 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ end
SymbolicIndexingInterface.supports_tuple_observed(::AbstractSystem) = true

function SymbolicIndexingInterface.observed(
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__)
sys::AbstractSystem, sym; eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
if has_index_cache(sys) && (ic = get_index_cache(sys)) !== nothing
if sym isa Symbol
_sym = get(ic.symbol_to_variable, sym, nothing)
Expand All @@ -808,7 +808,7 @@ function SymbolicIndexingInterface.observed(
end
end
end
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module)
_fn = build_explicit_observed_function(sys, sym; eval_expression, eval_module, checkbounds)

if is_time_dependent(sys)
return _fn
Expand Down Expand Up @@ -1671,11 +1671,12 @@ struct ObservedFunctionCache{S}
steady_state::Bool
eval_expression::Bool
eval_module::Module
checkbounds::Bool
end

function ObservedFunctionCache(
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__)
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module)
sys; steady_state = false, eval_expression = false, eval_module = @__MODULE__, checkbounds = true)
return ObservedFunctionCache(sys, Dict(), steady_state, eval_expression, eval_module, checkbounds)
end

# This is hit because ensemble problems do a deepcopy
Expand All @@ -1694,7 +1695,7 @@ function (ofc::ObservedFunctionCache)(obsvar, args...)
obs = get!(ofc.dict, value(obsvar)) do
SymbolicIndexingInterface.observed(
ofc.sys, obsvar; eval_expression = ofc.eval_expression,
eval_module = ofc.eval_module)
eval_module = ofc.eval_module, checkbounds = ofc.checkbounds)
end
if ofc.steady_state
obs = let fn = obs
Expand Down
4 changes: 2 additions & 2 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ function DiffEqBase.ODEFunction{iip, specialize}(sys::AbstractODESystem,
ArrayInterface.restructure(u0 .* u0', M)
end

observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; steady_state, eval_expression, eval_module, checkbounds)

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down Expand Up @@ -522,7 +522,7 @@ function DiffEqBase.DAEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
_jac = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

jac_prototype = if sparse
uElType = u0 === nothing ? Float64 : eltype(u0)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ function DiffEqBase.SDEFunction{iip, specialize}(sys::SDESystem, dvs = unknowns(
M = calculate_massmatrix(sys)
_M = (u0 === nothing || M == I) ? M : ArrayInterface.restructure(u0 .* u0', M)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

SDEFunction{iip, specialize}(f, g;
sys = sys,
Expand Down
2 changes: 1 addition & 1 deletion src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ function SciMLBase.DiscreteFunction{iip, specialize}(
f = SciMLBase.wrapfun_iip(f, (u0, u0, p, t))
end

observedfun = ObservedFunctionCache(sys)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

DiscreteFunction{iip, specialize}(f;
sys = sys,
Expand Down
4 changes: 2 additions & 2 deletions src/systems/jumps/jumpsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ function DiffEqBase.DiscreteProblem(sys::JumpSystem, u0map, tspan::Union{Tuple,
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false, check_length = false)
f = DiffEqBase.DISCRETE_INPLACE_DEFAULT

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

df = DiscreteFunction{true, true}(f; sys = sys, observed = observedfun)
DiscreteProblem(df, u0, tspan, p; kwargs...)
Expand Down Expand Up @@ -527,7 +527,7 @@ function DiffEqBase.ODEProblem(sys::JumpSystem, u0map, tspan::Union{Tuple, Nothi
t = tspan === nothing ? nothing : tspan[1], use_union, tofloat = false,
check_length = false)
f = (du, u, p, t) -> (du .= 0; nothing)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))
df = ODEFunction(f; sys, observed = observedfun)
return ODEProblem(df, u0, tspan, p; kwargs...)
end
Expand Down
4 changes: 2 additions & 2 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ function SciMLBase.NonlinearFunction{iip}(sys::NonlinearSystem, dvs = unknowns(s
_jac = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

if length(dvs) == length(equations(sys))
resid_prototype = nothing
Expand Down Expand Up @@ -411,7 +411,7 @@ function SciMLBase.IntervalNonlinearFunction(
f(u, p) = f_oop(u, p)
f(u, p::MTKParameters) = f_oop(u, p...)

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds = get(kwargs, :checkbounds, false))

IntervalNonlinearFunction{false}(
f; observed = observedfun, sys = sys, initialization_data)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/optimization/optimizationsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ function DiffEqBase.OptimizationProblem{iip}(sys::OptimizationSystem, u0map,
hess_prototype = nothing
end

observedfun = ObservedFunctionCache(sys; eval_expression, eval_module)
observedfun = ObservedFunctionCache(sys; eval_expression, eval_module, checkbounds)

if length(cstr) > 0
@named cons_sys = ConstraintsSystem(cstr, dvs, ps; checks)
Expand Down

0 comments on commit 8c55fe0

Please sign in to comment.