Skip to content

Commit

Permalink
Rename destats to stats (#223)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch authored Mar 11, 2023
1 parent 778c347 commit 21748ad
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 43 deletions.
6 changes: 3 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@ TaylorSeries = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea"
ToeplitzMatrices = "c751599d-da0a-543b-9d20-d0a503d91d24"

[compat]
DiffEqBase = "6"
DiffEqBase = "6.122"
DiffEqDevTools = "2"
ExponentialUtilities = "1"
FastBroadcast = "0.2"
ForwardDiff = "0.10"
FunctionWrappersWrappers = "0.1.3"
GaussianDistributions = "0.5"
Octavian = "0.3.17"
OrdinaryDiffEq = "6.2"
OrdinaryDiffEq = "6.49.1"
PSDMatrices = "0.4.2"
RecipesBase = "1"
RecursiveArrayTools = "2"
Reexport = "1"
SciMLBase = "1.34"
SciMLBase = "1.90"
SimpleUnPack = "1"
SnoopPrecompile = "1"
SpecialMatrices = "3"
Expand Down
4 changes: 2 additions & 2 deletions docs/src/dae.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ sol1 = solve(prob_index1, EK1())
truesol = solve(prob_index1, Rodas4(), abstol=1e-10, reltol=1e-10)
sol1_final_error = norm(sol1.u[end] - truesol.u[end])
sol1_f_evals = sol1.destats.nf
sol1_f_evals = sol1.stats.nf
sol3_final_error = norm(sol3.u[end] - truesol.u[end])
sol3_f_evals = sol3.destats.nf
sol3_f_evals = sol3.stats.nf
@info "Results" sol1_final_error sol1_f_evals sol3_final_error sol3_f_evals
```

Expand Down
18 changes: 9 additions & 9 deletions src/initialization/classicsolverinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function initial_update!(integ, cache, ::ClassicSolverInit)
Mcache = cache.C_DxD
condition_on!(x, Proj(0), _u, cache)
is_secondorder ? f.f1(du, u.x[1], u.x[2], p, t) : f(du, u, p, t)
integ.destats.nf += 1
integ.stats.nf += 1
condition_on!(x, Proj(1), view(du, :), cache)

if q < 2
Expand Down Expand Up @@ -59,7 +59,7 @@ function initial_update!(integ, cache, ::ClassicSolverInit)
integ.sol.prob,
integ,
)
integ.destats.nf += 2
integ.stats.nf += 2

nsteps = q + 2
tmax = t0 + nsteps * dt
Expand All @@ -75,13 +75,13 @@ function initial_update!(integ, cache, ::ClassicSolverInit)
saveat=tstops,
)
# This is necessary in order to fairly account for the cost of initialization!
integ.destats.nf += sol.destats.nf
integ.destats.njacs += sol.destats.njacs
integ.destats.nsolve += sol.destats.nsolve
integ.destats.nw += sol.destats.nw
integ.destats.nnonliniter += sol.destats.nnonliniter
integ.destats.nnonlinconvfail += sol.destats.nnonlinconvfail
integ.destats.ncondition += sol.destats.ncondition
integ.stats.nf += sol.stats.nf
integ.stats.njacs += sol.stats.njacs
integ.stats.nsolve += sol.stats.nsolve
integ.stats.nw += sol.stats.nw
integ.stats.nnonliniter += sol.stats.nnonliniter
integ.stats.nnonlinconvfail += sol.stats.nnonlinconvfail
integ.stats.ncondition += sol.stats.ncondition

# Filter & smooth to fit these values!
us = [u for u in sol.u]
Expand Down
2 changes: 1 addition & 1 deletion src/initialization/taylormode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function initial_update!(integ, cache, init::TaylorModeInit)
end

f_derivatives = taylormode_get_derivatives(u, f, p, t, q)
integ.destats.nf += q
integ.stats.nf += q
@assert length(0:q) == length(f_derivatives)
m_cache = Gaussian(zeros(eltype(u), d), PSDMatrix(zeros(eltype(u), D, d)))
for (o, df) in zip(0:q, f_derivatives)
Expand Down
2 changes: 1 addition & 1 deletion src/perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ function evaluate_ode!(integ, x_pred, t)
z_tmp = integ.cache.m_tmp

integ.cache.measurement_model(z.μ, x_pred.μ, p, t)
integ.destats.nf += 1
integ.stats.nf += 1

calc_H!(H, integ, integ.cache)

Expand Down
30 changes: 15 additions & 15 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,26 @@ mutable struct ProbODESolution{
cache::CType
dense::Bool
tslocation::Int
destats::DE
stats::DE
retcode::ReturnCode.T
end
ProbODESolution{T,N}(
u, pu, u_analytic, errors, t, k, x_filt, x_smooth, diffusions, log_likelihood, prob,
alg, interp, cache, dense, tslocation, destats, retcode,
alg, interp, cache, dense, tslocation, stats, retcode,
) where {T,N} = ProbODESolution{
T,N,typeof(u),typeof(pu),typeof(u_analytic),typeof(errors),typeof(t),typeof(k),
typeof(x_filt),typeof(diffusions),typeof(log_likelihood),typeof(prob),typeof(alg),
typeof(interp),typeof(cache),typeof(destats),
typeof(interp),typeof(cache),typeof(stats),
}(
u, pu, u_analytic, errors, t, k, x_filt, x_smooth, diffusions, log_likelihood, prob,
alg, interp, cache, dense, tslocation, destats, retcode,
alg, interp, cache, dense, tslocation, stats, retcode,
)

function DiffEqBase.solution_new_retcode(sol::ProbODESolution{T,N}, retcode) where {T,N}
return ProbODESolution{T,N}(
sol.u, sol.pu, sol.u_analytic, sol.errors, sol.t, sol.k, sol.x_filt, sol.x_smooth,
sol.diffusions, sol.log_likelihood, sol.prob, sol.alg, sol.interp, sol.cache,
sol.dense, sol.tslocation, sol.destats, retcode,
sol.dense, sol.tslocation, sol.stats, retcode,
)
end

Expand All @@ -61,7 +61,7 @@ function DiffEqBase.build_solution(
u;
k=nothing,
retcode=ReturnCode.Default,
destats=nothing,
stats=nothing,
dense=true,
kwargs...,
)
Expand Down Expand Up @@ -114,7 +114,7 @@ function DiffEqBase.build_solution(
return ProbODESolution{T,N}(
u, pu, u_analytic, errors, t, k, x_filt, x_smooth, typeof(diffusion_prototype)[],
ll, prob, alg, interp, cache,
dense, 0, destats, retcode,
dense, 0, stats, retcode,
)
end

Expand All @@ -126,7 +126,7 @@ function DiffEqBase.build_solution(
return ProbODESolution{T,N}(
sol.u, sol.pu, u_analytic, errors, sol.t, sol.k, sol.x_filt, sol.x_smooth,
sol.diffusions, sol.log_likelihood, sol.prob, sol.alg, sol.interp, sol.cache,
sol.dense, sol.tslocation, sol.destats, sol.retcode,
sol.dense, sol.tslocation, sol.stats, sol.retcode,
)
end

Expand Down Expand Up @@ -155,33 +155,33 @@ mutable struct MeanProbODESolution{
cache::CType
dense::Bool
tslocation::Int
destats::DE
stats::DE
retcode::ReturnCode.T
probsol::PSolType
end
MeanProbODESolution{T,N}(
u, u_analytic, errs, t, k, prob, alg, interp, cache, dense, tsl, destats, retcode,
u, u_analytic, errs, t, k, prob, alg, interp, cache, dense, tsl, stats, retcode,
probsol,
) where {T,N} = MeanProbODESolution{
T,N,typeof(u),typeof(u_analytic),typeof(errs),typeof(t),typeof(k),typeof(prob),
typeof(alg),typeof(interp),typeof(cache),typeof(destats),typeof(probsol)}(
u, u_analytic, errs, t, k, prob, alg, interp, cache, dense, tsl, destats, retcode,
typeof(alg),typeof(interp),typeof(cache),typeof(stats),typeof(probsol)}(
u, u_analytic, errs, t, k, prob, alg, interp, cache, dense, tsl, stats, retcode,
probsol,
)

DiffEqBase.build_solution(sol::MeanProbODESolution{T,N}, u_analytic, errors) where {T,N} =
MeanProbODESolution{T,N}(
sol.u, u_analytic, errors, sol.t, sol.k, sol.prob, sol.alg, sol.interp, sol.cache,
sol.dense, sol.tslocation, sol.destats, sol.retcode, sol.probsol)
sol.dense, sol.tslocation, sol.stats, sol.retcode, sol.probsol)

function mean(sol::ProbODESolution{T,N}) where {T,N}
return MeanProbODESolution{
T,N,typeof(sol.u),typeof(sol.u_analytic),typeof(sol.errors),typeof(sol.t),
typeof(sol.k),typeof(sol.prob),typeof(sol.alg),typeof(sol.interp),typeof(sol.cache),
typeof(sol.destats),typeof(sol),
typeof(sol.stats),typeof(sol),
}(
sol.u, sol.u_analytic, sol.errors, sol.t, sol.k, sol.prob, sol.alg, sol.interp,
sol.cache, sol.dense, sol.tslocation, sol.destats, sol.retcode, sol,
sol.cache, sol.dense, sol.tslocation, sol.stats, sol.retcode, sol,
)
end
(sol::MeanProbODESolution)(t::Real, args...) = mean(sol.probsol(t, args...))
Expand Down
2 changes: 1 addition & 1 deletion test/implicit_solver_kwarg_compat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ sol2 = solve(prob, EK1(autodiff=false))
@test sol2 isa ProbNumDiffEq.ProbODESolution

# check that forwarddiff leads to a smaller nf than finite diff
@test sol1.destats.nf < sol2.destats.nf
@test sol1.stats.nf < sol2.stats.nf
4 changes: 2 additions & 2 deletions test/ioup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ A_noisy = A + 1e-3 * randn(MersenneTwister(42), 2, 2)
@testset "Adaptive steps" begin
sol_ioup_noisy = solve(prob, EK1(prior=IOUP(3, A_noisy)))
err_ioup_noisy = norm(ref[end] - sol_ioup_noisy[end])
@test sol_ioup_noisy.destats.nf < sol_iwp.destats.nf
@test sol_ioup_noisy.stats.nf < sol_iwp.stats.nf
@test err_ioup_noisy < 2e-5

sol_ioup = solve(prob, EK1(prior=IOUP(3, A)))
err_ioup = norm(ref[end] - sol_ioup[end])
@test sol_ioup.destats.nf < sol_ioup_noisy.destats.nf
@test sol_ioup.stats.nf < sol_ioup_noisy.stats.nf
@test err_ioup < 5e-10
end

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ const GROUP = get(ENV, "GROUP", "All")
include("solution.jl")
end
@timedsafetestset "DE-stats" begin
include("destats.jl")
include("stats.jl")
end
@timedsafetestset "Errors Thrown" begin
include("errors_thrown.jl")
Expand Down
8 changes: 4 additions & 4 deletions test/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ using ODEProblemLibrary: prob_ode_lotkavolterra
@test length(sol.t) == length(sol.u)
@test length(prob.u0) == length(sol.u[end])

# Destats
@testset "DEStats" begin
@test length(sol.t) == sol.destats.naccept + 1
@test sol.destats.naccept <= sol.destats.nf
# Stats
@testset "Stats" begin
@test length(sol.t) == sol.stats.naccept + 1
@test sol.stats.naccept <= sol.stats.nf
end

@testset "Hit the provided tspan" begin
Expand Down
8 changes: 4 additions & 4 deletions test/destats.jl → test/stats.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra

const q = 3

@testset "destats.nf testing $alg" for init in (TaylorModeInit(), ClassicSolverInit()),
@testset "stats.nf testing $alg" for init in (TaylorModeInit(), ClassicSolverInit()),
alg in (
EK0(order=q, smooth=false, initialization=init),
EK1(order=q, smooth=false, initialization=init),
Expand All @@ -22,10 +22,10 @@ const q = 3
tspan = (0.0, 1.0)
prob = ODEProblem(f, u0, tspan, p)
sol = solve(prob, alg, save_everystep=false, dense=false)
@test sol.destats.nf == f_counter[1]
@test sol.stats.nf == f_counter[1]
end

@testset "SecondOrderODEProblem destats.nf testing $alg" for init in (TaylorModeInit(),),
@testset "SecondOrderODEProblem stats.nf testing $alg" for init in (TaylorModeInit(),),
# ClassicSolverInit does not work for second order ODEs right now
alg in (
EK0(order=q, smooth=false, initialization=init),
Expand All @@ -45,5 +45,5 @@ end
tspan = (0.0, 1.0)
prob = SecondOrderODEProblem(f, du0, u0, tspan, p)
sol = solve(prob, alg, save_everystep=false, dense=false)
@test sol.destats.nf == f_counter[1]
@test sol.stats.nf == f_counter[1]
end

0 comments on commit 21748ad

Please sign in to comment.