Skip to content

Commit

Permalink
refactor: check if algorithm supports late binding tstops
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 12, 2024
1 parent b533d73 commit 354cda6
Showing 1 changed file with 27 additions and 0 deletions.
27 changes: 27 additions & 0 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,19 @@ function Base.showerror(io::IO, e::IncompatibleMassMatrixError)
println(io, TruncatedStacktraces.VERBOSE_MSG)
end

const LATE_BINDING_TSTOPS_ERROR_MESSAGE = """
This solver does not support providing `tstops` as a function.
Consider using a different solver or providing `tstops` as an array
of times.
"""

struct LateBindingTstopsNotSupportedError <: Exception end

function Base.showerror(io::IO, e::LateBindingTstopsNotSupportedError)
println(io, LATE_BINDING_TSTOPS_ERROR_MESSAGE)
println(io, TruncatedStacktraces.VERBOSE_MSG)
end

function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing,
kwargs...)
kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle
Expand Down Expand Up @@ -555,6 +568,13 @@ function init_up(prob::AbstractDEProblem, sensealg, u0, p, args...; kwargs...)
p = p, kwargs...)
init_call(_prob, args...; kwargs...)
else
tstops = get(kwargs, :tstops, nothing)
if tstops === nothing && has_kwargs(prob)
tstops = get(prob.kwargs, :tstops, nothing)
end
if !(tstops isa Union{Nothing, AbstractArray, Tuple}) && !SciMLBase.allows_late_binding_tstops(alg)
throw(LateBindingTstopsNotSupportedError())
end
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
_alg = prepare_alg(alg, _prob.u0, _prob.p, _prob)
check_prob_alg_pairing(_prob, alg) # alg for improved inference
Expand Down Expand Up @@ -1084,6 +1104,13 @@ function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0
p = p, kwargs...)
solve_call(_prob, args...; kwargs...)
else
tstops = get(kwargs, :tstops, nothing)
if tstops === nothing && has_kwargs(prob)
tstops = get(prob.kwargs, :tstops, nothing)
end
if !(tstops isa Union{Nothing, AbstractArray, Tuple}) && !SciMLBase.allows_late_binding_tstops(alg)
throw(LateBindingTstopsNotSupportedError())
end
_prob = get_concrete_problem(prob, isadaptive(alg); u0 = u0, p = p, kwargs...)
_alg = prepare_alg(alg, _prob.u0, _prob.p, _prob)
check_prob_alg_pairing(_prob, alg) # use alg for improved inference
Expand Down

0 comments on commit 354cda6

Please sign in to comment.