Skip to content

Commit

Permalink
Make cache.H the correct one again
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Sep 2, 2023
1 parent 1d653be commit 07c98d0
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 16 deletions.
12 changes: 8 additions & 4 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
########################################################################################
mutable struct EKCache{
RType,ProjType,SolProjType,PType,PIType,EType,uType,duType,xType,PriorType,AType,QType,
matType,bkType,diffusionType,diffModelType,measModType,measType,puType,llType,dtType,
HType,matType,bkType,diffusionType,diffModelType,measModType,measType,puType,llType,dtType,
rateType,UF,JC,uNoUnitsType,
} <: AbstractODEFilterCache
# Constants
Expand Down Expand Up @@ -40,7 +40,7 @@ mutable struct EKCache{
measurement::measType
m_tmp::measType
pu_tmp::puType
H::matType
H::HType
du::duType
ddu::matType
K1::matType
Expand Down Expand Up @@ -144,7 +144,11 @@ function OrdinaryDiffEq.alg_cache(

# Measurement model related things
R = zeros(uElType, d, d)
H = zeros(uElType, d, D)
H = if KRONECKER
copy(E1)
else
zeros(uElType, d, D)
end
v = zeros(uElType, d)
S = if KRONECKER
PSDMatrix(kronecker(Id, zeros(uElType, q + 1)))
Expand Down Expand Up @@ -199,7 +203,7 @@ function OrdinaryDiffEq.alg_cache(
ll = zero(uEltypeNoUnits)
return EKCache{
typeof(R),typeof(Proj),typeof(SolProj),typeof(P),typeof(PI),typeof(E0),
uType,typeof(du),typeof(x0),typeof(prior),typeof(A),typeof(Q),matType,
uType,typeof(du),typeof(x0),typeof(prior),typeof(A),typeof(Q),typeof(H),matType,
typeof(backward_kernel),typeof(initdiff),
typeof(diffmodel),typeof(measurement_model),typeof(measurement),typeof(pu_tmp),
uEltypeNoUnits,typeof(dt),typeof(du1),typeof(uf),typeof(jac_config),typeof(atmp),
Expand Down
8 changes: 1 addition & 7 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,12 @@ function calc_H_EK0!(H, integ, cache)
@unpack f = integ
@unpack d, ddu, E1, E2 = cache

return nothing

if f isa DynamicalODEFunction
@assert f.mass_matrix === I
H .= E2
else
if f.mass_matrix === I
H .= E1
copy!(H, E1)
elseif f.mass_matrix isa UniformScaling
H .= f.mass_matrix.λ .* E1
else
Expand All @@ -33,7 +31,3 @@ function calc_H_EK0!(H, integ, cache)
end
return nothing
end

get_H(alg::EK1, cache) = cache.H
get_H(alg::EK0, cache) = cache.E1
get_H(cache) = cache.E1
1 change: 0 additions & 1 deletion src/diffusions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ For more background information
"""
function local_scalar_diffusion(cache)
@unpack d, R, H, Qh, measurement, m_tmp, Smat = cache
H = get_H(cache)
z = measurement.μ
e, HQH = m_tmp.μ, m_tmp.Σ
fast_X_A_Xt!(HQH, Qh, H)
Expand Down
1 change: 0 additions & 1 deletion src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ function update!(
M_cache::AbstractMatrix,
C_dxd::AbstractMatrix,
)
# @info "Kronecker version of update!"
_x_out = Gaussian(x_out.μ, PSDMatrix(x_out.Σ.R.B))
_x_pred = Gaussian(x_pred.μ, PSDMatrix(x_pred.Σ.R.B))
_measurement = Gaussian(measurement.μ, PSDMatrix(measurement.Σ.R.B))
Expand Down
5 changes: 2 additions & 3 deletions src/perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ function evaluate_ode!(integ, x_pred, t)
end

compute_measurement_covariance!(cache) =
fast_X_A_Xt!(cache.measurement.Σ, cache.x_pred.Σ, get_H(cache))
fast_X_A_Xt!(cache.measurement.Σ, cache.x_pred.Σ, cache.H)

function update!(cache, prediction)
@unpack measurement, H, x_filt, K1, m_tmp, C_DxD = cache
@unpack C_dxd, C_Dxd = cache
K2 = C_Dxd
update!(x_filt, prediction, measurement, get_H(cache), K1, K2, C_DxD, C_dxd)
update!(x_filt, prediction, measurement, H, K1, K2, C_DxD, C_dxd)
return x_filt
end

Expand Down Expand Up @@ -210,7 +210,6 @@ To save allocations, the function modifies the given `cache` and writes into
"""
function estimate_errors!(cache::AbstractODEFilterCache)
@unpack local_diffusion, Qh, H, d = cache
H = get_H(cache)

if local_diffusion isa Real && isinf(local_diffusion)
return Inf
Expand Down

0 comments on commit 07c98d0

Please sign in to comment.