Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement observation noise for the PN likelihood #299

Merged
merged 17 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,14 @@ X_A_Xt(A, X) = X * A * X'
stack(x) = copy(reduce(hcat, x)')
vecvec2mat(x) = reduce(hcat, x)'

cov2psdmatrix(cov::Number; d) = PSDMatrix(sqrt(cov) * Eye(d))
cov2psdmatrix(cov::UniformScaling; d) = PSDMatrix(sqrt(cov.λ) * Eye(d))
cov2psdmatrix(cov::Diagonal; d) =
(@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(sqrt.(cov)))
cov2psdmatrix(cov::AbstractMatrix; d) =
(@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(Matrix(cholesky(cov).U)))
cov2psdmatrix(cov::PSDMatrix; d) = (@assert size(cov, 1) == size(cov, 2) == d; cov)

include("fast_linalg.jl")
include("kronecker.jl")
include("covariance_structure.jl")
Expand Down
103 changes: 67 additions & 36 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,26 @@
########################################################################################
abstract type AbstractEK <: OrdinaryDiffEq.OrdinaryDiffEqAdaptiveAlgorithm end

function ekargcheck(alg; diffusionmodel, pn_observation_noise, kwargs...)
if (isstatic(diffusionmodel) && diffusionmodel.calibrate) &&
(!isnothing(pn_observation_noise) && !iszero(pn_observation_noise))
throw(

Check warning on line 9 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L9

Added line #L9 was not covered by tests
ArgumentError(
"Automatic calibration of global diffusion models is not possible when using observation noise. If you want to calibrate a global diffusion parameter, do so setting `calibrate=false` and optimizing `sol.pnstats.log_likelihood` manually.",
),
)
end
if (
(diffusionmodel isa FixedMVDiffusion && diffusionmodel.calibrate) ||
diffusionmodel isa DynamicMVDiffusion) && alg == EK1
throw(
ArgumentError(
"The `EK1` algorithm does not support automatic calibration of multivariate diffusion models. Either use the `EK0` instead, or use a scalar diffusion model, or set `calibrate=false` and calibrate manually by optimizing `sol.pnstats.log_likelihood`.",
),
)
end
end

"""
EK0(; order=3,
smooth=true,
Expand Down Expand Up @@ -38,19 +58,24 @@

# [References](@ref references)
"""
struct EK0{PT,DT,IT} <: AbstractEK
struct EK0{PT,DT,IT,RT} <: AbstractEK
prior::PT
diffusionmodel::DT
smooth::Bool
initialization::IT
pn_observation_noise::RT
EK0(; order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
pn_observation_noise::RT=nothing,
) where {PT,DT,IT,RT} = begin
ekargcheck(EK0; diffusionmodel, pn_observation_noise)
new{PT,DT,IT,RT}(
prior, diffusionmodel, smooth, initialization, pn_observation_noise)
end
end
EK0(;
order=3,
prior=IWP(order),
diffusionmodel=DynamicDiffusion(),
smooth=true,
initialization=TaylorModeInit(num_derivatives(prior)),
) = EK0(prior, diffusionmodel, smooth, initialization)

_unwrap_val(::Val{B}) where {B} = B
_unwrap_val(B) = B
Expand Down Expand Up @@ -92,39 +117,45 @@

# [References](@ref references)
"""
struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT} <: AbstractEK
struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT} <: AbstractEK
prior::PT
diffusionmodel::DT
smooth::Bool
initialization::IT
pn_observation_noise::RT
EK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
diff_type=Val{:forward},
standardtag=Val{true}(),
concrete_jac=nothing,
pn_observation_noise::RT=nothing,
) where {PT,DT,IT,RT} = begin
ekargcheck(EK1; diffusionmodel, pn_observation_noise)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
PT,
DT,
IT,
RT,
}(
prior,
diffusionmodel,
smooth,
initialization,
pn_observation_noise,
)
end
end
EK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
diff_type=Val{:forward},
standardtag=Val{true}(),
concrete_jac=nothing,
) where {PT,DT,IT} =
EK1{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
PT,
DT,
IT,
}(
prior,
diffusionmodel,
smooth,
initialization,
)

"""
ExpEK(; L, order=3, kwargs...)
Expand Down
6 changes: 4 additions & 2 deletions src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,10 +168,12 @@ function OrdinaryDiffEq.alg_cache(
copy!(x0.Σ, apply_diffusion(x0.Σ, initdiff))

# Measurement model related things
R = nothing # factorized_similar(FAC, d, d)
R =
isnothing(alg.pn_observation_noise) ? nothing :
to_factorized_matrix(FAC, cov2psdmatrix(alg.pn_observation_noise; d))
H = factorized_similar(FAC, d, D)
v = similar(Array{uElType}, d)
S = PSDMatrix(factorized_zeros(FAC, D, d))
S = factorized_zeros(FAC, d, d)
measurement = Gaussian(v, S)

# Caches
Expand Down
10 changes: 1 addition & 9 deletions src/callbacks/dataupdate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,7 @@ function DataUpdateCallback(
obs_mean = _matmul!(view(m_tmp.μ, 1:o), H, x.μ)
obs_mean .-= val

R = if observation_noise_cov isa PSDMatrix
observation_noise_cov
elseif observation_noise_cov isa Number
PSDMatrix(sqrt(observation_noise_cov) * Eye(o))
elseif observation_noise_cov isa UniformScaling
PSDMatrix(sqrt(observation_noise_cov.λ) * Eye(o))
else
PSDMatrix(cholesky(observation_noise_cov).U)
end
R = cov2psdmatrix(observation_noise_cov; d=o)

# _A = x.Σ.R * H'
# obs_cov = _A'_A + R
Expand Down
20 changes: 6 additions & 14 deletions src/data_likelihoods/fenrir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,7 @@ function fenrir_data_loglik(

# Fit the ODE solution / PN posterior to the provided data; this is the actual Fenrir
o = length(data.u[1])
R = if observation_noise_cov isa PSDMatrix
observation_noise_cov
elseif observation_noise_cov isa Number
PSDMatrix(sqrt(observation_noise_cov) * Eye(o))
elseif observation_noise_cov isa UniformScaling
PSDMatrix(sqrt(observation_noise_cov.λ) * Eye(o))
else
PSDMatrix(cholesky(observation_noise_cov).U)
end
R = cov2psdmatrix(observation_noise_cov; d=o)
LL, _, _ = fit_pnsolution_to_data!(sol, R, data; proj=observation_matrix)

return LL
Expand All @@ -91,7 +83,7 @@ function fit_pnsolution_to_data!(
C_d=view(C_d, 1:o),
K1=view(K1, :, 1:o),
K2=view(C_Dxd, :, 1:o),
m_tmp=Gaussian(view(m_tmp.μ, 1:o), PSDMatrix(view(m_tmp.Σ.R, :, 1:o))),
m_tmp=Gaussian(view(m_tmp.μ, 1:o), view(m_tmp.Σ, 1:o, 1:o)),
)

x_posterior = copy(sol.x_filt) # the object to be filled
Expand Down Expand Up @@ -144,10 +136,10 @@ function measure_and_update!(x, u, H, R::PSDMatrix, cache)
z, S = cache.m_tmp
_matmul!(z, H, x.μ)
z .-= u
fast_X_A_Xt!(S, x.Σ, H)
# _S = PSDMatrix(S.R'S.R + R.R'R.R)
_S = PSDMatrix(triangularize!([S.R; R.R], cachemat=cache.C_DxD))
msmnt = Gaussian(z, _S)
_matmul!(cache.C_Dxd, x.Σ.R, H')
_matmul!(S, cache.C_Dxd', cache.C_Dxd)
S .+= _matmul!(cache.C_dxd, R.R', R.R)
msmnt = Gaussian(z, S)

return update!(x, copy!(cache.x_tmp, x), msmnt, H; R=R, cache)
end
39 changes: 19 additions & 20 deletions src/diffusions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,12 @@ function estimate_global_diffusion(::FixedDiffusion, integ)

v, S = measurement.μ, measurement.Σ
e = m_tmp.μ
_S = _matmul!(Smat, S.R', S.R)
e .= v
diffusion_t = if _S isa IsometricKroneckerProduct
@assert length(_S.B) == 1
dot(v, e) / d / _S.B[1]
diffusion_t = if S isa IsometricKroneckerProduct
@assert length(S.B) == 1
dot(v, e) / d / S.B[1]
else
S_chol = cholesky!(_S)
S_chol = cholesky!(S)
ldiv!(S_chol, e)
dot(v, e) / d
end
Expand Down Expand Up @@ -123,13 +122,12 @@ function estimate_global_diffusion(::FixedMVDiffusion, integ)
@unpack d, q, measurement, local_diffusion = integ.cache

v, S = measurement.μ, measurement.Σ
# S_11 = diag(S)[1]
c1 = view(S.R, :, 1)
S_11 = dot(c1, c1)
# @assert diag(S) |> unique |> length == 1
S_11 = S[1, 1]

Σ_ii = v .^ 2 ./ S_11
Σ = Diagonal(Σ_ii)
Σ_out = kron(Σ, I(q + 1))
Σ_out = kron(Σ, I(q + 1)) # -> Different for each dimension; same for each derivative

if integ.success_iter == 0
# @assert length(diffusions) == 0
Expand Down Expand Up @@ -159,17 +157,17 @@ For more background information
* [bosch20capos](@cite) Bosch et al, "Calibrated Adaptive Probabilistic ODE Solvers", AISTATS (2021)
"""
function local_scalar_diffusion(cache)
@unpack d, R, H, Qh, measurement, m_tmp, Smat = cache
@unpack d, R, H, Qh, measurement, m_tmp, Smat, C_Dxd = cache
z = measurement.μ
e, HQH = m_tmp.μ, m_tmp.Σ
fast_X_A_Xt!(HQH, Qh, H)
HQHmat = _matmul!(Smat, HQH.R', HQH.R)
_matmul!(C_Dxd, Qh.R, H')
_matmul!(HQH, C_Dxd', C_Dxd)
e .= z
σ² = if HQHmat isa IsometricKroneckerProduct
@assert length(HQHmat.B) == 1
dot(z, e) / d / HQHmat.B[1]
σ² = if HQH isa IsometricKroneckerProduct
@assert length(HQH.B) == 1
dot(z, e) / d / HQH.B[1]
else
C = cholesky!(HQHmat)
C = cholesky!(HQH)
ldiv!(C, e)
dot(z, e) / d
end
Expand All @@ -195,16 +193,17 @@ function local_diagonal_diffusion(cache)
@unpack d, q, H, Qh, measurement, m_tmp, tmp = cache
@unpack local_diffusion = cache
z = measurement.μ
HQH = fast_X_A_Xt!(m_tmp.Σ, Qh, H)
# Q0_11 = diag(HQH)[1]
c1 = view(HQH.R, :, 1)
# HQH = H * unfactorize(Qh) * H'
# @assert HQH |> diag |> unique |> length == 1
# c1 = view(_matmul!(cache.C_Dxd, Qh.R, H'), :, 1)
c1 = mul!(view(cache.C_Dxd, :, 1:1), Qh.R, view(H, 1:1, :)')
Q0_11 = dot(c1, c1)

Σ_ii = @. m_tmp.μ = z^2 / Q0_11
# Σ_ii .= max.(Σ_ii, eps(eltype(Σ_ii)))
Σ = Diagonal(Σ_ii)

# local_diffusion = kron(Σ, I(q+1))
# -> Different for each dimension; same for each derivative
for i in 1:d
for j in (i-1)*(q+1)+1:i*(q+1)
local_diffusion[j, j] = Σ[i, i]
Expand Down
22 changes: 18 additions & 4 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,16 @@ function update!(
fast_X_A_Xt!(x_out.Σ, P_p, M_cache)

if !isnothing(R)
x_out.Σ.R .= triangularize!([x_out.Σ.R; R.R * K']; cachemat=M_cache)
# M = Matrix(x_out.Σ) + K * Matrix(R) * K'
_matmul!(M_cache, x_out.Σ.R', x_out.Σ.R)
_matmul!(K1_cache, K, R.R')
_matmul!(M_cache, K1_cache, K1_cache', 1, 1)
chol = cholesky!(Symmetric(M_cache), check=false)
if issuccess(chol)
copy!(x_out.Σ.R, chol.U)
else
x_out.Σ.R .= triangularize!([x_out.Σ.R; K1_cache']; cachemat=M_cache)
end
end

return x_out, loglikelihood
Expand All @@ -141,7 +150,10 @@ end
function update!(
x_out::SRGaussian{T,<:IsometricKroneckerProduct},
x_pred::SRGaussian{T,<:IsometricKroneckerProduct},
measurement::SRGaussian{T,<:IsometricKroneckerProduct},
measurement::Gaussian{
<:AbstractVector,
<:Union{<:PSDMatrix{T,<:IsometricKroneckerProduct},<:IsometricKroneckerProduct},
},
H::IsometricKroneckerProduct,
K1_cache::IsometricKroneckerProduct,
K2_cache::IsometricKroneckerProduct,
Expand All @@ -156,7 +168,9 @@ function update!(
_x_out = Gaussian(reshape_no_alloc(x_out.μ, Q, d), PSDMatrix(x_out.Σ.R.B))
_x_pred = Gaussian(reshape_no_alloc(x_pred.μ, Q, d), PSDMatrix(x_pred.Σ.R.B))
_measurement = Gaussian(
reshape_no_alloc(measurement.μ, 1, d), PSDMatrix(measurement.Σ.R.B))
reshape_no_alloc(measurement.μ, 1, d),
measurement.Σ isa PSDMatrix ? PSDMatrix(measurement.Σ.R.B) : measurement.Σ.B,
)
_H = H.B
_K1_cache = K1_cache.B
_K2_cache = K2_cache.B
Expand All @@ -180,7 +194,7 @@ function update!(
end

# Short-hand with cache
function update!(x_out, x, measurement, H; R=nothing, cache)
function update!(x_out, x, measurement, H; cache, R=nothing)
@unpack K1, m_tmp, C_DxD, C_dxd, C_Dxd, C_d = cache
K2 = C_Dxd
return update!(x_out, x, measurement, H, K1, K2, C_DxD, C_dxd, C_d; R)
Expand Down
3 changes: 2 additions & 1 deletion src/initialization/classicsolverinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ function rk_init_improve(cache::AbstractODEFilterCache, ts, us, dt)

H = cache.E0 * PI
measurement.μ .= H * x_pred.μ .- u
fast_X_A_Xt!(measurement.Σ, x_pred.Σ, H)
_matmul!(C_Dxd, x_pred.Σ.R, H')
_matmul!(measurement.Σ, C_Dxd', C_Dxd)

update!(x_filt, x_pred, measurement, H, K1, C_Dxd, C_DxD, C_dxd, C_d)
push!(filts, copy(x_filt))
Expand Down
3 changes: 2 additions & 1 deletion src/initialization/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ function init_condition_on!(
m_tmp.μ .-= data

# measurement cov
fast_X_A_Xt!(m_tmp.Σ, x.Σ, H)
_matmul!(C_Dxd, x.Σ.R, H')
_matmul!(m_tmp.Σ, C_Dxd', C_Dxd)
copy!(x_tmp, x)
update!(x, x_tmp, m_tmp, H, K1, C_Dxd, C_DxD, C_dxd, C_d)
end
Loading
Loading