From 07c98d06ce7e0977d31a4859c938035f168b37b1 Mon Sep 17 00:00:00 2001 From: Nathanael Bosch Date: Sat, 2 Sep 2023 21:31:49 +0300 Subject: [PATCH] Make cache.H the correct one again --- src/caches.jl | 12 ++++++++---- src/derivative_utils.jl | 8 +------- src/diffusions.jl | 1 - src/filtering/update.jl | 1 - src/perform_step.jl | 5 ++--- 5 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/caches.jl b/src/caches.jl index e5b403943..bed37d63f 100644 --- a/src/caches.jl +++ b/src/caches.jl @@ -3,7 +3,7 @@ ######################################################################################## mutable struct EKCache{ RType,ProjType,SolProjType,PType,PIType,EType,uType,duType,xType,PriorType,AType,QType, - matType,bkType,diffusionType,diffModelType,measModType,measType,puType,llType,dtType, + HType,matType,bkType,diffusionType,diffModelType,measModType,measType,puType,llType,dtType, rateType,UF,JC,uNoUnitsType, } <: AbstractODEFilterCache # Constants @@ -40,7 +40,7 @@ mutable struct EKCache{ measurement::measType m_tmp::measType pu_tmp::puType - H::matType + H::HType du::duType ddu::matType K1::matType @@ -144,7 +144,11 @@ function OrdinaryDiffEq.alg_cache( # Measurement model related things R = zeros(uElType, d, d) - H = zeros(uElType, d, D) + H = if KRONECKER + copy(E1) + else + zeros(uElType, d, D) + end v = zeros(uElType, d) S = if KRONECKER PSDMatrix(kronecker(Id, zeros(uElType, q + 1))) @@ -199,7 +203,7 @@ function OrdinaryDiffEq.alg_cache( ll = zero(uEltypeNoUnits) return EKCache{ typeof(R),typeof(Proj),typeof(SolProj),typeof(P),typeof(PI),typeof(E0), - uType,typeof(du),typeof(x0),typeof(prior),typeof(A),typeof(Q),matType, + uType,typeof(du),typeof(x0),typeof(prior),typeof(A),typeof(Q),typeof(H),matType, typeof(backward_kernel),typeof(initdiff), typeof(diffmodel),typeof(measurement_model),typeof(measurement),typeof(pu_tmp), uEltypeNoUnits,typeof(dt),typeof(du1),typeof(uf),typeof(jac_config),typeof(atmp), diff --git a/src/derivative_utils.jl b/src/derivative_utils.jl index 15167d650..c3082c2df 100644 --- a/src/derivative_utils.jl +++ b/src/derivative_utils.jl @@ -17,14 +17,12 @@ function calc_H_EK0!(H, integ, cache) @unpack f = integ @unpack d, ddu, E1, E2 = cache - return nothing - if f isa DynamicalODEFunction @assert f.mass_matrix === I H .= E2 else if f.mass_matrix === I - H .= E1 + copy!(H, E1) elseif f.mass_matrix isa UniformScaling H .= f.mass_matrix.λ .* E1 else @@ -33,7 +31,3 @@ function calc_H_EK0!(H, integ, cache) end return nothing end - -get_H(alg::EK1, cache) = cache.H -get_H(alg::EK0, cache) = cache.E1 -get_H(cache) = cache.E1 diff --git a/src/diffusions.jl b/src/diffusions.jl index 6c0fa19c5..d6d235596 100644 --- a/src/diffusions.jl +++ b/src/diffusions.jl @@ -153,7 +153,6 @@ For more background information """ function local_scalar_diffusion(cache) @unpack d, R, H, Qh, measurement, m_tmp, Smat = cache - H = get_H(cache) z = measurement.μ e, HQH = m_tmp.μ, m_tmp.Σ fast_X_A_Xt!(HQH, Qh, H) diff --git a/src/filtering/update.jl b/src/filtering/update.jl index bc1fa9644..fcb5ace51 100644 --- a/src/filtering/update.jl +++ b/src/filtering/update.jl @@ -135,7 +135,6 @@ function update!( M_cache::AbstractMatrix, C_dxd::AbstractMatrix, ) - # @info "Kronecker version of update!" _x_out = Gaussian(x_out.μ, PSDMatrix(x_out.Σ.R.B)) _x_pred = Gaussian(x_pred.μ, PSDMatrix(x_pred.Σ.R.B)) _measurement = Gaussian(measurement.μ, PSDMatrix(measurement.Σ.R.B)) diff --git a/src/perform_step.jl b/src/perform_step.jl index af240f73b..3bc0c7654 100644 --- a/src/perform_step.jl +++ b/src/perform_step.jl @@ -148,13 +148,13 @@ function evaluate_ode!(integ, x_pred, t) end compute_measurement_covariance!(cache) = - fast_X_A_Xt!(cache.measurement.Σ, cache.x_pred.Σ, get_H(cache)) + fast_X_A_Xt!(cache.measurement.Σ, cache.x_pred.Σ, cache.H) function update!(cache, prediction) @unpack measurement, H, x_filt, K1, m_tmp, C_DxD = cache @unpack C_dxd, C_Dxd = cache K2 = C_Dxd - update!(x_filt, prediction, measurement, get_H(cache), K1, K2, C_DxD, C_dxd) + update!(x_filt, prediction, measurement, H, K1, K2, C_DxD, C_dxd) return x_filt end @@ -210,7 +210,6 @@ To save allocations, the function modifies the given `cache` and writes into """ function estimate_errors!(cache::AbstractODEFilterCache) @unpack local_diffusion, Qh, H, d = cache - H = get_H(cache) if local_diffusion isa Real && isinf(local_diffusion) return Inf