diff --git a/ext/DiffEqBaseChainRulesCoreExt.jl b/ext/DiffEqBaseChainRulesCoreExt.jl index 5a47ffb4a..162f0594e 100644 --- a/ext/DiffEqBaseChainRulesCoreExt.jl +++ b/ext/DiffEqBaseChainRulesCoreExt.jl @@ -11,19 +11,21 @@ ChainRulesCore.rrule(::typeof(numargs), f) = (numargs(f), df -> (NoTangent(), No ChainRulesCore.@non_differentiable DiffEqBase.checkkwargs(kwargshandle) function ChainRulesCore.frule(::typeof(DiffEqBase.solve_up), prob, - sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, - u0, p, args...; - kwargs...) - DiffEqBase._solve_forward(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...; + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; + kwargs...) + DiffEqBase._solve_forward( + prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...; kwargs...) end function ChainRulesCore.rrule(::typeof(DiffEqBase.solve_up), prob::AbstractDEProblem, - sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, - u0, p, args...; - kwargs...) - DiffEqBase._solve_adjoint(prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...; + sensealg::Union{Nothing, AbstractSensitivityAlgorithm}, + u0, p, args...; + kwargs...) + DiffEqBase._solve_adjoint( + prob, sensealg, u0, p, SciMLBase.ChainRulesOriginator(), args...; kwargs...) end -end \ No newline at end of file +end diff --git a/ext/DiffEqBaseEnzymeExt.jl b/ext/DiffEqBaseEnzymeExt.jl index c66d7e268..c9e3ff187 100644 --- a/ext/DiffEqBaseEnzymeExt.jl +++ b/ext/DiffEqBaseEnzymeExt.jl @@ -6,7 +6,10 @@ using Enzyme import Enzyme: Const using ChainRulesCore -function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT +function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWidth{1}, + func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, prob, + sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, + u0, p, args...; kwargs...) where {RT} @inline function copy_or_reuse(val, idx) if Enzyme.EnzymeRules.overwritten(config)[idx] && ismutable(val) return deepcopy(val) @@ -16,24 +19,30 @@ function Enzyme.EnzymeRules.augmented_primal(config::Enzyme.EnzymeRules.ConfigWi end @inline function arg_copy(i) - copy_or_reuse(args[i].val, i+5) + copy_or_reuse(args[i].val, i + 5) end - - res = DiffEqBase._solve_adjoint(copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...; + + res = DiffEqBase._solve_adjoint( + copy_or_reuse(prob.val, 2), copy_or_reuse(sensealg.val, 3), + copy_or_reuse(u0.val, 4), copy_or_reuse(p.val, 5), + SciMLBase.ChainRulesOriginator(), ntuple(arg_copy, Val(length(args)))...; kwargs...) dres = deepcopy(res[1])::RT for v in dres.u - v.= 0 + v .= 0 end tup = (dres, res[2]) return Enzyme.EnzymeRules.AugmentedReturn{RT, RT, Any}(res[1], dres, tup::Any) end -function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, u0, p, args...; kwargs...) where RT - dres, clos = tape +function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, + func::Const{typeof(DiffEqBase.solve_up)}, ::Type{Duplicated{RT}}, tape, prob, + sensealg::Union{Const{Nothing}, Const{<:DiffEqBase.AbstractSensitivityAlgorithm}}, + u0, p, args...; kwargs...) where {RT} + dres, clos = tape dres = dres::RT - dargs = clos(dres) + dargs = clos(dres) for (darg, ptr) in zip(dargs, (func, prob, sensealg, u0, p, args...)) if ptr isa Enzyme.Const continue @@ -44,9 +53,9 @@ function Enzyme.EnzymeRules.reverse(config::Enzyme.EnzymeRules.ConfigWidth{1}, f ptr.dval .+= darg end for v in dres.u - v.= 0 + v .= 0 end - return ntuple(_ -> nothing, Val(length(args)+4)) + return ntuple(_ -> nothing, Val(length(args) + 4)) end end diff --git a/ext/DiffEqBaseMPIExt.jl b/ext/DiffEqBaseMPIExt.jl index 2e081f25d..a5c0b2d82 100644 --- a/ext/DiffEqBaseMPIExt.jl +++ b/ext/DiffEqBaseMPIExt.jl @@ -10,7 +10,7 @@ end if isdefined(MPI, :AbstractMultiRequest) function DiffEqBase.anyeltypedual(::Type{T}, - counter = 0) where {T <: MPI.AbstractMultiRequest} + counter = 0) where {T <: MPI.AbstractMultiRequest} Any end end diff --git a/ext/DiffEqBaseMeasurementsExt.jl b/ext/DiffEqBaseMeasurementsExt.jl index fc4a1ab02..528392e23 100644 --- a/ext/DiffEqBaseMeasurementsExt.jl +++ b/ext/DiffEqBaseMeasurementsExt.jl @@ -11,7 +11,7 @@ else end function DiffEqBase.promote_u0(u0::AbstractArray{<:Measurements.Measurement}, - p::AbstractArray{<:Measurements.Measurement}, t0) + p::AbstractArray{<:Measurements.Measurement}, t0) u0 end DiffEqBase.promote_u0(u0, p::AbstractArray{<:Measurements.Measurement}, t0) = eltype(p).(u0) @@ -22,16 +22,17 @@ value(x::Measurements.Measurement) = Measurements.value(x) @inline DiffEqBase.fastpow(x::Measurements.Measurement, y::Measurements.Measurement) = x^y # Support adaptive steps should be errorless -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{ - <:Measurements.Measurement, - N, - }, - t) where {N} +@inline function DiffEqBase.ODE_DEFAULT_NORM( + u::AbstractArray{ + <:Measurements.Measurement, + N + }, + t) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) end @inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Measurements.Measurement, N}, - t) where {N} + t) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) end diff --git a/ext/DiffEqBaseMonteCarloMeasurementsExt.jl b/ext/DiffEqBaseMonteCarloMeasurementsExt.jl index 746dbf4a0..179ff5b65 100644 --- a/ext/DiffEqBaseMonteCarloMeasurementsExt.jl +++ b/ext/DiffEqBaseMonteCarloMeasurementsExt.jl @@ -10,16 +10,17 @@ else using ..MonteCarloMeasurements end -function DiffEqBase.promote_u0(u0::AbstractArray{ - <:MonteCarloMeasurements.AbstractParticles, - }, - p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - t0) +function DiffEqBase.promote_u0( + u0::AbstractArray{ + <:MonteCarloMeasurements.AbstractParticles, + }, + p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, + t0) u0 end function DiffEqBase.promote_u0(u0, - p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, - t0) + p::AbstractArray{<:MonteCarloMeasurements.AbstractParticles}, + t0) eltype(p).(u0) end @@ -27,23 +28,25 @@ DiffEqBase.value(x::Type{MonteCarloMeasurements.AbstractParticles{T, N}}) where DiffEqBase.value(x::MonteCarloMeasurements.AbstractParticles) = mean(x.particles) @inline function DiffEqBase.fastpow(x::MonteCarloMeasurements.AbstractParticles, - y::MonteCarloMeasurements.AbstractParticles) + y::MonteCarloMeasurements.AbstractParticles) x^y end # Support adaptive steps should be errorless -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{ - <:MonteCarloMeasurements.AbstractParticles, - N}, t) where {N} +@inline function DiffEqBase.ODE_DEFAULT_NORM( + u::AbstractArray{ + <:MonteCarloMeasurements.AbstractParticles, + N}, t) where {N} sqrt(mean(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((value(x) for x in u), Iterators.repeated(t)))) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{ - <:MonteCarloMeasurements.AbstractParticles, - N}, - t::AbstractArray{ - <:MonteCarloMeasurements.AbstractParticles, - N}) where {N} +@inline function DiffEqBase.ODE_DEFAULT_NORM( + u::AbstractArray{ + <:MonteCarloMeasurements.AbstractParticles, + N}, + t::AbstractArray{ + <:MonteCarloMeasurements.AbstractParticles, + N}) where {N} sqrt(mean(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((value(x) for x in u), Iterators.repeated(value.(t))))) end diff --git a/ext/DiffEqBaseReverseDiffExt.jl b/ext/DiffEqBaseReverseDiffExt.jl index 656a39d80..1319b9643 100644 --- a/ext/DiffEqBaseReverseDiffExt.jl +++ b/ext/DiffEqBaseReverseDiffExt.jl @@ -12,10 +12,10 @@ end DiffEqBase.value(x::Type{ReverseDiff.TrackedReal{V, D, O}}) where {V, D, O} = V function DiffEqBase.value(x::Type{ - ReverseDiff.TrackedArray{V, D, N, VA, DA}, + ReverseDiff.TrackedArray{V, D, N, VA, DA}, }) where {V, D, - N, VA, - DA} + N, VA, + DA} Array{V, N} end DiffEqBase.value(x::ReverseDiff.TrackedReal) = x.value @@ -26,15 +26,15 @@ DiffEqBase._reshape(v::AbstractVector{<:ReverseDiff.TrackedReal}, siz) = reduce( DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, p::ReverseDiff.TrackedArray, t0) = u0 function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, - p::ReverseDiff.TrackedArray, t0) + p::ReverseDiff.TrackedArray, t0) u0 end function DiffEqBase.promote_u0(u0::ReverseDiff.TrackedArray, - p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) + p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) u0 end function DiffEqBase.promote_u0(u0::AbstractArray{<:ReverseDiff.TrackedReal}, - p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) + p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) u0 end DiffEqBase.promote_u0(u0, p::ReverseDiff.TrackedArray, t0) = ReverseDiff.track(u0) @@ -44,13 +44,14 @@ DiffEqBase.promote_u0(u0, p::AbstractArray{<:ReverseDiff.TrackedReal}, t0) = elt @inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, t) sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N}, - t) where {N} +@inline function DiffEqBase.ODE_DEFAULT_NORM( + u::AbstractArray{<:ReverseDiff.TrackedReal, N}, + t) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) end @inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N}, - t) where {N} + t) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) end @@ -60,94 +61,95 @@ end # Support TrackedReal time, don't drop tracking on the adaptivity there @inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedArray, - t::ReverseDiff.TrackedReal) + t::ReverseDiff.TrackedReal) sqrt(sum(abs2, u) / length(u)) end -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:ReverseDiff.TrackedReal, N}, - t::ReverseDiff.TrackedReal) where {N} +@inline function DiffEqBase.ODE_DEFAULT_NORM( + u::AbstractArray{<:ReverseDiff.TrackedReal, N}, + t::ReverseDiff.TrackedReal) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) end @inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:ReverseDiff.TrackedReal, N}, - t::ReverseDiff.TrackedReal) where {N} + t::ReverseDiff.TrackedReal) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) end @inline function DiffEqBase.ODE_DEFAULT_NORM(u::ReverseDiff.TrackedReal, - t::ReverseDiff.TrackedReal) + t::ReverseDiff.TrackedReal) abs(u) end # `ReverseDiff.TrackedArray` function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, u0::ReverseDiff.TrackedArray, - p::ReverseDiff.TrackedArray, args...; kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, + p::ReverseDiff.TrackedArray, args...; kwargs...) ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, u0, p::ReverseDiff.TrackedArray, - args...; kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, p::ReverseDiff.TrackedArray, + args...; kwargs...) ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, u0::ReverseDiff.TrackedArray, p, - args...; kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, p, + args...; kwargs...) ReverseDiff.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end # `AbstractArray{<:ReverseDiff.TrackedReal}` function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, - u0::AbstractArray{<:ReverseDiff.TrackedReal}, - p::AbstractArray{<:ReverseDiff.TrackedReal}, args...; - kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, + u0::AbstractArray{<:ReverseDiff.TrackedReal}, + p::AbstractArray{<:ReverseDiff.TrackedReal}, args...; + kwargs...) DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), reduce(vcat, p), args...; kwargs...) end function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, u0, - p::AbstractArray{<:ReverseDiff.TrackedReal}, - args...; kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, + p::AbstractArray{<:ReverseDiff.TrackedReal}, + args...; kwargs...) DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...) end function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, u0::ReverseDiff.TrackedArray, - p::AbstractArray{<:ReverseDiff.TrackedReal}, - args...; kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::ReverseDiff.TrackedArray, + p::AbstractArray{<:ReverseDiff.TrackedReal}, + args...; kwargs...) DiffEqBase.solve_up(prob, sensealg, u0, reduce(vcat, p), args...; kwargs...) end function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, - u0::AbstractArray{<:ReverseDiff.TrackedReal}, p, - args...; kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, + u0::AbstractArray{<:ReverseDiff.TrackedReal}, p, + args...; kwargs...) DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...) end function DiffEqBase.solve_up(prob::DiffEqBase.DEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, - u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, - args...; kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, + u0::AbstractArray{<:ReverseDiff.TrackedReal}, p::ReverseDiff.TrackedArray, + args...; kwargs...) DiffEqBase.solve_up(prob, sensealg, reduce(vcat, u0), p, args...; kwargs...) end diff --git a/ext/DiffEqBaseTrackerExt.jl b/ext/DiffEqBaseTrackerExt.jl index f92e2a7da..b7a67c815 100644 --- a/ext/DiffEqBaseTrackerExt.jl +++ b/ext/DiffEqBaseTrackerExt.jl @@ -17,15 +17,15 @@ DiffEqBase.value(x::Tracker.TrackedArray) = x.data DiffEqBase.promote_u0(u0::Tracker.TrackedArray, p::Tracker.TrackedArray, t0) = u0 function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, - p::Tracker.TrackedArray, t0) + p::Tracker.TrackedArray, t0) u0 end function DiffEqBase.promote_u0(u0::Tracker.TrackedArray, - p::AbstractArray{<:Tracker.TrackedReal}, t0) + p::AbstractArray{<:Tracker.TrackedReal}, t0) u0 end function DiffEqBase.promote_u0(u0::AbstractArray{<:Tracker.TrackedReal}, - p::AbstractArray{<:Tracker.TrackedReal}, t0) + p::AbstractArray{<:Tracker.TrackedReal}, t0) u0 end DiffEqBase.promote_u0(u0, p::Tracker.TrackedArray, t0) = Tracker.track(u0) @@ -39,12 +39,12 @@ DiffEqBase.promote_u0(u0, p::AbstractArray{<:Tracker.TrackedReal}, t0) = eltype( sqrt(sum(abs2, DiffEqBase.value(u)) / length(u)) end @inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N}, - t) where {N} + t) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) end @inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N}, - t) where {N} + t) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((DiffEqBase.value(x) for x in u), Iterators.repeated(t))) / length(u)) end @@ -52,51 +52,51 @@ end # Support TrackedReal time, don't drop tracking on the adaptivity there @inline function DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedArray, - t::Tracker.TrackedReal) + t::Tracker.TrackedReal) sqrt(sum(abs2, u) / length(u)) end @inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{<:Tracker.TrackedReal, N}, - t::Tracker.TrackedReal) where {N} + t::Tracker.TrackedReal) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) end @inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Tracker.TrackedReal, N}, - t::Tracker.TrackedReal) where {N} + t::Tracker.TrackedReal) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip(u, Iterators.repeated(t))) / length(u)) end @inline DiffEqBase.ODE_DEFAULT_NORM(u::Tracker.TrackedReal, t::Tracker.TrackedReal) = abs(u) function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, u0::Tracker.TrackedArray, - p::Tracker.TrackedArray, args...; kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::Tracker.TrackedArray, + p::Tracker.TrackedArray, args...; kwargs...) Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, u0::Tracker.TrackedArray, p, args...; - kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0::Tracker.TrackedArray, p, args...; + kwargs...) Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end function DiffEqBase.solve_up(prob::DiffEqBase.AbstractDEProblem, - sensealg::Union{ - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - Nothing}, u0, p::Tracker.TrackedArray, args...; - kwargs...) + sensealg::Union{ + SciMLBase.AbstractOverloadingSensitivityAlgorithm, + Nothing}, u0, p::Tracker.TrackedArray, args...; + kwargs...) Tracker.track(DiffEqBase.solve_up, prob, sensealg, u0, p, args...; kwargs...) end Tracker.@grad function DiffEqBase.solve_up(prob, - sensealg::Union{Nothing, - SciMLBase.AbstractOverloadingSensitivityAlgorithm, - }, - u0, p, args...; - kwargs...) + sensealg::Union{Nothing, + SciMLBase.AbstractOverloadingSensitivityAlgorithm + }, + u0, p, args...; + kwargs...) out = DiffEqBase._solve_adjoint(prob, sensealg, Tracker.data(u0), Tracker.data(p), SciMLBase.TrackerOriginator(), args...; kwargs...) Array(out[1]), out[2] diff --git a/ext/DiffEqBaseUnitfulExt.jl b/ext/DiffEqBaseUnitfulExt.jl index 757ee6e85..d74525804 100644 --- a/ext/DiffEqBaseUnitfulExt.jl +++ b/ext/DiffEqBaseUnitfulExt.jl @@ -12,16 +12,17 @@ end # Support adaptive errors should be errorless for exponentiation value(x::Type{Unitful.AbstractQuantity{T, D, U}}) where {T, D, U} = T value(x::Unitful.AbstractQuantity) = x.val -@inline function DiffEqBase.ODE_DEFAULT_NORM(u::AbstractArray{ - <:Unitful.AbstractQuantity, - N, - }, - t) where {N} +@inline function DiffEqBase.ODE_DEFAULT_NORM( + u::AbstractArray{ + <:Unitful.AbstractQuantity, + N + }, + t) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) end @inline function DiffEqBase.ODE_DEFAULT_NORM(u::Array{<:Unitful.AbstractQuantity, N}, - t) where {N} + t) where {N} sqrt(sum(x -> DiffEqBase.ODE_DEFAULT_NORM(x[1], x[2]), zip((value(x) for x in u), Iterators.repeated(t))) / length(u)) end diff --git a/src/DiffEqBase.jl b/src/DiffEqBase.jl index c5c348d6f..ffb4eee25 100644 --- a/src/DiffEqBase.jl +++ b/src/DiffEqBase.jl @@ -47,52 +47,56 @@ PrecompileTools.@recompile_invalidations begin using SciMLOperators: AbstractSciMLOperator, AbstractSciMLScalarOperator using SciMLBase: @def, DEIntegrator, AbstractDEProblem, - AbstractDiffEqInterpolation, - DECallback, AbstractDEOptions, DECache, AbstractContinuousCallback, - AbstractDiscreteCallback, AbstractLinearProblem, AbstractNonlinearProblem, - AbstractOptimizationProblem, AbstractSteadyStateProblem, - AbstractJumpProblem, - AbstractNoiseProblem, AbstractEnsembleProblem, AbstractDynamicalODEProblem, - AbstractDEAlgorithm, StandardODEProblem, AbstractIntegralProblem, - AbstractSensitivityAlgorithm, AbstractODEAlgorithm, - AbstractSDEAlgorithm, AbstractDDEAlgorithm, AbstractDAEAlgorithm, - AbstractSDDEAlgorithm, AbstractRODEAlgorithm, DAEInitializationAlgorithm, - AbstractSteadyStateAlgorithm, AbstractODEProblem, AbstractDiscreteProblem, - AbstractSDEProblem, AbstractRODEProblem, AbstractDDEProblem, - AbstractDAEProblem, AbstractSDDEProblem, AbstractBVProblem, - AbstractTimeseriesSolution, AbstractNoTimeSolution, numargs, - AbstractODEFunction, AbstractSDEFunction, AbstractRODEFunction, - AbstractDDEFunction, AbstractSDDEFunction, AbstractDAEFunction, - AbstractNonlinearFunction, AbstractEnsembleSolution, - AbstractODESolution, AbstractRODESolution, AbstractDAESolution, - AbstractDDESolution, - EnsembleAlgorithm, EnsembleSolution, EnsembleSummary, - NonlinearSolution, - TimeGradientWrapper, TimeDerivativeWrapper, UDerivativeWrapper, - UJacobianWrapper, ParamJacobianWrapper, JacobianWrapper, - check_error!, has_jac, has_tgrad, has_Wfact, has_Wfact_t, has_paramjac, - AbstractODEIntegrator, AbstractSDEIntegrator, AbstractRODEIntegrator, - AbstractDDEIntegrator, AbstractSDDEIntegrator, - AbstractDAEIntegrator, unwrap_cache, has_reinit, reinit!, - postamble!, last_step_failed, islinear, has_stats, - initialize_dae!, build_solution, solution_new_retcode, - solution_new_tslocation, plot_indices, - NullParameters, isinplace, AbstractADType, AbstractDiscretization, - DISCRETE_OUTOFPLACE_DEFAULT, DISCRETE_INPLACE_DEFAULT, - has_analytic, calculate_solution_errors!, AbstractNoiseProcess, - has_colorvec, parameterless_type, undefined_exports, - is_diagonal_noise, AbstractDiffEqFunction, sensitivity_solution, - interp_summary, AbstractHistoryFunction, LinearInterpolation, - ConstantInterpolation, HermiteInterpolation, SensitivityInterpolation, - NoAD, @add_kwonly, - calculate_ensemble_errors, DEFAULT_UPDATE_FUNC, isconstant, - DEFAULT_REDUCTION, isautodifferentiable, - isadaptive, isdiscrete, has_syms, AbstractAnalyticalSolution, - RECOMPILE_BY_DEFAULT, wrap_sol, has_destats + AbstractDiffEqInterpolation, + DECallback, AbstractDEOptions, DECache, AbstractContinuousCallback, + AbstractDiscreteCallback, AbstractLinearProblem, + AbstractNonlinearProblem, + AbstractOptimizationProblem, AbstractSteadyStateProblem, + AbstractJumpProblem, + AbstractNoiseProblem, AbstractEnsembleProblem, + AbstractDynamicalODEProblem, + AbstractDEAlgorithm, StandardODEProblem, AbstractIntegralProblem, + AbstractSensitivityAlgorithm, AbstractODEAlgorithm, + AbstractSDEAlgorithm, AbstractDDEAlgorithm, AbstractDAEAlgorithm, + AbstractSDDEAlgorithm, AbstractRODEAlgorithm, + DAEInitializationAlgorithm, + AbstractSteadyStateAlgorithm, AbstractODEProblem, + AbstractDiscreteProblem, + AbstractSDEProblem, AbstractRODEProblem, AbstractDDEProblem, + AbstractDAEProblem, AbstractSDDEProblem, AbstractBVProblem, + AbstractTimeseriesSolution, AbstractNoTimeSolution, numargs, + AbstractODEFunction, AbstractSDEFunction, AbstractRODEFunction, + AbstractDDEFunction, AbstractSDDEFunction, AbstractDAEFunction, + AbstractNonlinearFunction, AbstractEnsembleSolution, + AbstractODESolution, AbstractRODESolution, AbstractDAESolution, + AbstractDDESolution, + EnsembleAlgorithm, EnsembleSolution, EnsembleSummary, + NonlinearSolution, + TimeGradientWrapper, TimeDerivativeWrapper, UDerivativeWrapper, + UJacobianWrapper, ParamJacobianWrapper, JacobianWrapper, + check_error!, has_jac, has_tgrad, has_Wfact, has_Wfact_t, has_paramjac, + AbstractODEIntegrator, AbstractSDEIntegrator, AbstractRODEIntegrator, + AbstractDDEIntegrator, AbstractSDDEIntegrator, + AbstractDAEIntegrator, unwrap_cache, has_reinit, reinit!, + postamble!, last_step_failed, islinear, has_stats, + initialize_dae!, build_solution, solution_new_retcode, + solution_new_tslocation, plot_indices, + NullParameters, isinplace, AbstractADType, AbstractDiscretization, + DISCRETE_OUTOFPLACE_DEFAULT, DISCRETE_INPLACE_DEFAULT, + has_analytic, calculate_solution_errors!, AbstractNoiseProcess, + has_colorvec, parameterless_type, undefined_exports, + is_diagonal_noise, AbstractDiffEqFunction, sensitivity_solution, + interp_summary, AbstractHistoryFunction, LinearInterpolation, + ConstantInterpolation, HermiteInterpolation, SensitivityInterpolation, + NoAD, @add_kwonly, + calculate_ensemble_errors, DEFAULT_UPDATE_FUNC, isconstant, + DEFAULT_REDUCTION, isautodifferentiable, + isadaptive, isdiscrete, has_syms, AbstractAnalyticalSolution, + RECOMPILE_BY_DEFAULT, wrap_sol, has_destats import SciMLBase: solve, init, step!, solve!, __init, __solve, update_coefficients!, - update_coefficients, isadaptive, wrapfun_oop, wrapfun_iip, - unwrap_fw, promote_tspan, set_u!, set_t!, set_ut! + update_coefficients, isadaptive, wrapfun_oop, wrapfun_iip, + unwrap_fw, promote_tspan, set_u!, set_t!, set_ut! import SciMLBase: AbstractDiffEqLinearOperator # deprecation path @@ -162,12 +166,13 @@ export initialize!, finalize! export SensitivityADPassThrough export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode, - NormTerminationMode, RelTerminationMode, RelNormTerminationMode, AbsTerminationMode, - AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode, - RelSafeBestTerminationMode, AbsSafeBestTerminationMode + NormTerminationMode, RelTerminationMode, RelNormTerminationMode, AbsTerminationMode, + AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode, + RelSafeBestTerminationMode, AbsSafeBestTerminationMode # Deprecated API export NLSolveTerminationMode, - NLSolveSafeTerminationOptions, NLSolveTerminationCondition, NLSolveSafeTerminationResult + NLSolveSafeTerminationOptions, NLSolveTerminationCondition, + NLSolveSafeTerminationResult export KeywordArgError, KeywordArgWarn, KeywordArgSilent diff --git a/src/calculate_residuals.jl b/src/calculate_residuals.jl index dbcfdaf7a..f2453c20c 100644 --- a/src/calculate_residuals.jl +++ b/src/calculate_residuals.jl @@ -7,14 +7,14 @@ Calculate element-wise residuals ``` """ @inline @muladd function calculate_residuals(ũ::Number, u₀::Number, u₁::Number, - α, ρ, internalnorm, t) + α, ρ, internalnorm, t) @fastmath ũ / (α + max(internalnorm(u₀, t), internalnorm(u₁, t)) * ρ) end @inline function calculate_residuals(ũ::Array{T}, u₀::Array{T}, u₁::Array{T}, α::T2, - ρ::Real, internalnorm, - t) where - {T <: Number, T2 <: Number} + ρ::Real, internalnorm, + t) where + {T <: Number, T2 <: Number} out = similar(ũ) calculate_residuals!(out, ũ, u₀, u₁, α, ρ, internalnorm, t) out @@ -34,13 +34,13 @@ Calculate element-wise residuals """ @inline @muladd function calculate_residuals(u₀::Number, u₁::Number, - α, ρ, internalnorm, t) + α, ρ, internalnorm, t) @fastmath (u₁ - u₀) / (α + max(internalnorm(u₀, t), internalnorm(u₁, t)) * ρ) end @inline function calculate_residuals(u₀::Array{T}, u₁::Array{T}, α::T2, - ρ::Real, internalnorm, - t) where {T <: Number, T2 <: Number} + ρ::Real, internalnorm, + t) where {T <: Number, T2 <: Number} out = similar(u₀) calculate_residuals!(out, u₀, u₁, α, ρ, internalnorm, t) out @@ -58,14 +58,15 @@ Return element-wise residuals \\frac{δ E₁ + E₂}{α+\\max{scalarnorm(u₀),scalarnorm(u₁)}*ρ}. ``` """ -@inline @muladd function calculate_residuals(E₁::Number, E₂::Number, u₀::Number, u₁::Number, - α::Real, ρ::Real, δ::Number, scalarnorm, t) +@inline @muladd function calculate_residuals( + E₁::Number, E₂::Number, u₀::Number, u₁::Number, + α::Real, ρ::Real, δ::Number, scalarnorm, t) @fastmath (δ * E₁ + E₂) / (α + max(scalarnorm(u₀, t), scalarnorm(u₁, t)) * ρ) end @inline function calculate_residuals(E₁::Array{<:Number}, E₂::Array{<:Number}, - u₀::Array{<:Number}, u₁::Array{<:Number}, α::Real, - ρ::Real, δ::Number, scalarnorm, t) + u₀::Array{<:Number}, u₁::Array{<:Number}, α::Real, + ρ::Real, δ::Number, scalarnorm, t) out = similar(u₀) calculate_residuals!(out, E₁, E₂, u₀, u₁, α, ρ, δ, scalarnorm, t) out @@ -92,14 +93,16 @@ or use multiple threads (`thread = True()`) when Julia is started with multiple threads. """ @inline function calculate_residuals!(out, ũ, u₀, u₁, α, ρ, internalnorm, t, - thread::Union{False, True} = False()) - @.. broadcast=false thread=thread out=calculate_residuals(ũ, u₀, u₁, α, ρ, internalnorm, + thread::Union{False, True} = False()) + @.. broadcast=false thread=thread out=calculate_residuals( + ũ, u₀, u₁, α, ρ, internalnorm, t) nothing end -@inline function calculate_residuals!(out::Array, ũ::Array, u₀::Array, u₁::Array, α::Number, - ρ::Number, internalnorm::F, t, ::False) where {F} +@inline function calculate_residuals!( + out::Array, ũ::Array, u₀::Array, u₁::Array, α::Number, + ρ::Number, internalnorm::F, t, ::False) where {F} @inbounds @simd ivdep for i in eachindex(out, ũ, u₀, u₁) out[i] = calculate_residuals(ũ[i], u₀[i], u₁[i], α, ρ, internalnorm, t) end @@ -121,7 +124,7 @@ or use multiple threads (`thread = True()`) when Julia is started with multiple threads. """ @inline function calculate_residuals!(out, u₀, u₁, α, ρ, internalnorm, t, - thread::Union{False, True} = False()) + thread::Union{False, True} = False()) @.. broadcast=false thread=thread out=calculate_residuals(u₀, u₁, α, ρ, internalnorm, t) end @@ -139,7 +142,7 @@ or use multiple threads (`thread = True()`) when Julia is started with multiple threads. """ @inline function calculate_residuals!(out, E₁, E₂, u₀, u₁, α, ρ, δ, scalarnorm, t, - thread::Union{False, True} = False()) + thread::Union{False, True} = False()) @.. broadcast=false thread=thread out=calculate_residuals(E₁, E₂, u₀, u₁, α, ρ, δ, scalarnorm, t) out diff --git a/src/callbacks.jl b/src/callbacks.jl index 2267a490e..8ee11aaa3 100644 --- a/src/callbacks.jl +++ b/src/callbacks.jl @@ -9,12 +9,12 @@ function initialize!(cb::CallbackSet, u, t, integrator::DEIntegrator) end initialize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false function initialize!(u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback, cs::DECallback...) + c::DECallback, cs::DECallback...) c.initialize(c, u, t, integrator) initialize!(u, t, integrator, any_modified || integrator.u_modified, cs...) end function initialize!(u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback) + c::DECallback) c.initialize(c, u, t, integrator) any_modified || integrator.u_modified end @@ -29,12 +29,12 @@ function finalize!(cb::CallbackSet, u, t, integrator::DEIntegrator) end finalize!(cb::CallbackSet{Tuple{}, Tuple{}}, u, t, integrator::DEIntegrator) = false function finalize!(u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback, cs::DECallback...) + c::DECallback, cs::DECallback...) c.finalize(c, u, t, integrator) finalize!(u, t, integrator, any_modified || integrator.u_modified, cs...) end function finalize!(u, t, integrator::DEIntegrator, any_modified::Bool, - c::DECallback) + c::DECallback) c.finalize(c, u, t, integrator) any_modified || integrator.u_modified end @@ -108,7 +108,8 @@ function get_condition(integrator::DEIntegrator, callback, abst) end integrator.sol.stats.ncondition += 1 if callback isa VectorContinuousCallback - callback.condition(@view(integrator.callback_cache.tmp_condition[1:(callback.len)]), + callback.condition( + @view(integrator.callback_cache.tmp_condition[1:(callback.len)]), tmp, abst, integrator) return @view(integrator.callback_cache.tmp_condition[1:(callback.len)]) else @@ -118,15 +119,15 @@ end # Use a generated function for type stability even when many callbacks are given @inline function find_first_continuous_callback(integrator, - callbacks::Vararg{ - AbstractContinuousCallback, - N}) where {N} + callbacks::Vararg{ + AbstractContinuousCallback, + N}) where {N} find_first_continuous_callback(integrator, tuple(callbacks...)) end @generated function find_first_continuous_callback(integrator, - callbacks::NTuple{N, - AbstractContinuousCallback, - }) where {N} + callbacks::NTuple{N, + AbstractContinuousCallback + }) where {N} ex = quote tmin, upcrossing, event_occurred, event_idx = find_callback_time(integrator, callbacks[1], 1) @@ -135,7 +136,8 @@ end for i in 2:N ex = quote $ex - tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time(integrator, + tmin2, upcrossing2, event_occurred2, event_idx2 = find_callback_time( + integrator, callbacks[$i], $i) if event_occurred2 && (tmin2 < tmin || !event_occurred) @@ -155,7 +157,7 @@ end end @inline function determine_event_occurance(integrator, callback::VectorContinuousCallback, - counter) + counter) event_occurred = false if callback.interp_points != 0 addsteps!(integrator) @@ -191,7 +193,8 @@ end next_sign = @view(integrator.callback_cache.next_sign[1:(callback.len)]) if integrator.event_last_time == counter && - minimum(ODE_DEFAULT_NORM(ArrayInterface.allowed_getindex(previous_condition, + minimum(ODE_DEFAULT_NORM( + ArrayInterface.allowed_getindex(previous_condition, ivec), integrator.t)) <= 100ODE_DEFAULT_NORM(integrator.last_event_error, integrator.t) @@ -265,7 +268,7 @@ end end @inline function determine_event_occurance(integrator, callback::ContinuousCallback, - counter) + counter) event_occurred = false if callback.interp_points != 0 addsteps!(integrator) @@ -356,8 +359,9 @@ end # always ensures that if r = bisection(f, (x0, x1)) # then either f(nextfloat(r)) == 0 or f(nextfloat(r)) * f(r) < 0 # note: not really using bisection - uses the ITP method -function bisection(f, tup, t_forward::Bool, rootfind::SciMLBase.RootfindOpt, abstol, reltol; - maxiters = 1000) +function bisection( + f, tup, t_forward::Bool, rootfind::SciMLBase.RootfindOpt, abstol, reltol; + maxiters = 1000) if rootfind == SciMLBase.LeftRootFind solve(IntervalNonlinearProblem{false}(f, tup), InternalITP(), abstol = abstol, @@ -376,7 +380,7 @@ Modifies `next_sign` to be an array of booleans for if there is a sign change in the interval between prev_sign and next_sign """ function findall_events!(next_sign::Union{Array, SubArray}, affect!::F1, affect_neg!::F2, - prev_sign::Union{Array, SubArray}) where {F1, F2} + prev_sign::Union{Array, SubArray}) where {F1, F2} @inbounds for i in 1:length(prev_sign) next_sign[i] = ((prev_sign[i] < 0 && affect! !== nothing) || (prev_sign[i] > 0 && affect_neg! !== nothing)) && @@ -394,7 +398,8 @@ function findall_events!(next_sign, affect!::F1, affect_neg!::F2, prev_sign) whe end function find_callback_time(integrator, callback::ContinuousCallback, counter) - event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx = determine_event_occurance(integrator, + event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx = determine_event_occurance( + integrator, callback, counter) if event_occurred @@ -453,7 +458,8 @@ function find_callback_time(integrator, callback::ContinuousCallback, counter) end function find_callback_time(integrator, callback::VectorContinuousCallback, counter) - event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx = determine_event_occurance(integrator, + event_occurred, interp_index, ts, prev_sign, prev_sign_index, event_idx = determine_event_occurance( + integrator, callback, counter) if event_occurred @@ -474,7 +480,8 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun for idx in 1:length(event_idx) if ArrayInterface.allowed_getindex(event_idx, idx) != 0 function zero_func(abst, p = nothing) - ArrayInterface.allowed_getindex(get_condition(integrator, + ArrayInterface.allowed_getindex( + get_condition(integrator, callback, abst), idx) end @@ -500,7 +507,8 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun isone(integrator.tdir), callback.rootfind, callback.abstol, callback.reltol) if integrator.tdir * Θ < integrator.tdir * min_t - integrator.last_event_error = ODE_DEFAULT_NORM(zero_func(Θ), + integrator.last_event_error = ODE_DEFAULT_NORM( + zero_func(Θ), Θ) end end @@ -542,8 +550,8 @@ function find_callback_time(integrator, callback::VectorContinuousCallback, coun end function apply_callback!(integrator, - callback::Union{ContinuousCallback, VectorContinuousCallback}, - cb_time, prev_sign, event_idx) + callback::Union{ContinuousCallback, VectorContinuousCallback}, + cb_time, prev_sign, event_idx) if isadaptive(integrator) set_proposed_dt!(integrator, integrator.tdir * max(nextfloat(integrator.opts.dtmin), @@ -622,8 +630,8 @@ end end @inline function apply_discrete_callback!(integrator, discrete_modified::Bool, - saved_in_cb::Bool, callback::DiscreteCallback, - args...) + saved_in_cb::Bool, callback::DiscreteCallback, + args...) bool, saved_in_cb2 = apply_discrete_callback!(integrator, apply_discrete_callback!(integrator, callback)..., @@ -632,7 +640,7 @@ end end @inline function apply_discrete_callback!(integrator, discrete_modified::Bool, - saved_in_cb::Bool, callback::DiscreteCallback) + saved_in_cb::Bool, callback::DiscreteCallback) bool, saved_in_cb2 = apply_discrete_callback!(integrator, callback) discrete_modified || bool, saved_in_cb || saved_in_cb2 end @@ -676,7 +684,7 @@ mutable struct CallbackCache{conditionType, signType} end function CallbackCache(u, max_len, ::Type{conditionType}, - ::Type{signType}) where {conditionType, signType} + ::Type{signType}) where {conditionType, signType} tmp_condition = similar(u, conditionType, max_len) previous_condition = similar(u, conditionType, max_len) next_sign = similar(u, signType, max_len) @@ -685,7 +693,7 @@ function CallbackCache(u, max_len, ::Type{conditionType}, end function CallbackCache(max_len, ::Type{conditionType}, - ::Type{signType}) where {conditionType, signType} + ::Type{signType}) where {conditionType, signType} tmp_condition = zeros(conditionType, max_len) previous_condition = zeros(conditionType, max_len) next_sign = zeros(signType, max_len) diff --git a/src/common_defaults.jl b/src/common_defaults.jl index f638d1505..c87f766de 100644 --- a/src/common_defaults.jl +++ b/src/common_defaults.jl @@ -67,11 +67,12 @@ end Base.FastMath.sqrt_fast(real(sum(abs2 ∘ f, u)) / max(length(u), 1)) end -@inline function ODE_DEFAULT_NORM(u::Union{ - AbstractArray, - RecursiveArrayTools.AbstractVectorOfArray, - }, - t) +@inline function ODE_DEFAULT_NORM( + u::Union{ + AbstractArray, + RecursiveArrayTools.AbstractVectorOfArray + }, + t) Base.FastMath.sqrt_fast(UNITLESS_ABS2(u) / max(recursive_length(u), 1)) end @@ -97,7 +98,8 @@ end @inline NAN_CHECK(x::Number) = isnan(x) @inline NAN_CHECK(x::Float64) = isnan(x) || (x > 1e50) @inline NAN_CHECK(x::Enum) = false -@inline NAN_CHECK(x::Union{AbstractArray, RecursiveArrayTools.AbstractVectorOfArray}) = any(NAN_CHECK, +@inline NAN_CHECK(x::Union{AbstractArray, RecursiveArrayTools.AbstractVectorOfArray}) = any( + NAN_CHECK, x) @inline NAN_CHECK(x::RecursiveArrayTools.ArrayPartition) = any(NAN_CHECK, x.x) @inline function NAN_CHECK(x::SparseArrays.AbstractSparseMatrixCSC) @@ -124,7 +126,8 @@ end end @inline function NONLINEARSOLVE_DEFAULT_NORM(f::F, - u::Union{Array{T}, Iterators.Zip{<:Tuple{Vararg{Array{T}}}}}) where {F, T <: Union{AbstractFloat, Complex}} + u::Union{Array{T}, Iterators.Zip{<:Tuple{Vararg{Array{T}}}}}) where { + F, T <: Union{AbstractFloat, Complex}} x = zero(T) @inbounds @fastmath for ui in u x += abs2(f(ui)) @@ -138,7 +141,8 @@ end end @inline function NONLINEARSOLVE_DEFAULT_NORM(f::F, - u::StaticArraysCore.StaticArray{<:Tuple, T}) where {F, T <: Union{AbstractFloat, Complex}} + u::StaticArraysCore.StaticArray{<:Tuple, T}) where { + F, T <: Union{AbstractFloat, Complex}} return Base.FastMath.sqrt_fast(real(sum(abs2 ∘ f, u))) end diff --git a/src/forwarddiff.jl b/src/forwarddiff.jl index 0e89d838f..997f1179b 100644 --- a/src/forwarddiff.jl +++ b/src/forwarddiff.jl @@ -12,29 +12,29 @@ space for solving the equation. promote_dual(::Type{T}, ::Type{T2}) where {T, T2} = T promote_dual(::Type{T}, ::Type{T2}) where {T <: ForwardDiff.Dual, T2} = T function promote_dual(::Type{T}, - ::Type{T2}) where {T <: ForwardDiff.Dual, T2 <: ForwardDiff.Dual} + ::Type{T2}) where {T <: ForwardDiff.Dual, T2 <: ForwardDiff.Dual} T end promote_dual(::Type{T}, ::Type{T2}) where {T, T2 <: ForwardDiff.Dual} = T2 function promote_dual(::Type{T}, - ::Type{T2}) where {T3, T4, V, V2 <: ForwardDiff.Dual, N, N2, - T <: ForwardDiff.Dual{T3, V, N}, - T2 <: ForwardDiff.Dual{T4, V2, N2}} + ::Type{T2}) where {T3, T4, V, V2 <: ForwardDiff.Dual, N, N2, + T <: ForwardDiff.Dual{T3, V, N}, + T2 <: ForwardDiff.Dual{T4, V2, N2}} T2 end function promote_dual(::Type{T}, - ::Type{T2}) where {T3, T4, V <: ForwardDiff.Dual, V2, N, N2, - T <: ForwardDiff.Dual{T3, V, N}, - T2 <: ForwardDiff.Dual{T4, V2, N2}} + ::Type{T2}) where {T3, T4, V <: ForwardDiff.Dual, V2, N, N2, + T <: ForwardDiff.Dual{T3, V, N}, + T2 <: ForwardDiff.Dual{T4, V2, N2}} T end function promote_dual(::Type{T}, - ::Type{T2}) where { - T3, V <: ForwardDiff.Dual, V2 <: ForwardDiff.Dual, - N, - T <: ForwardDiff.Dual{T3, V, N}, - T2 <: ForwardDiff.Dual{T3, V2, N}} + ::Type{T2}) where { + T3, V <: ForwardDiff.Dual, V2 <: ForwardDiff.Dual, + N, + T <: ForwardDiff.Dual{T3, V, N}, + T2 <: ForwardDiff.Dual{T3, V2, N}} ForwardDiff.Dual{T3, promote_dual(V, V2), N} end @@ -52,15 +52,15 @@ end reduce_tup(op, map(f, x)) end # For other container types, we probably just want to call `mapreduce` -@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x, init=Any) +@inline diffeqmapreduce(f::F, op::OP, x) where {F, OP} = mapreduce(f, op, x, init = Any) struct DualEltypeChecker{T, T2} x::T counter::T2 end -getval(::Val{I}) where I = I -getval(::Type{Val{I}}) where I = I +getval(::Val{I}) where {I} = I +getval(::Type{Val{I}}) where {I} = I getval(I::Int) = I function (dec::DualEltypeChecker)(::Val{Y}) where {Y} @@ -92,20 +92,20 @@ upconversion is not done automatically, the user is required to upconvert all in themselves, for an example of how this can be confusing to a user see https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a-forced-differential-equation/82937 """ -@generated function anyeltypedual(x, ::Type{Val{counter}} = Val{0}) where counter +@generated function anyeltypedual(x, ::Type{Val{counter}} = Val{0}) where {counter} x = x.name === Core.Compiler.typename(Type) ? x.parameters[1] : x if x <: ForwardDiff.Dual :($x) elseif fieldnames(x) === () :(Any) elseif counter < DUALCHECK_RECURSION_MAX - T = diffeqmapreduce(x->anyeltypedual(x, Val{counter+1}), promote_dual, + T = diffeqmapreduce(x -> anyeltypedual(x, Val{counter + 1}), promote_dual, x.parameters) if T === Any || isconcretetype(T) :($T) else - :(diffeqmapreduce(DualEltypeChecker($x, $counter+1), promote_dual, - map(Val, fieldnames($(typeof(x)))))) + :(diffeqmapreduce(DualEltypeChecker($x, $counter + 1), promote_dual, + map(Val, fieldnames($(typeof(x)))))) end else :(Any) @@ -113,22 +113,44 @@ https://discourse.julialang.org/t/typeerror-in-julia-turing-when-sampling-for-a- end # Opt out since these are using for preallocation, not differentiation -anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, ::Type{Val{counter}} = Val{0}) where {counter} = Any -anyeltypedual(x::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: ForwardDiff.AbstractConfig} = Any -anyeltypedual(x::SciMLBase.RecipesBase.AbstractPlot, ::Type{Val{counter}} = Val{0}) where {counter} = Any -anyeltypedual(x::Returns, ::Type{Val{counter}} = Val{0}) where {counter} = anyeltypedual(x.value, Val{counter}) +function anyeltypedual(x::Union{ForwardDiff.AbstractConfig, Module}, + ::Type{Val{counter}} = Val{0}) where {counter} + Any +end +function anyeltypedual(x::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where {T <: + ForwardDiff.AbstractConfig} + Any +end +function anyeltypedual(x::SciMLBase.RecipesBase.AbstractPlot, + ::Type{Val{counter}} = Val{0}) where {counter} + Any +end +function anyeltypedual(x::Returns, ::Type{Val{counter}} = Val{0}) where {counter} + anyeltypedual(x.value, Val{counter}) +end if isdefined(PreallocationTools, :FixedSizeDiffCache) - anyeltypedual(x::PreallocationTools.FixedSizeDiffCache, ::Type{Val{counter}} = Val{0}) where {counter} = Any + function anyeltypedual(x::PreallocationTools.FixedSizeDiffCache, + ::Type{Val{counter}} = Val{0}) where {counter} + Any + end end Base.@pure function __anyeltypedual(::Type{T}) where {T} hasproperty(T, :parameters) ? mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) : T end -anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T} = __anyeltypedual(T) -anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: ForwardDiff.Dual} = T -function anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: Union{AbstractArray, Set}} +function anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T} + __anyeltypedual(T) +end +function anyeltypedual(::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where {T <: ForwardDiff.Dual} + T +end +function anyeltypedual(::Type{T}, + ::Type{Val{counter}} = Val{0}) where {counter} where {T <: + Union{AbstractArray, Set}} anyeltypedual(eltype(T)) end Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple} @@ -141,28 +163,39 @@ Base.@pure function __anyeltypedual_ntuple(::Type{T}) where {T <: NTuple} mapreduce(anyeltypedual, promote_dual, T.parameters; init = Any) end end -anyeltypedual(::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: NTuple} = __anyeltypedual_ntuple(T) +function anyeltypedual( + ::Type{T}, ::Type{Val{counter}} = Val{0}) where {counter} where {T <: NTuple} + __anyeltypedual_ntuple(T) +end # Any in this context just means not Dual -anyeltypedual(x::SciMLBase.NullParameters, ::Type{Val{counter}} = Val{0}) where {counter} = Any +function anyeltypedual( + x::SciMLBase.NullParameters, ::Type{Val{counter}} = Val{0}) where {counter} + Any +end function anyeltypedual(sol::RecursiveArrayTools.AbstractDiffEqArray, counter = 0) diffeqmapreduce(anyeltypedual, promote_dual, (sol.u, sol.t)) end -anyeltypedual(x::Number, ::Type{Val{counter}} = Val{0}) where {counter} = anyeltypedual(typeof(x)) -anyeltypedual(x::Union{String, Symbol}, ::Type{Val{counter}} = Val{0}) where {counter} = typeof(x) +function anyeltypedual(x::Number, ::Type{Val{counter}} = Val{0}) where {counter} + anyeltypedual(typeof(x)) +end +function anyeltypedual( + x::Union{String, Symbol}, ::Type{Val{counter}} = Val{0}) where {counter} + typeof(x) +end function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}}, - ::Type{Val{counter}} = Val{0}) where {counter} where { - T <: - Union{Number, + ::Type{Val{counter}} = Val{0}) where {counter} where { + T <: + Union{Number, Symbol, String}} anyeltypedual(T) end function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}}, - ::Type{Val{counter}} = Val{0}) where {counter} where { - T <: Union{ + ::Type{Val{counter}} = Val{0}) where {counter} where { + T <: Union{ AbstractArray{ <:Number, }, @@ -172,7 +205,7 @@ function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}}, anyeltypedual(eltype(x)) end function anyeltypedual(x::Union{Array{T}, AbstractArray{T}, Set{T}}, - ::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}} + ::Type{Val{counter}} = Val{0}) where {counter} where {N, T <: NTuple{N, <:Number}} anyeltypedual(eltype(x)) end @@ -182,7 +215,7 @@ function anyeltypedual(x::AbstractArray, ::Type{Val{counter}} = Val{0}) where {c anyeltypedual(eltype(x)) elseif !isempty(x) && all(i -> isassigned(x, i), 1:length(x)) && counter < DUALCHECK_RECURSION_MAX - _counter = Val{counter+1} + _counter = Val{counter + 1} mapreduce(y -> anyeltypedual(y, _counter), promote_dual, x) else # This fallback to Any is required since otherwise we cannot handle `undef` in all cases @@ -238,7 +271,7 @@ end end function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, - tspan::Tuple{<:ForwardDiff.Dual, <:ForwardDiff.Dual}, prob, kwargs) + tspan::Tuple{<:ForwardDiff.Dual, <:ForwardDiff.Dual}, prob, kwargs) return _promote_tspan(tspan, kwargs) end @@ -252,7 +285,7 @@ function promote_tspan(u0::AbstractArray{<:ForwardDiff.Dual}, p, tspan, prob, kw end function promote_tspan(u0::AbstractArray{<:Complex{<:ForwardDiff.Dual}}, p, tspan, prob, - kwargs) + kwargs) return _promote_tspan(real(eltype(u0)).(tspan), kwargs) end diff --git a/src/internal_euler.jl b/src/internal_euler.jl index 7a4d3d9c0..9bac8800f 100644 --- a/src/internal_euler.jl +++ b/src/internal_euler.jl @@ -13,10 +13,10 @@ struct FwdEulerAlg <: EulerAlgs end struct BwdEulerAlg <: EulerAlgs end function DiffEqBase.solve(prob::DiffEqBase.AbstractODEProblem{uType, tType, isinplace}, - Alg::FwdEulerAlg; - dt = (prob.tspan[2] - prob.tspan[1]) / 100, - tstops = tType[], - kwargs...) where {uType, tType, isinplace} + Alg::FwdEulerAlg; + dt = (prob.tspan[2] - prob.tspan[1]) / 100, + tstops = tType[], + kwargs...) where {uType, tType, isinplace} u0 = prob.u0 f = prob.f tspan = prob.tspan @@ -46,12 +46,12 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractODEProblem{uType, tType, isin end function DiffEqBase.solve(prob::DiffEqBase.AbstractODEProblem{uType, tType, isinplace}, - Alg::BwdEulerAlg; - dt = (prob.tspan[2] - prob.tspan[1]) / 100, - tstops = tType[], - tol = 1e-5, - maxiter = 100, - kwargs...) where {uType, tType, isinplace} + Alg::BwdEulerAlg; + dt = (prob.tspan[2] - prob.tspan[1]) / 100, + tstops = tType[], + tol = 1e-5, + maxiter = 100, + kwargs...) where {uType, tType, isinplace} u0 = prob.u0 f = prob.f tspan = prob.tspan diff --git a/src/internal_falsi.jl b/src/internal_falsi.jl index d8f693510..f71978bd9 100644 --- a/src/internal_falsi.jl +++ b/src/internal_falsi.jl @@ -22,8 +22,8 @@ simpler dependencies. struct InternalFalsi end function SciMLBase.solve(prob::IntervalNonlinearProblem, alg::InternalFalsi, args...; - maxiters = 1000, - kwargs...) + maxiters = 1000, + kwargs...) f = Base.Fix2(prob.f, prob.p) left, right = prob.tspan fl, fr = f(left), f(right) @@ -128,10 +128,11 @@ function scalar_nlsolve_ad(prob, alg::InternalFalsi, args...; kwargs...) return sol, partials end -function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, - <:ForwardDiff.Dual{T, V, P}}, - alg::InternalFalsi, args...; - kwargs...) where {uType, iip, T, V, P} +function SciMLBase.solve( + prob::IntervalNonlinearProblem{uType, iip, + <:ForwardDiff.Dual{T, V, P}}, + alg::InternalFalsi, args...; + kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), sol.resid; retcode = sol.retcode, @@ -139,14 +140,15 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, right = ForwardDiff.Dual{T, V, P}(sol.right, partials)) end -function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, - <:AbstractArray{ - <:ForwardDiff.Dual{T, +function SciMLBase.solve( + prob::IntervalNonlinearProblem{uType, iip, + <:AbstractArray{ + <:ForwardDiff.Dual{T, V, P}, - }}, - alg::InternalFalsi, args...; - kwargs...) where {uType, iip, T, V, P} + }}, + alg::InternalFalsi, args...; + kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), diff --git a/src/internal_itp.jl b/src/internal_itp.jl index 0b389bf97..a5776e0fd 100644 --- a/src/internal_itp.jl +++ b/src/internal_itp.jl @@ -11,8 +11,8 @@ end InternalITP() = InternalITP(0.007, 1.5, 10) function SciMLBase.solve(prob::IntervalNonlinearProblem{IP, Tuple{T, T}}, alg::InternalITP, - args...; - maxiters = 1000, kwargs...) where {IP, T} + args...; + maxiters = 1000, kwargs...) where {IP, T} f = Base.Fix2(prob.f, prob.p) left, right = prob.tspan # a and b fl, fr = f(left), f(right) @@ -127,10 +127,11 @@ function scalar_nlsolve_ad(prob, alg::InternalITP, args...; kwargs...) return sol, partials end -function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, - <:ForwardDiff.Dual{T, V, P}}, - alg::InternalITP, args...; - kwargs...) where {uType, iip, T, V, P} +function SciMLBase.solve( + prob::IntervalNonlinearProblem{uType, iip, + <:ForwardDiff.Dual{T, V, P}}, + alg::InternalITP, args...; + kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), sol.resid; retcode = sol.retcode, @@ -138,14 +139,15 @@ function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, right = ForwardDiff.Dual{T, V, P}(sol.right, partials)) end -function SciMLBase.solve(prob::IntervalNonlinearProblem{uType, iip, - <:AbstractArray{ - <:ForwardDiff.Dual{T, +function SciMLBase.solve( + prob::IntervalNonlinearProblem{uType, iip, + <:AbstractArray{ + <:ForwardDiff.Dual{T, V, P}, - }}, - alg::InternalITP, args...; - kwargs...) where {uType, iip, T, V, P} + }}, + alg::InternalITP, args...; + kwargs...) where {uType, iip, T, V, P} sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...) return SciMLBase.build_solution(prob, alg, ForwardDiff.Dual{T, V, P}(sol.u, partials), diff --git a/src/norecompile.jl b/src/norecompile.jl index 8446c8bcc..a097fa0f0 100644 --- a/src/norecompile.jl +++ b/src/norecompile.jl @@ -4,7 +4,8 @@ const dualT = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, Float64}, Floa dualgen(::Type{T}) where {T} = ForwardDiff.Dual{ForwardDiff.Tag{OrdinaryDiffEqTag, T}, T, 1} -const NORECOMPILE_IIP_SUPPORTED_ARGS = (Tuple{Vector{Float64}, Vector{Float64}, +const NORECOMPILE_IIP_SUPPORTED_ARGS = ( + Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}, Tuple{Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, Float64}) @@ -52,7 +53,7 @@ function wrapfun_oop(ff, inputs::Tuple = ()) end function wrapfun_iip(ff, - inputs::Tuple{T1, T2, T3, T4}) where {T1, T2, T3, T4} + inputs::Tuple{T1, T2, T3, T4}) where {T1, T2, T3, T4} T = eltype(T2) dualT = dualgen(T) dualT1 = ArrayInterface.promote_eltype(T1, dualT) @@ -72,20 +73,21 @@ function wrapfun_iip(ff, FunctionWrappersWrappers.FunctionWrappersWrapper{typeof(fwt), false}(fwt) end -const iip_arglists_default = (Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, +const iip_arglists_default = ( + Tuple{Vector{Float64}, Vector{Float64}, Vector{Float64}, Float64}, Tuple{Vector{Float64}, Vector{Float64}, SciMLBase.NullParameters, - Float64, + Float64 }, Tuple{Vector{dualT}, Vector{Float64}, Vector{Float64}, dualT}, Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, dualT}, Tuple{Vector{dualT}, Vector{dualT}, Vector{Float64}, Float64}, Tuple{Vector{dualT}, Vector{dualT}, SciMLBase.NullParameters, - Float64, + Float64 }, Tuple{Vector{dualT}, Vector{Float64}, - SciMLBase.NullParameters, dualT, + SciMLBase.NullParameters, dualT }) const iip_returnlists_default = ntuple(x -> Nothing, length(iip_arglists_default)) diff --git a/src/solve.jl b/src/solve.jl index 795e05546..e8e4bef9b 100644 --- a/src/solve.jl +++ b/src/solve.jl @@ -501,7 +501,7 @@ function Base.showerror(io::IO, e::IncompatibleMassMatrixError) end function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing, - kwargs...) + kwargs...) kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? _prob.kwargs[:kwargshandle] : kwargshandle @@ -510,8 +510,9 @@ function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothin if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) kwargs_temp = NamedTuple{ Base.diff_names(Base._nt_names(values(kwargs)), - (:callback,))}(values(kwargs)) - callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(_prob.kwargs[:callback], + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], values(kwargs).callback),)) kwargs = merge(kwargs_temp, callbacks) end @@ -530,8 +531,9 @@ function init_call(_prob, args...; merge_callbacks = true, kwargshandle = nothin end end -function init(prob::Union{AbstractDEProblem, NonlinearProblem}, args...; sensealg = nothing, - u0 = nothing, p = nothing, kwargs...) +function init( + prob::Union{AbstractDEProblem, NonlinearProblem}, args...; sensealg = nothing, + u0 = nothing, p = nothing, kwargs...) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] end @@ -565,7 +567,7 @@ function init_up(prob::AbstractDEProblem, sensealg, u0, p, args...; kwargs...) end function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothing, - kwargs...) + kwargs...) kwargshandle = kwargshandle === nothing ? KeywordArgError : kwargshandle kwargshandle = has_kwargs(_prob) && haskey(_prob.kwargs, :kwargshandle) ? _prob.kwargs[:kwargshandle] : kwargshandle @@ -574,8 +576,9 @@ function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothi if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) kwargs_temp = NamedTuple{ Base.diff_names(Base._nt_names(values(kwargs)), - (:callback,))}(values(kwargs)) - callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(_prob.kwargs[:callback], + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], values(kwargs).callback),)) kwargs = merge(kwargs_temp, callbacks) end @@ -590,7 +593,7 @@ function solve_call(_prob, args...; merge_callbacks = true, kwargshandle = nothi end if !(eltype(_prob.u0) <: Number) && !(eltype(_prob.u0) <: Enum) && - !(_prob.u0 isa AbstractVector{<:AbstractArray} && _prob isa BVProblem) + !(_prob.u0 isa AbstractVector{<:AbstractArray} && _prob isa BVProblem) # Allow Enums for FunctionMaps, make into a trait in the future # BVPs use Vector of Arrays for initial guesses throw(NonNumberEltypeError(eltype(_prob.u0))) @@ -621,10 +624,11 @@ mutable struct NullODEIntegrator{IIP, ProbType, T, SolType, F, P} <: p::P end function build_null_integrator(prob::AbstractDEProblem, args...; - kwargs...) + kwargs...) sol = solve(prob, args...; kwargs...) - return NullODEIntegrator{isinplace(prob), typeof(prob), eltype(prob.tspan), typeof(sol), - typeof(prob.f), typeof(prob.p), + return NullODEIntegrator{ + isinplace(prob), typeof(prob), eltype(prob.tspan), typeof(sol), + typeof(prob.f), typeof(prob.p) }(Float64[], Float64[], prob.tspan[1], @@ -647,13 +651,13 @@ function step!(integ::NullODEIntegrator, dt = nothing, stop_at_tdt = false) end function build_null_solution(prob::AbstractDEProblem, args...; - saveat = (), - save_everystep = true, - save_on = true, - save_start = save_everystep || isempty(saveat) || - saveat isa Number || prob.tspan[1] in saveat, - save_end = true, - kwargs...) + saveat = (), + save_everystep = true, + save_on = true, + save_start = save_everystep || isempty(saveat) || + saveat isa Number || prob.tspan[1] in saveat, + save_end = true, + kwargs...) ts = if saveat === () if save_start && save_end [prob.tspan[1], prob.tspan[2]] @@ -676,13 +680,13 @@ function build_null_solution(prob::AbstractDEProblem, args...; end function build_null_solution(prob::Union{SteadyStateProblem, NonlinearProblem}, args...; - saveat = (), - save_everystep = true, - save_on = true, - save_start = save_everystep || isempty(saveat) || - saveat isa Number || prob.tspan[1] in saveat, - save_end = true, - kwargs...) + saveat = (), + save_everystep = true, + save_on = true, + save_start = save_everystep || isempty(saveat) || + saveat isa Number || prob.tspan[1] in saveat, + save_end = true, + kwargs...) SciMLBase.build_solution(prob, nothing, Float64[], nothing; retcode = ReturnCode.Success) end @@ -969,7 +973,7 @@ the extension to other types is straightforward. `progress = true` you are enabling the progress bar. """ function solve(prob::AbstractDEProblem, args...; sensealg = nothing, - u0 = nothing, p = nothing, wrap = Val(true), kwargs...) + u0 = nothing, p = nothing, wrap = Val(true), kwargs...) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] end @@ -1026,7 +1030,7 @@ problems. https://docs.sciml.ai/SciMLSensitivity/stable/ """ function solve(prob::NonlinearProblem, args...; sensealg = nothing, - u0 = nothing, p = nothing, wrap = Val(true), kwargs...) + u0 = nothing, p = nothing, wrap = Val(true), kwargs...) if sensealg === nothing && haskey(prob.kwargs, :sensealg) sensealg = prob.kwargs[:sensealg] end @@ -1042,7 +1046,7 @@ function solve(prob::NonlinearProblem, args...; sensealg = nothing, end function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0, p, - args...; kwargs...) + args...; kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, @@ -1061,8 +1065,8 @@ function solve_up(prob::Union{AbstractDEProblem, NonlinearProblem}, sensealg, u0 end function solve_call(prob::SteadyStateProblem, - alg::SciMLBase.AbstractNonlinearAlgorithm, args...; - kwargs...) + alg::SciMLBase.AbstractNonlinearAlgorithm, args...; + kwargs...) solve_call(NonlinearProblem(prob), alg, args...; kwargs...) @@ -1133,12 +1137,12 @@ function get_concrete_problem(prob::AbstractEnsembleProblem, isadapt; kwargs...) end function solve(prob::PDEProblem, alg::AbstractDEAlgorithm, args...; - kwargs...) + kwargs...) solve(prob.prob, alg, args...; kwargs...) end function init(prob::PDEProblem, alg::AbstractDEAlgorithm, args...; - kwargs...) + kwargs...) init(prob.prob, alg, args...; kwargs...) end @@ -1361,8 +1365,9 @@ handle_distribution_u0(_u0) = _u0 eval_u0(u0::Function) = true eval_u0(u0) = false -function __solve(prob::AbstractDEProblem, args...; default_set = false, second_time = false, - kwargs...) +function __solve( + prob::AbstractDEProblem, args...; default_set = false, second_time = false, + kwargs...) if second_time throw(NoDefaultAlgorithmError()) elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm}) @@ -1373,7 +1378,7 @@ function __solve(prob::AbstractDEProblem, args...; default_set = false, second_t end function __init(prob::AbstractDEProblem, args...; default_set = false, second_time = false, - kwargs...) + kwargs...) if second_time throw(NoDefaultAlgorithmError()) elseif length(args) > 0 && !(first(args) isa Union{Nothing, AbstractDEAlgorithm}) @@ -1456,7 +1461,7 @@ end nothing end elseif first(solve_args) isa SciMLBase.AbstractSciMLAlgorithm && - !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) + !(first(solve_args) isa SciMLBase.EnsembleAlgorithm) first(solve_args) else nothing @@ -1485,7 +1490,7 @@ struct SensitivityADPassThrough <: AbstractDEAlgorithm end kwargs...) function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, - kwargs...) + kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, @@ -1498,8 +1503,9 @@ function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callba if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) kwargs_temp = NamedTuple{ Base.diff_names(Base._nt_names(values(kwargs)), - (:callback,))}(values(kwargs)) - callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(_prob.kwargs[:callback], + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], values(kwargs).callback),)) kwargs = merge(kwargs_temp, callbacks) end @@ -1515,7 +1521,7 @@ function _solve_adjoint(prob, sensealg, u0, p, originator, args...; merge_callba end function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callbacks = true, - kwargs...) + kwargs...) alg = extract_alg(args, kwargs, prob.kwargs) if isnothing(alg) || !(alg isa AbstractDEAlgorithm) # Default algorithm handling _prob = get_concrete_problem(prob, !(prob isa DiscreteProblem); u0 = u0, @@ -1528,8 +1534,9 @@ function _solve_forward(prob, sensealg, u0, p, originator, args...; merge_callba if merge_callbacks && haskey(_prob.kwargs, :callback) && haskey(kwargs, :callback) kwargs_temp = NamedTuple{ Base.diff_names(Base._nt_names(values(kwargs)), - (:callback,))}(values(kwargs)) - callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet(_prob.kwargs[:callback], + (:callback,))}(values(kwargs)) + callbacks = NamedTuple{(:callback,)}((DiffEqBase.CallbackSet( + _prob.kwargs[:callback], values(kwargs).callback),)) kwargs = merge(kwargs_temp, callbacks) end diff --git a/src/stats.jl b/src/stats.jl index 71e66ed2e..14833f69b 100644 --- a/src/stats.jl +++ b/src/stats.jl @@ -1,62 +1,62 @@ @static if isdefined(SciMLBase, :DEStats) - const Stats = SciMLBase.DEStats + const Stats = SciMLBase.DEStats else - """ - $(TYPEDEF) - - Statistics from the differential equation solver about the solution process. - - ## Fields - - - nf: Number of function evaluations. If the differential equation is a split function, - such as a `SplitFunction` for implicit-explicit (IMEX) integration, then `nf` is the - number of function evaluations for the first function (the implicit function) - - nf2: If the differential equation is a split function, such as a `SplitFunction` - for implicit-explicit (IMEX) integration, then `nf2` is the number of function - evaluations for the second function, i.e. the function treated explicitly. Otherwise - it is zero. - - nw: The number of W=I-gamma*J (or W=I/gamma-J) matrices constructed during the solving - process. - - nsolve: The number of linear solves `W\b` required for the integration. - - njacs: Number of Jacobians calculated during the integration. - - nnonliniter: Total number of iterations for the nonlinear solvers. - - nnonlinconvfail: Number of nonlinear solver convergence failures. - - ncondition: Number of calls to the condition function for callbacks. - - naccept: Number of accepted steps. - - nreject: Number of rejected steps. - - maxeig: Maximum eigenvalue over the solution. This is only computed if the - method is an auto-switching algorithm. - """ - mutable struct Stats - nf::Int - nf2::Int - nw::Int - nsolve::Int - njacs::Int - nnonliniter::Int - nnonlinconvfail::Int - ncondition::Int - naccept::Int - nreject::Int - maxeig::Float64 - end - - Base.@deprecate_binding DEStats Stats false - - Stats(x::Int = -1) = Stats(x, x, x, x, x, x, x, x, x, x, 0.0) - - function Base.show(io::IO, s::Stats) - println(io, summary(s)) - @printf io "%-50s %-d\n" "Number of function 1 evaluations:" s.nf - @printf io "%-50s %-d\n" "Number of function 2 evaluations:" s.nf2 - @printf io "%-50s %-d\n" "Number of W matrix evaluations:" s.nw - @printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve - @printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs - @printf io "%-50s %-d\n" "Number of nonlinear solver iterations:" s.nnonliniter - @printf io "%-50s %-d\n" "Number of nonlinear solver convergence failures:" s.nnonlinconvfail - @printf io "%-50s %-d\n" "Number of rootfind condition calls:" s.ncondition - @printf io "%-50s %-d\n" "Number of accepted steps:" s.naccept - @printf io "%-50s %-d" "Number of rejected steps:" s.nreject - iszero(s.maxeig) || @printf io "\n%-50s %-d" "Maximum eigenvalue recorded:" s.maxeig - end + """ + $(TYPEDEF) + + Statistics from the differential equation solver about the solution process. + + ## Fields + + - nf: Number of function evaluations. If the differential equation is a split function, + such as a `SplitFunction` for implicit-explicit (IMEX) integration, then `nf` is the + number of function evaluations for the first function (the implicit function) + - nf2: If the differential equation is a split function, such as a `SplitFunction` + for implicit-explicit (IMEX) integration, then `nf2` is the number of function + evaluations for the second function, i.e. the function treated explicitly. Otherwise + it is zero. + - nw: The number of W=I-gamma*J (or W=I/gamma-J) matrices constructed during the solving + process. + - nsolve: The number of linear solves `W\b` required for the integration. + - njacs: Number of Jacobians calculated during the integration. + - nnonliniter: Total number of iterations for the nonlinear solvers. + - nnonlinconvfail: Number of nonlinear solver convergence failures. + - ncondition: Number of calls to the condition function for callbacks. + - naccept: Number of accepted steps. + - nreject: Number of rejected steps. + - maxeig: Maximum eigenvalue over the solution. This is only computed if the + method is an auto-switching algorithm. + """ + mutable struct Stats + nf::Int + nf2::Int + nw::Int + nsolve::Int + njacs::Int + nnonliniter::Int + nnonlinconvfail::Int + ncondition::Int + naccept::Int + nreject::Int + maxeig::Float64 + end + + Base.@deprecate_binding DEStats Stats false + + Stats(x::Int = -1) = Stats(x, x, x, x, x, x, x, x, x, x, 0.0) + + function Base.show(io::IO, s::Stats) + println(io, summary(s)) + @printf io "%-50s %-d\n" "Number of function 1 evaluations:" s.nf + @printf io "%-50s %-d\n" "Number of function 2 evaluations:" s.nf2 + @printf io "%-50s %-d\n" "Number of W matrix evaluations:" s.nw + @printf io "%-50s %-d\n" "Number of linear solves:" s.nsolve + @printf io "%-50s %-d\n" "Number of Jacobians created:" s.njacs + @printf io "%-50s %-d\n" "Number of nonlinear solver iterations:" s.nnonliniter + @printf io "%-50s %-d\n" "Number of nonlinear solver convergence failures:" s.nnonlinconvfail + @printf io "%-50s %-d\n" "Number of rootfind condition calls:" s.ncondition + @printf io "%-50s %-d\n" "Number of accepted steps:" s.naccept + @printf io "%-50s %-d" "Number of rejected steps:" s.nreject + iszero(s.maxeig) || @printf io "\n%-50s %-d" "Maximum eigenvalue recorded:" s.maxeig + end end diff --git a/src/tableaus.jl b/src/tableaus.jl index 1ba51f327..ffbbdea5e 100644 --- a/src/tableaus.jl +++ b/src/tableaus.jl @@ -17,9 +17,9 @@ mutable struct ExplicitRKTableau{MType <: AbstractMatrix, VType <: AbstractVecto stability_size::S end function ExplicitRKTableau(A::MType, c::VType, α::VType, order; - adaptiveorder = 0, αEEst = similar(α, 0), - fsal = false, stability_size = 0.0, - d = similar(α, 0)) where {MType, VType} + adaptiveorder = 0, αEEst = similar(α, 0), + fsal = false, stability_size = 0.0, + d = similar(α, 0)) where {MType, VType} S = typeof(stability_size) ExplicitRKTableau{MType, VType, S}(A, c, α, αEEst, d, length(α), order, adaptiveorder, fsal, stability_size) @@ -41,6 +41,6 @@ mutable struct ImplicitRKTableau{MType <: AbstractMatrix, VType <: AbstractVecto adaptiveorder::Int #The lower order of the pair. Only used for adaptivity. end function ImplicitRKTableau(A::MType, c::VType, α::VType, order; - adaptiveorder = 0, αEEst = VType()) where {MType, VType} + adaptiveorder = 0, αEEst = VType()) where {MType, VType} ImplicitRKTableau{MType, VType}(A, c, α, αEEst, length(α), order, adaptiveorder) end diff --git a/src/termination_conditions.jl b/src/termination_conditions.jl index c71beebe3..5b1d5b1ed 100644 --- a/src/termination_conditions.jl +++ b/src/termination_conditions.jl @@ -259,7 +259,8 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T Vector{TT}(undef, mode.max_stalled_steps) best_value = initial_objective max_stalled_steps = mode.max_stalled_steps - if ArrayInterface.can_setindex(u_) && !(u_ isa Number) && step_norm_trace !== nothing + if ArrayInterface.can_setindex(u_) && !(u_ isa Number) && + step_norm_trace !== nothing u_diff_cache = similar(u_) else u_diff_cache = u_ @@ -337,7 +338,8 @@ function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminati return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol) end -function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})(mode::AbstractSafeNonlinearTerminationMode, +function (cache::NonlinearTerminationModeCache{uType, TT, dep_retcode})( + mode::AbstractSafeNonlinearTerminationMode, du, u, uprev, args...) where {uType, TT, dep_retcode} if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode objective = maximum(abs, du) @@ -483,21 +485,27 @@ function check_convergence(::RelTerminationMode, duₙ, uₙ, uₙ₋₁, abstol return all(@. abs(duₙ) ≤ reltol * abs(uₙ)) end -function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode, - RelSafeBestTerminationMode}, duₙ::ZIPPABLE_TYPES, uₙ::ZIPPABLE_TYPES, +function check_convergence( + ::Union{RelNormTerminationMode, RelSafeTerminationMode, + RelSafeBestTerminationMode}, + duₙ::ZIPPABLE_TYPES, uₙ::ZIPPABLE_TYPES, uₙ₋₁::ZIPPABLE_TYPES, abstol, reltol) return maximum(abs, duₙ) ≤ reltol * maximum(((x, y),) -> abs(x + y), zip(duₙ, uₙ)) end -function check_convergence(::Union{RelNormTerminationMode, RelSafeTerminationMode, - RelSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) +function check_convergence( + ::Union{RelNormTerminationMode, RelSafeTerminationMode, + RelSafeBestTerminationMode}, + duₙ, uₙ, uₙ₋₁, abstol, reltol) return maximum(abs, duₙ) ≤ reltol * maximum(abs, duₙ .+ uₙ) end function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol) return all(x -> abs(x) ≤ abstol, duₙ) end -function check_convergence(::Union{AbsNormTerminationMode, AbsSafeTerminationMode, - AbsSafeBestTerminationMode}, duₙ, uₙ, uₙ₋₁, abstol, reltol) +function check_convergence( + ::Union{AbsNormTerminationMode, AbsSafeTerminationMode, + AbsSafeBestTerminationMode}, + duₙ, uₙ, uₙ₋₁, abstol, reltol) return maximum(abs, duₙ) ≤ abstol end @@ -546,7 +554,8 @@ function NLSolveSafeTerminationResult(u = nothing; best_objective_value = Inf64, best_objective_value_iteration = 0, return_code = NLSolveSafeTerminationReturnCode.Failure) u = u !== nothing ? copy(u) : u - Base.depwarn("NLSolveSafeTerminationResult has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", + Base.depwarn( + "NLSolveSafeTerminationResult has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", :NLSolveSafeTerminationResult) return NLSolveSafeTerminationResult{typeof(best_objective_value), typeof(u)}(u, best_objective_value, best_objective_value_iteration, return_code) @@ -634,7 +643,8 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6, protective_threshold = 1e3, patience_steps::Int = 30, patience_objective_multiplier = 3, min_max_factor = 1.3) where {T} - Base.depwarn("NLSolveTerminationCondition has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", + Base.depwarn( + "NLSolveTerminationCondition has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!", :NLSolveTerminationCondition) @assert mode ∈ instances(NLSolveTerminationMode.T) options = if mode ∈ SAFE_TERMINATION_MODES @@ -648,8 +658,8 @@ end function (cond::NLSolveTerminationCondition)(storage::Union{ NLSolveSafeTerminationResult, - Nothing, - }) + Nothing +}) mode = get_termination_mode(cond) # We need both the dispatches to support solvers that don't use the integrator # interface like SimpleNonlinearSolve diff --git a/test/callbacks.jl b/test/callbacks.jl index 98bcc2824..5f4f98e1a 100644 --- a/test/callbacks.jl +++ b/test/callbacks.jl @@ -58,11 +58,11 @@ struct EmptyIntegrator u::Vector{Float64} end function DiffEqBase.find_callback_time(integrator::EmptyIntegrator, - callback::ContinuousCallback, counter) + callback::ContinuousCallback, counter) 1.0 + counter, 0.9 + counter, true, counter end function DiffEqBase.find_callback_time(integrator::EmptyIntegrator, - callback::VectorContinuousCallback, counter) + callback::VectorContinuousCallback, counter) 1.0 + counter, 0.9 + counter, true, counter end find_first_integrator = EmptyIntegrator([1.0, 2.0]) diff --git a/test/downstream/community_callback_tests.jl b/test/downstream/community_callback_tests.jl index e8cc5a268..7bcf6ff0a 100644 --- a/test/downstream/community_callback_tests.jl +++ b/test/downstream/community_callback_tests.jl @@ -64,7 +64,7 @@ perror = [ 0.029177617589788596, 0.03064986043089549, 0.023280222517122397, - 6.931251277770224, + 6.931251277770224 ] y_max = 0.002604806609572015 u0 = [1, zeros(length(perror) - 1)...] @@ -103,7 +103,7 @@ function attactor(du, u, p, t) u.nodes[k][3], u.nodes[k][4], -β * u.nodes[k][3], - -β * u.nodes[k][4], + -β * u.nodes[k][4] ] else du.nodes[k][3:4] .+= α * (u.nodes[j][1:2] - u.nodes[k][1:2]) @@ -128,7 +128,7 @@ Newton = construct(PhysicsLaw, Thingy([-700.0, -350.0, 0.0, 0.0]), Thingy([-550.0, -150.0, 0.0, 0.0]), Thingy([-600.0, 15.0, 0.0, 10.0]), - Thingy([200.0, -200.0, 5.0, -5.0]), + Thingy([200.0, -200.0, 5.0, -5.0]) ]) parameters = [1e-2, 0.06] @@ -219,21 +219,20 @@ soln = solve(prob, Tsit5()) @test soln.t[end] ≈ 4.712347213360699 odefun = ODEFunction((u, p, t) -> [u[2], u[2] - p]; mass_matrix = [1 0; 0 0]) -callback = PresetTimeCallback(.5, integ->(integ.p=-integ.p;)) -prob = ODEProblem(odefun, [0.0, -1.0], (0., 1), 1; callback) +callback = PresetTimeCallback(0.5, integ -> (integ.p = -integ.p)) +prob = ODEProblem(odefun, [0.0, -1.0], (0.0, 1), 1; callback) #test that reinit happens for both FSAL and non FSAL integrators @testset "dae re-init" for alg in [FBDF(), Rodas5P()] sol = solve(prob, alg) # test that the callback flipping p caused u[2] to get flipped. first_t = findfirst(isequal(0.5), sol.t) - @test sol.u[first_t][2] == -sol.u[first_t+1][2] + @test sol.u[first_t][2] == -sol.u[first_t + 1][2] end - daefun = DAEFunction((du, u, p, t) -> [du[1] - u[2], u[2] - p]) -prob = DAEProblem(daefun, [0.0, 0.0], [0.0,-1.0], (0., 1), 1; +prob = DAEProblem(daefun, [0.0, 0.0], [0.0, -1.0], (0.0, 1), 1; differential_vars = [true, false], callback) sol = solve(prob, DFBDF()) # test that the callback flipping p caused u[2] to get flipped. first_t = findfirst(isequal(0.5), sol.t) -@test sol.u[first_t][2] == -sol.u[first_t+1][2] +@test sol.u[first_t][2] == -sol.u[first_t + 1][2] diff --git a/test/downstream/dual_detection_solution.jl b/test/downstream/dual_detection_solution.jl index 8ec0601f8..2bf6e1ae2 100644 --- a/test/downstream/dual_detection_solution.jl +++ b/test/downstream/dual_detection_solution.jl @@ -3,9 +3,9 @@ using OrdinaryDiffEq ## https://github.com/SciML/DifferentialEquations.jl/issues/1013 mutable struct SomeObject - position - velocity - trajectory + position::Any + velocity::Any + trajectory::Any end object = SomeObject(0, 1, nothing) @@ -26,11 +26,11 @@ end # https://github.com/SciML/DiffEqBase.jl/issues/1003 -f(u,p,t) = 1.01*u -u0=1/2 -tspan = (0.0,1.0) -prob = ODEProblem(f,u0,tspan) -sol = solve(prob,Tsit5(),reltol=1e-8,abstol=1e-8) +f(u, p, t) = 1.01 * u +u0 = 1 / 2 +tspan = (0.0, 1.0) +prob = ODEProblem(f, u0, tspan) +sol = solve(prob, Tsit5(), reltol = 1e-8, abstol = 1e-8) -prob2 = ODEProblem((du,u,p,t) -> du[1]=1, [0.0], (0,10), (;x=sol)) -solve(prob2, Tsit5()) \ No newline at end of file +prob2 = ODEProblem((du, u, p, t) -> du[1] = 1, [0.0], (0, 10), (; x = sol)) +solve(prob2, Tsit5()) diff --git a/test/downstream/ensemble_analysis.jl b/test/downstream/ensemble_analysis.jl index ad19a8f33..7d0b46ec0 100644 --- a/test/downstream/ensemble_analysis.jl +++ b/test/downstream/ensemble_analysis.jl @@ -1,5 +1,5 @@ using StochasticDiffEq, DiffEqBase, - OrdinaryDiffEq, DiffEqBase.EnsembleAnalysis + OrdinaryDiffEq, DiffEqBase.EnsembleAnalysis using Test import SDEProblemLibrary: prob_sde_linear, prob_sde_2Dlinear diff --git a/test/downstream/inference.jl b/test/downstream/inference.jl index 75b3604ff..207c89836 100644 --- a/test/downstream/inference.jl +++ b/test/downstream/inference.jl @@ -10,25 +10,32 @@ prob = ODEProblem(lorenz, u0, tspan) sol = solve(prob, Tsit5(), save_idxs = 1) @inferred solve(prob, Tsit5()) @inferred solve(prob, Tsit5(), save_idxs = 1) -@test_broken @inferred(remake(prob, u0 = Float32[1.0; 0.0; 0.0])) == remake(prob, u0 = Float32[1.0; 0.0; 0.0]) -@test_broken @inferred(solve(prob, Tsit5(), u0 = Float32[1.0; 0.0; 0.0])) == solve(prob, Tsit5(), u0 = Float32[1.0; 0.0; 0.0]) +@test_broken @inferred(remake(prob, u0 = Float32[1.0; 0.0; 0.0])) == + remake(prob, u0 = Float32[1.0; 0.0; 0.0]) +@test_broken @inferred(solve(prob, Tsit5(), u0 = Float32[1.0; 0.0; 0.0])) == + solve(prob, Tsit5(), u0 = Float32[1.0; 0.0; 0.0]) prob = ODEProblem{true, SciMLBase.FullSpecialize}(lorenz, u0, tspan) @inferred SciMLBase.wrapfun_iip(prob.f) @inferred remake(prob, u0 = [1.0; 0.0; 0.0]) @inferred remake(prob, u0 = Float32[1.0; 0.0; 0.0]) -@test_broken @inferred(solve(prob, Tsit5(), u0 = Float32[1.0; 0.0; 0.0])) == solve(prob, Tsit5(), u0 = Float32[1.0; 0.0; 0.0]) +@test_broken @inferred(solve(prob, Tsit5(), u0 = Float32[1.0; 0.0; 0.0])) == + solve(prob, Tsit5(), u0 = Float32[1.0; 0.0; 0.0]) prob = ODEProblem(lorenz, Float32[1.0; 0.0; 0.0], tspan) @inferred solve(prob, Tsit5(), save_idxs = 1) -@test_broken @inferred(solve(prob, Tsit5(), u0 = [1.0; 0.0; 0.0])) == solve(prob, Tsit5(), u0 = [1.0; 0.0; 0.0]) +@test_broken @inferred(solve(prob, Tsit5(), u0 = [1.0; 0.0; 0.0])) == + solve(prob, Tsit5(), u0 = [1.0; 0.0; 0.0]) remake(prob, u0 = [1.0; 0.0; 0.0]) @inferred SciMLBase.wrapfun_iip(prob.f) -@test_broken @inferred(ODEFunction{isinplace(prob), SciMLBase.FunctionWrapperSpecialize}(prob.f)) == ODEFunction{isinplace(prob), SciMLBase.FunctionWrapperSpecialize}(prob.f) +@test_broken @inferred(ODEFunction{ + isinplace(prob), SciMLBase.FunctionWrapperSpecialize}(prob.f)) == + ODEFunction{isinplace(prob), SciMLBase.FunctionWrapperSpecialize}(prob.f) @inferred remake(prob, u0 = [1.0; 0.0; 0.0]) -@test_broken @inferred(solve(prob, Tsit5(), u0 = [1.0; 0.0; 0.0])) == solve(prob, Tsit5(), u0 = [1.0; 0.0; 0.0]) +@test_broken @inferred(solve(prob, Tsit5(), u0 = [1.0; 0.0; 0.0])) == + solve(prob, Tsit5(), u0 = [1.0; 0.0; 0.0]) function f(du, u, p, t) du[1] = p.a @@ -63,12 +70,18 @@ function solve_ode(f::F, p::P, ensemblealg; kwargs...) where {F, P} end @inferred solve_ode(f, (a = 1, b = 1), EnsembleSerial()) @inferred solve_ode(f, (a = 1, b = 1), EnsembleThreads()) -@test_broken @inferred(solve_ode(f, (a = 1, b = 1), EnsembleDistributed())) == solve_ode(f, (a = 1, b = 1), EnsembleDistributed()) -@test_broken @inferred(solve_ode(f, (a = 1, b = 1), EnsembleSplitThreads())) == solve_ode(f, (a = 1, b = 1), EnsembleSplitThreads()) +@test_broken @inferred(solve_ode(f, (a = 1, b = 1), EnsembleDistributed())) == + solve_ode(f, (a = 1, b = 1), EnsembleDistributed()) +@test_broken @inferred(solve_ode(f, (a = 1, b = 1), EnsembleSplitThreads())) == + solve_ode(f, (a = 1, b = 1), EnsembleSplitThreads()) @inferred solve_ode(f, (a = 1, b = 1), EnsembleSerial(), save_idxs = 1) @inferred solve_ode(f, (a = 1, b = 1), EnsembleThreads(), save_idxs = 1) -@test_broken @inferred(solve_ode(f, (a = 1, b = 1), EnsembleDistributed(), save_idxs = 1)) == solve_ode(f, (a = 1, b = 1), EnsembleDistributed(), save_idxs = 1) -@test_broken @inferred(solve_ode(f, (a = 1, b = 1), EnsembleSplitThreads(), save_idxs = 1)) == solve_ode(f, (a = 1, b = 1), EnsembleSplitThreads(), save_idxs = 1) +@test_broken @inferred(solve_ode( + f, (a = 1, b = 1), EnsembleDistributed(), save_idxs = 1)) == + solve_ode(f, (a = 1, b = 1), EnsembleDistributed(), save_idxs = 1) +@test_broken @inferred(solve_ode( + f, (a = 1, b = 1), EnsembleSplitThreads(), save_idxs = 1)) == + solve_ode(f, (a = 1, b = 1), EnsembleSplitThreads(), save_idxs = 1) using StochasticDiffEq, Test u0 = 1 / 2 diff --git a/test/downstream/kwarg_warn.jl b/test/downstream/kwarg_warn.jl index 3c7e9b014..32e77976a 100644 --- a/test/downstream/kwarg_warn.jl +++ b/test/downstream/kwarg_warn.jl @@ -9,7 +9,8 @@ tspan = (0.0, 100.0) prob = ODEProblem(lorenz, u0, tspan) @test_nowarn sol = solve(prob, Tsit5(), reltol = 1e-6) sol = solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn) -@test_logs (:warn, DiffEqBase.KWARGWARN_MESSAGE) sol=solve(prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn) +@test_logs (:warn, DiffEqBase.KWARGWARN_MESSAGE) sol=solve( + prob, Tsit5(), rel_tol = 1e-6, kwargshandle = DiffEqBase.KeywordArgWarn) @test_throws DiffEqBase.CommonKwargError sol=solve(prob, Tsit5(), rel_tol = 1e-6) prob = ODEProblem(lorenz, u0, tspan, test = 2.0, kwargshandle = DiffEqBase.KeywordArgWarn) diff --git a/test/downstream/null_de.jl b/test/downstream/null_de.jl index 9df66755e..d49a59aff 100644 --- a/test/downstream/null_de.jl +++ b/test/downstream/null_de.jl @@ -2,7 +2,7 @@ using ModelingToolkit, OrdinaryDiffEq, SteadyStateDiffEq, Test @variables t x(t) y(t) eqs = [0 ~ x - y - 0 ~ y - x] + 0 ~ y - x] @named sys = ODESystem(eqs, t) sys = structural_simplify(sys) @@ -16,7 +16,7 @@ sol = solve(prob, Tsit5()) for kwargs in [ Dict(:saveat => 0:0.1:1), Dict(:save_start => false), - Dict(:save_end => false), + Dict(:save_end => false) ] sol = solve(prob, kwargs...) init_integ = init(prob, kwargs...) @@ -31,7 +31,7 @@ end @variables t x y eqs = [0 ~ x - y - 0 ~ y - x] + 0 ~ y - x] @named sys = NonlinearSystem(eqs, [x, y], []) sys = structural_simplify(sys) diff --git a/test/downstream/solve_error_handling.jl b/test/downstream/solve_error_handling.jl index 4a7447104..91af4e028 100644 --- a/test/downstream/solve_error_handling.jl +++ b/test/downstream/solve_error_handling.jl @@ -75,10 +75,10 @@ prob = SDEProblem(f, noise = StochasticDiffEq.RealWienerProcess(0.0, zeros(3))) @test_throws DiffEqBase.NoiseSizeIncompatabilityError solve(prob, LambaEM()) -function g!(du,u,p,t) - du[1] .= u[1] + ones(3,3) - du[2] .= ones(3,3) - end - u0 = [zeros(3,3),zeros(3,3)] - prob = ODEProblem(g!,u0,(0,1.0)) -@test_throws DiffEqBase.NonNumberEltypeError solve(prob,Tsit5()) +function g!(du, u, p, t) + du[1] .= u[1] + ones(3, 3) + du[2] .= ones(3, 3) +end +u0 = [zeros(3, 3), zeros(3, 3)] +prob = ODEProblem(g!, u0, (0, 1.0)) +@test_throws DiffEqBase.NonNumberEltypeError solve(prob, Tsit5()) diff --git a/test/forwarddiff_dual_detection.jl b/test/forwarddiff_dual_detection.jl index 100b80797..c390132c5 100644 --- a/test/forwarddiff_dual_detection.jl +++ b/test/forwarddiff_dual_detection.jl @@ -35,7 +35,7 @@ p_possibilities = [ForwardDiff.Dual(2.0), (ForwardDiff.Dual(2.0), 2.0), (; x = 2.0, y = [ForwardDiff.Dual(2.0)]), (; x = 2.0, y = [[ForwardDiff.Dual(2.0)]]), Set([2.0, ForwardDiff.Dual(2.0)]), (SciMLBase.NullParameters(), ForwardDiff.Dual(2.0)), ((), ForwardDiff.Dual(2.0)), ForwardDiff.Dual{Nothing}(ForwardDiff.Dual{MyStruct}(2.0)), - (plot(), ForwardDiff.Dual(2.0)), + (plot(), ForwardDiff.Dual(2.0)) ] for p in p_possibilities @@ -58,7 +58,7 @@ higher_order_p_possibilities = [ForwardDiff.Dual{Nothing}(ForwardDiff.Dual{MyStr (ForwardDiff.Dual{Nothing}(ForwardDiff.Dual{MyStruct}(2.0)), ForwardDiff.Dual{Nothing}(2.0)), (ForwardDiff.Dual{Nothing}(2.0), - ForwardDiff.Dual{Nothing}(ForwardDiff.Dual{MyStruct}(2.0))), + ForwardDiff.Dual{Nothing}(ForwardDiff.Dual{MyStruct}(2.0))) ] for p in higher_order_p_possibilities @@ -83,7 +83,7 @@ p_possibilities17 = [ [MyStruct(2.0, (2.0, ForwardDiff.Dual(2.0)))], ((;), ForwardDiff.Dual(2.0)), MyStruct3(ForwardDiff.Dual(2.0)), (Mod, ForwardDiff.Dual(2.0)), (() -> 2.0, ForwardDiff.Dual(2.0)), - (Base.pointer([2.0]), ForwardDiff.Dual(2.0)), + (Base.pointer([2.0]), ForwardDiff.Dual(2.0)) ] push!(p_possibilities17, Returns((a = 2, b = 1.3, c = ForwardDiff.Dual(2.0f0)))) @@ -125,7 +125,7 @@ p_possibilities_uninferrred = [ [MyStruct(2.0, ForwardDiff.Dual(2.0))], (; x = 2.0, y = [[MyStruct3(ForwardDiff.Dual(2.0))]]), (; x = Vector{Float64}(undef, 2), y = [[MyStruct3(ForwardDiff.Dual(2.0))]]), - (; x = Matrix{Any}(undef, 2, 2), y = [[MyStruct3(ForwardDiff.Dual(2.0))]]), + (; x = Matrix{Any}(undef, 2, 2), y = [[MyStruct3(ForwardDiff.Dual(2.0))]]) ] for p in p_possibilities_uninferrred @@ -144,7 +144,7 @@ end p_possibilities_missed = [ Set([2.0, "s", ForwardDiff.Dual(2.0)]), Set([2.0, ForwardDiff.Dual(2.0), SciMLBase.NullParameters()]), - Set([Matrix{Float64}(undef, 2, 2), ForwardDiff.Dual(2.0)]), + Set([Matrix{Float64}(undef, 2, 2), ForwardDiff.Dual(2.0)]) ] for p in p_possibilities_missed @@ -161,7 +161,7 @@ for p in p_possibilities_missed end p_possibilities_notdual = [ - (), (;), [2.0], [2.0, 2], [2.0, (2.0)], [2.0, MyStruct(2.0, 2.0f0)], + (), (;), [2.0], [2.0, 2], [2.0, (2.0)], [2.0, MyStruct(2.0, 2.0f0)] ] for p in p_possibilities_notdual @@ -186,7 +186,7 @@ p_possibilities_notdual_uninferred = [ [Dict(:x => 2, "y" => 5), MyStruct2(2.0)], # Dictionaries can have inference issues - Dict(:x => 2, :y => 5), Dict(:x => 2, "y" => 5), + Dict(:x => 2, :y => 5), Dict(:x => 2, "y" => 5) ] # Also check circular references @@ -226,7 +226,7 @@ f(du, u, p, t) = du .= u config = ForwardDiff.JacobianConfig(f, ones(5)) p_possibilities_configs = [ - (config, config), (config, 2.0), config, (; x = config, y = 2.0), + (config, config), (config, 2.0), config, (; x = config, y = 2.0) ] for p in p_possibilities_configs @@ -244,7 +244,7 @@ for p in p_possibilities_configs end p_possibilities_configs_not_inferred = [ - [2.0, (2.0,), config], [2.0, config, MyStruct(2.0, 2.0f0)], + [2.0, (2.0,), config], [2.0, config, MyStruct(2.0, 2.0f0)] ] for p in p_possibilities_configs_not_inferred @@ -292,4 +292,4 @@ end p = EOS() @test !(DiffEqBase.anyeltypedual(p) <: ForwardDiff.Dual) -@inferred DiffEqBase.anyeltypedual(p) \ No newline at end of file +@inferred DiffEqBase.anyeltypedual(p) diff --git a/test/gpu/termination_conditions.jl b/test/gpu/termination_conditions.jl index 1e429478b..d0811a311 100644 --- a/test/gpu/termination_conditions.jl +++ b/test/gpu/termination_conditions.jl @@ -9,7 +9,7 @@ const TERMINATION_CONDITIONS = [ SteadyStateDiffEqTerminationMode(), SimpleNonlinearSolveTerminationMode(), NormTerminationMode(), RelTerminationMode(), RelNormTerminationMode(), AbsTerminationMode(), AbsNormTerminationMode(), RelSafeTerminationMode(), - AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode(), + AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode() ] @testset "Termination Conditions: Allocations" begin diff --git a/test/ode_default_unstable_check.jl b/test/ode_default_unstable_check.jl index ec50fc1de..7cb20c4ab 100644 --- a/test/ode_default_unstable_check.jl +++ b/test/ode_default_unstable_check.jl @@ -20,7 +20,7 @@ u2′[2] = SA[1.0 NaN; 1.0 1.0] u3 = VectorOfArray([ones(5), ones(5)]) @test !NAN_CHECK(u3) u3′ = recursivecopy(u3) -u3′[3,2] = NaN +u3′[3, 2] = NaN @test NAN_CHECK(u3′) u4 = ArrayPartition(u1, u2, u3) diff --git a/test/plot_vars.jl b/test/plot_vars.jl index a66c6c57a..a0a0403c1 100644 --- a/test/plot_vars.jl +++ b/test/plot_vars.jl @@ -21,27 +21,27 @@ syms = [:x, :y, :z] @test SciMLBase.interpret_vars([(0, 1), (1, 3), (4, 5)], sol) == [ (SciMLBase.DEFAULT_PLOT_FUNC, 0, 1), (SciMLBase.DEFAULT_PLOT_FUNC, 1, 3), - (SciMLBase.DEFAULT_PLOT_FUNC, 4, 5), + (SciMLBase.DEFAULT_PLOT_FUNC, 4, 5) ] @test SciMLBase.interpret_vars([1, (1, 3), (4, 5)], sol) == [ (SciMLBase.DEFAULT_PLOT_FUNC, 0, 1), (SciMLBase.DEFAULT_PLOT_FUNC, 1, 3), - (SciMLBase.DEFAULT_PLOT_FUNC, 4, 5), + (SciMLBase.DEFAULT_PLOT_FUNC, 4, 5) ] @test SciMLBase.interpret_vars([1, 3, 4], sol) == [ (SciMLBase.DEFAULT_PLOT_FUNC, 0, 1), (SciMLBase.DEFAULT_PLOT_FUNC, 0, 3), - (SciMLBase.DEFAULT_PLOT_FUNC, 0, 4), + (SciMLBase.DEFAULT_PLOT_FUNC, 0, 4) ] @test SciMLBase.interpret_vars(([1, 2, 3], [4, 5, 6]), sol) == [ (SciMLBase.DEFAULT_PLOT_FUNC, 1, 4), (SciMLBase.DEFAULT_PLOT_FUNC, 2, 5), - (SciMLBase.DEFAULT_PLOT_FUNC, 3, 6), + (SciMLBase.DEFAULT_PLOT_FUNC, 3, 6) ] @test SciMLBase.interpret_vars((1, [2, 3, 4]), sol) == [ (SciMLBase.DEFAULT_PLOT_FUNC, 1, 2), (SciMLBase.DEFAULT_PLOT_FUNC, 1, 3), - (SciMLBase.DEFAULT_PLOT_FUNC, 1, 4), + (SciMLBase.DEFAULT_PLOT_FUNC, 1, 4) ] f(x, y) = (x + y, y) diff --git a/test/problem_creation_tests.jl b/test/problem_creation_tests.jl index 4139488a3..c213fdeef 100644 --- a/test/problem_creation_tests.jl +++ b/test/problem_creation_tests.jl @@ -94,7 +94,8 @@ differential_vars = [true, true, false] prob_dae_resrob = DAEProblem(f, du0, u0, (0.0, 100000.0)) prob_dae_resrob = DAEProblem{true}(f, du0, u0, (0.0, 100000.0)) -@test_broken @inferred(DAEProblem(f, du0, u0, (0.0, 100000.0))) == DAEProblem(f, du0, u0, (0.0, 100000.0)) +@test_broken @inferred(DAEProblem(f, du0, u0, (0.0, 100000.0))) == + DAEProblem(f, du0, u0, (0.0, 100000.0)) @inferred DAEProblem{true}(f, du0, u0, (0.0, 100000.0)) # Ensures uniform dimensionality of u0, du0, and differential_vars diff --git a/test/remake_tests.jl b/test/remake_tests.jl index a4950b877..f1a49b945 100644 --- a/test/remake_tests.jl +++ b/test/remake_tests.jl @@ -79,7 +79,7 @@ noise2 = remake(noise1; tspan = tspan2); # Test remake with TwoPointBVPFunction (manually defined): f1 = SciMLBase.TwoPointBVPFunction((u, p, t) -> 1, ((u_a, p) -> 2, (u_b, p) -> 2)) -@test_broken f2 = remake(f1; bc = ((u_a, p) -> 3, (u_b, p) -> 4)) +@test_broken f2 = remake(f1; bc = ((u_a, p) -> 3, (u_b, p) -> 4)) @test_broken f1.bc() == 1 @test_broken f2.bc() == 2 diff --git a/test/termination_conditions.jl b/test/termination_conditions.jl index a9906d055..3403262f0 100644 --- a/test/termination_conditions.jl +++ b/test/termination_conditions.jl @@ -8,12 +8,12 @@ const TERMINATION_CONDITIONS = [ SteadyStateDiffEqTerminationMode(), SimpleNonlinearSolveTerminationMode(), NormTerminationMode(), RelTerminationMode(), RelNormTerminationMode(), AbsTerminationMode(), AbsNormTerminationMode(), RelSafeTerminationMode(), - AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode(), + AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode() ] @testset "Termination Conditions: Allocations" begin @testset "Mode: $(tcond)" for tcond in TERMINATION_CONDITIONS @test (@ballocated DiffEqBase.check_convergence($tcond, $du, $u, $uprev, 1e-3, - 1e-3)) == 0 + 1e-3)) == 0 end end diff --git a/test/utils.jl b/test/utils.jl index cdf12b72f..24c2ad584 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -36,10 +36,10 @@ end end @testset "_rate_prototype" begin - @test _rate_prototype([1f0], 1.0, 1.0) isa Vector{Float32} - td = Dual{Tag{typeof(+), Float64}}(2.0,1.0) - @test _rate_prototype([1f0], td, td) isa Vector{Float32} - xd = [Dual{Tag{typeof(+), Float32}}(2.0,1.0)] + @test _rate_prototype([1.0f0], 1.0, 1.0) isa Vector{Float32} + td = Dual{Tag{typeof(+), Float64}}(2.0, 1.0) + @test _rate_prototype([1.0f0], td, td) isa Vector{Float32} + xd = [Dual{Tag{typeof(+), Float32}}(2.0, 1.0)] @test _rate_prototype(xd, 1.0, 1.0) isa typeof(xd) @test _rate_prototype([u"1f0m"], u"1.0s", 1.0) isa typeof([u"1f0m/s"]) end