Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the DiagonalEK1 #301

Merged
merged 99 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
58c63eb
Add a BlockDiagonal implementation
nathanaelbosch Feb 13, 2024
d0a3eb0
It works and it's (a little bit) faster than dense!
nathanaelbosch Feb 13, 2024
5519c31
Implement a first version of the DiagonalEK1
nathanaelbosch Feb 13, 2024
c81e5ca
Added smoothing
nathanaelbosch Feb 13, 2024
05c93df
Add BlockDiagonals to the tests
nathanaelbosch Feb 13, 2024
bc6c372
This should be the proper logic to choose the cov factorization
nathanaelbosch Feb 13, 2024
f112c15
We can now select the covariance from the outside!
nathanaelbosch Feb 13, 2024
1404ee7
Add some SIMD here and there
nathanaelbosch Feb 13, 2024
ff81928
Make views of BlockDiagonals illegal as they are super slow
nathanaelbosch Feb 13, 2024
2bb610e
Change how the diffusions work
nathanaelbosch Feb 13, 2024
d7a49c6
Make some more diffusions work
nathanaelbosch Feb 13, 2024
473a140
Better checking for validity of algorithm arguments
nathanaelbosch Feb 14, 2024
051a36f
More BlockDiagonal linalg things
nathanaelbosch Feb 14, 2024
051689d
Better handling of the diffusion for prediction
nathanaelbosch Feb 14, 2024
390c468
The global diffusion is now written into the cache directly
nathanaelbosch Feb 14, 2024
4c56635
Implement rmul! or the IsometricKroneckerProduct
nathanaelbosch Feb 14, 2024
dff3e05
Properly ply the new diffusion after the solve
nathanaelbosch Feb 14, 2024
27912d5
Properly estimate the global scalar diffusion
nathanaelbosch Feb 14, 2024
3685edd
Properly implement the global MV diffusion
nathanaelbosch Feb 14, 2024
a184952
This should be a proper implementation of the dynamic MV diffusion
nathanaelbosch Feb 14, 2024
e4fd99f
Try to fix how the prediction handles the diffusion (I failed)
nathanaelbosch Feb 14, 2024
57e533d
Try to get the DynamicMV diff to work with BlockDiag cov (but fail)
nathanaelbosch Feb 14, 2024
097ea98
Get the DiagonalEK1 to work with a dense covariance factorization
nathanaelbosch Feb 14, 2024
895b666
Check diffusion and factorization compat somewhere else and warn inst…
nathanaelbosch Feb 14, 2024
efeaaad
JuliaFormatter.jl
nathanaelbosch Feb 14, 2024
2fb44be
Implement my own BlockDiag type
nathanaelbosch Feb 14, 2024
427392e
JuliaFormatter.jl
nathanaelbosch Feb 14, 2024
d808525
Start fixing some tests
nathanaelbosch Feb 14, 2024
cdf74e1
Remove duplicate matmul implementation
nathanaelbosch Feb 14, 2024
a583738
Fix some failing state init tests
nathanaelbosch Feb 15, 2024
d9fc137
Improve the diffusion handling some more
nathanaelbosch Feb 15, 2024
385d534
Enable the EK0 again with priors that are not Kronecker
nathanaelbosch Feb 15, 2024
ca92591
Remove one test case that's not yet supported
nathanaelbosch Feb 15, 2024
c766d6e
Significantly speed up the secondorderodeproblem tests
nathanaelbosch Feb 15, 2024
ed20760
Remove a test that currently fails
nathanaelbosch Feb 15, 2024
3f9e00d
Fix many of the tests that I had to temporally remove
nathanaelbosch Feb 15, 2024
15d3c28
Make it more obvious that BlockDiagonals and second order ODEs are no…
nathanaelbosch Feb 15, 2024
4c3be25
Rename our BlockDiagonals to ProbNumDiffEqBlockDiagonal (+ shortcut)
nathanaelbosch Feb 15, 2024
aafb416
Add a BlockDiagonals extension to transfomr ours to theirs
nathanaelbosch Feb 15, 2024
6501878
Add unit-tests for our `BlockDiag`s
nathanaelbosch Feb 15, 2024
85edcaa
Check that the K.b is actually empty
nathanaelbosch Feb 15, 2024
a27079d
Add versions to overload also the non-blasfloat matmuls
nathanaelbosch Feb 15, 2024
20b8430
Make some code more compact and readable
nathanaelbosch Feb 15, 2024
faa4d55
Change order to fit acronym
nathanaelbosch Feb 15, 2024
6733003
For some reson the eval tests failed; so fix them
nathanaelbosch Feb 15, 2024
597f968
Make the if else order in predict and backward kernel easier
nathanaelbosch Feb 15, 2024
ee26220
misc
nathanaelbosch Feb 15, 2024
cbb4e09
Properly implement `size`
nathanaelbosch Feb 15, 2024
5b2186b
Remove an inbounds as we don't explicitly do a sizecheck
nathanaelbosch Feb 16, 2024
9e2177a
Remove some checks again as they are irrelevant for the cov
nathanaelbosch Feb 16, 2024
e9946ac
Add a very minimal docstring to ProbNumDiffEqBlockDiagonal
nathanaelbosch Feb 16, 2024
b15fae8
Better BlockDiagonals and a bit of Kronecker
nathanaelbosch Feb 16, 2024
5ece898
Make the DiagonalEK1 work again (except for secondorderodes)
nathanaelbosch Feb 16, 2024
8d0ae85
Test the diffusions (found a bug! unittests are actually nice)
nathanaelbosch Feb 16, 2024
e613401
Give the diffusions much more space
nathanaelbosch Feb 16, 2024
fa9fa33
Grealy simplify the local error estimate code
nathanaelbosch Feb 16, 2024
5aa15f2
Better predict tests
nathanaelbosch Feb 16, 2024
7d07e9a
Beter update and smoothing tests
nathanaelbosch Feb 16, 2024
4669e9f
Praise the lord for unittests
nathanaelbosch Feb 16, 2024
20d2a59
JuliaFormatter.jl
nathanaelbosch Feb 16, 2024
4bf2568
Actually git the diffusion tests
nathanaelbosch Feb 16, 2024
46dbd17
Testfix
nathanaelbosch Feb 16, 2024
b9de4c0
Remove some comments
nathanaelbosch Feb 16, 2024
802f322
Test that the computed log-likelihood is correct
nathanaelbosch Feb 16, 2024
cc38b17
Add BlockDiagonals compat entry
nathanaelbosch Feb 16, 2024
33d5852
Add more solvers to the autodiff tests
nathanaelbosch Feb 16, 2024
3cbcd20
Add much more tests
nathanaelbosch Feb 16, 2024
d7e0cc0
JuliaFormatter.jl
nathanaelbosch Feb 16, 2024
9bf8770
Better complexity test
nathanaelbosch Feb 16, 2024
2a2533f
Actually do a proper test to check the scaling of the solvers
nathanaelbosch Feb 16, 2024
8fd54bc
JuliaFormatter.jl
nathanaelbosch Feb 16, 2024
d9501bf
Make preconditioner computation simpler and test better
nathanaelbosch Feb 17, 2024
3609d7e
One more prior test
nathanaelbosch Feb 17, 2024
f08ff1e
Fix some tests
nathanaelbosch Feb 17, 2024
786b1ed
Get the data likelihoods to work with DiagonalEK1 and the EK0
nathanaelbosch Feb 17, 2024
29be0c0
Make the data likelihoods better
nathanaelbosch Feb 17, 2024
3b53494
JuliaFormatter.jl
nathanaelbosch Feb 17, 2024
6d83599
Relax the complexity tests even more
nathanaelbosch Feb 17, 2024
be1fbef
JuliaFormatter.jl
nathanaelbosch Feb 17, 2024
dd7d08d
Remove some unused code
nathanaelbosch Feb 17, 2024
d74e3c8
Bisschen code upgrade für die data likelihoods
nathanaelbosch Feb 17, 2024
eb2e235
Make the complexity tests more compact
nathanaelbosch Feb 17, 2024
367b0bf
Remove some unused things, mainly to re-trigger gh actions
nathanaelbosch Feb 17, 2024
c85602a
Fix the bad getindex for BlockDiag
nathanaelbosch Feb 17, 2024
f3a04fd
Use the built-in matrix exponential
nathanaelbosch Feb 17, 2024
8b8a19a
Remore parts of a test that are not implemented anymore
nathanaelbosch Feb 17, 2024
ef76c67
Improve coverage a bit
nathanaelbosch Feb 18, 2024
5f218de
JuliaFormatter.jl
nathanaelbosch Feb 18, 2024
8d7fd1f
Check better what observation noise works with what factorization
nathanaelbosch Feb 18, 2024
3ab4520
Fix the pn_observation_noise check
nathanaelbosch Feb 18, 2024
a9f24d7
More BlockDiag tests
nathanaelbosch Feb 18, 2024
a2d527a
Add some more Kronecker tests
nathanaelbosch Feb 18, 2024
54c62ec
JuliaFormatter.jl
nathanaelbosch Feb 18, 2024
7d7bf1a
Fix test
nathanaelbosch Feb 18, 2024
569b9e3
Make the FixedMVDiffusion work with dense matrices
nathanaelbosch Feb 19, 2024
e67c665
Make the FixedMVDiffusion work with the data likelihoods
nathanaelbosch Feb 19, 2024
d9b4a36
Remove some comments
nathanaelbosch Feb 19, 2024
71e4be4
Add `add!` to the Kronecker tests and shorten them a bit
nathanaelbosch Feb 19, 2024
5f75325
Fix the failing test that I just found
nathanaelbosch Feb 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,8 @@ version = "0.15.0"
ArrayAllocators = "c9d4266f-a5cb-439d-837c-c97b191379f5"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
ExponentialUtilities = "d4d017d3-3776-5f7e-afef-a10c40355c18"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteHorizonGramians = "b59a298d-d283-4a37-9369-85a9f9a111a5"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -30,7 +27,6 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
SpecialMatrices = "928aab9d-ef52-54ac-8ca1-acd7ca42c160"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Expand All @@ -39,22 +35,23 @@ TaylorSeries = "6aa5eb33-94cf-58f4-a9d0-e4b2c4fc25ea"
ToeplitzMatrices = "c751599d-da0a-543b-9d20-d0a503d91d24"

[weakdeps]
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[extensions]
BlockDiagonalsExt = "BlockDiagonals"
DiffEqDevToolsExt = "DiffEqDevTools"
RecipesBaseExt = "RecipesBase"

[compat]
ArrayAllocators = "0.3"
BlockDiagonals = "0.1"
DiffEqBase = "6.122"
DiffEqCallbacks = "2.36"
DiffEqDevTools = "2"
DocStringExtensions = "0.9"
ExponentialUtilities = "1"
FastBroadcast = "0.2"
FastGaussQuadrature = "0.5, 1"
FillArrays = "1.9"
FiniteHorizonGramians = "0.2"
ForwardDiff = "0.10"
Expand All @@ -73,7 +70,6 @@ RecursiveArrayTools = "2, 3"
Reexport = "1"
SciMLBase = "1.90, 2"
SimpleUnPack = "1"
SpecialMatrices = "3"
StaticArrayInterface = "1.3"
Statistics = "1"
StructArrays = "0.4, 0.5, 0.6"
Expand Down
8 changes: 8 additions & 0 deletions ext/BlockDiagonalsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module BlockDiagonalsExt

import ProbNumDiffEq: ProbNumDiffEqBlockDiagonal, blocks
import BlockDiagonals: BlockDiagonal

BlockDiagonal(M::ProbNumDiffEqBlockDiagonal) = BlockDiagonal(blocks(M))

end
25 changes: 19 additions & 6 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ __precompile__()

module ProbNumDiffEq

import Base: copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate, ==, length
import Base:
copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate, ==, length, zero

using LinearAlgebra
import LinearAlgebra: mul!
Expand All @@ -15,7 +16,7 @@ using Reexport
import SciMLBase
import SciMLBase: interpret_vars, getsyms, remake
using OrdinaryDiffEq
using SpecialMatrices, ToeplitzMatrices
using ToeplitzMatrices
using FastBroadcast
using StaticArrayInterface
using FunctionWrappersWrappers
Expand All @@ -24,9 +25,7 @@ using TaylorSeries, TaylorIntegration
using SimpleUnPack
using RecursiveArrayTools
using ForwardDiff
using ExponentialUtilities
using Octavian
using FastGaussQuadrature
import Kronecker
using ArrayAllocators
using FiniteHorizonGramians
Expand All @@ -45,15 +44,27 @@ vecvec2mat(x) = reduce(hcat, x)'

cov2psdmatrix(cov::Number; d) = PSDMatrix(sqrt(cov) * Eye(d))
cov2psdmatrix(cov::UniformScaling; d) = PSDMatrix(sqrt(cov.λ) * Eye(d))
cov2psdmatrix(cov::Diagonal{<:Number,<:FillArrays.Fill}; d) =
(@assert size(cov, 1) == size(cov, 2) == d; cov2psdmatrix(cov.diag.value; d))
cov2psdmatrix(cov::Diagonal; d) =
(@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(sqrt.(cov)))
cov2psdmatrix(cov::AbstractMatrix; d) =
(@assert size(cov, 1) == size(cov, 2) == d; PSDMatrix(Matrix(cholesky(cov).U)))
cov2psdmatrix(cov::PSDMatrix; d) = (@assert size(cov, 1) == size(cov, 2) == d; cov)

"""
add!(out, toadd)

Add `toadd` to `out` in-place.
"""
add!
add!(out, toadd) = (out .+= toadd)

include("fast_linalg.jl")
include("kronecker.jl")
include("blockdiagonals.jl")
include("covariance_structure.jl")
export IsometricKroneckerCovariance, DenseCovariance, BlockDiagonalCovariance

abstract type AbstractODEFilterCache <: OrdinaryDiffEq.OrdinaryDiffEqCache end

Expand All @@ -65,14 +76,16 @@ include("priors/ltisde.jl")
include("priors/ioup.jl")
include("priors/matern.jl")
export IWP, IOUP, Matern
include("diffusions.jl")
include("diffusions/typedefs.jl")
include("diffusions/apply_diffusion.jl")
include("diffusions/calibration.jl")
export FixedDiffusion, DynamicDiffusion, FixedMVDiffusion, DynamicMVDiffusion

include("initialization/common.jl")
export TaylorModeInit, ClassicSolverInit, SimpleInit, ForwardDiffInit

include("algorithms.jl")
export EK0, EK1
export EK0, EK1, DiagonalEK1
export ExpEK, RosenbrockExpEK

include("alg_utils.jl")
Expand Down
18 changes: 12 additions & 6 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,25 @@
############################################################################################

OrdinaryDiffEq._alg_autodiff(::AbstractEK) = Val{true}()
OrdinaryDiffEq._alg_autodiff(::EK1{CS,AD}) where {CS,AD} = Val{AD}()
OrdinaryDiffEq.alg_difftype(::EK1{CS,AD,DiffType}) where {CS,AD,DiffType} = DiffType
OrdinaryDiffEq.standardtag(::AbstractEK) = false
OrdinaryDiffEq.standardtag(::EK1{CS,AD,DiffType,ST}) where {CS,AD,DiffType,ST} = ST
OrdinaryDiffEq.concrete_jac(::AbstractEK) = nothing
OrdinaryDiffEq.concrete_jac(::EK1{CS,AD,DiffType,ST,CJ}) where {CS,AD,DiffType,ST,CJ} = CJ

@inline DiffEqBase.get_tmp_cache(integ, alg::AbstractEK, cache::AbstractODEFilterCache) =
(cache.tmp, cache.atmp)
OrdinaryDiffEq.get_chunksize(::EK1{CS}) where {CS} = Val(CS)
OrdinaryDiffEq.isfsal(::AbstractEK) = false

OrdinaryDiffEq.isimplicit(::EK1) = true
for ALG in [:EK1, :DiagonalEK1]
@eval OrdinaryDiffEq._alg_autodiff(::$ALG{CS,AD}) where {CS,AD} = Val{AD}()
@eval OrdinaryDiffEq.alg_difftype(::$ALG{CS,AD,DiffType}) where {CS,AD,DiffType} =
DiffType
@eval OrdinaryDiffEq.standardtag(::$ALG{CS,AD,DiffType,ST}) where {CS,AD,DiffType,ST} =
ST
@eval OrdinaryDiffEq.concrete_jac(
::$ALG{CS,AD,DiffType,ST,CJ},
) where {CS,AD,DiffType,ST,CJ} = CJ
@eval OrdinaryDiffEq.get_chunksize(::$ALG{CS}) where {CS} = Val(CS)
@eval OrdinaryDiffEq.isimplicit(::$ALG) = true
end

############################################
# Step size control
Expand Down
151 changes: 133 additions & 18 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
########################################################################################
abstract type AbstractEK <: OrdinaryDiffEq.OrdinaryDiffEqAdaptiveAlgorithm end

function ekargcheck(alg; diffusionmodel, pn_observation_noise, kwargs...)
function ekargcheck(
alg;
diffusionmodel,
pn_observation_noise,
covariance_factorization,
kwargs...,
)
if (isstatic(diffusionmodel) && diffusionmodel.calibrate) &&
(!isnothing(pn_observation_noise) && !iszero(pn_observation_noise))
throw(
Expand All @@ -12,14 +18,63 @@
),
)
end
if (
(diffusionmodel isa FixedMVDiffusion && diffusionmodel.calibrate) ||
diffusionmodel isa DynamicMVDiffusion) && alg == EK1
throw(
ArgumentError(
"The `EK1` algorithm does not support automatic calibration of multivariate diffusion models. Either use the `EK0` instead, or use a scalar diffusion model, or set `calibrate=false` and calibrate manually by optimizing `sol.pnstats.log_likelihood`.",
),
)
if alg == EK1
if diffusionmodel isa FixedMVDiffusion && diffusionmodel.calibrate
throw(
ArgumentError(
"The `EK1` algorithm does not support automatic global calibration of multivariate diffusion models. Either use a scalar diffusion model, or set `calibrate=false` and calibrate manually by optimizing `sol.pnstats.log_likelihood`. Or use a different solve, like `EK0` or `DiagonalEK1`.",
),
)
elseif diffusionmodel isa DynamicMVDiffusion
throw(
ArgumentError(
"The `EK1` algorithm does not support automatic calibration of local multivariate diffusion models. Either use a scalar diffusion model, or use a different solve, like `EK0` or `DiagonalEK1`.",
),
)
end
end
if !(isnothing(pn_observation_noise) || ismissing(pn_observation_noise))
if covariance_factorization == IsometricKroneckerCovariance && !(
pn_observation_noise isa Number
|| pn_observation_noise isa UniformScaling
|| pn_observation_noise isa Diagonal{<:Number,<:FillArrays.Fill})
throw(
ArgumentError(
"The supplied `pn_observation_noise` is not compatible with the chosen `IsometricKroneckerCovariance` factorization. Try one of `BlockDiagonalCovariance` or `DenseCovariance` instead!",
),
)
end
if covariance_factorization == BlockDiagonalCovariance && !(
pn_observation_noise isa Number
|| pn_observation_noise isa UniformScaling
|| pn_observation_noise isa Diagonal)
throw(
ArgumentError(
"The supplied `pn_observation_noise` is not compatible with the chosen `BlockDiagonalCovariance` factorization. Try `DenseCovariance` instead!",
),
)
end
end
end

function covariance_structure(::Type{Alg}, prior, diffusionmodel) where {Alg<:AbstractEK}
if Alg <: EK0
if prior isa IWP
if (diffusionmodel isa DynamicDiffusion || diffusionmodel isa FixedDiffusion)
return IsometricKroneckerCovariance
else
return BlockDiagonalCovariance
end
else
# This is not great as other priors can be Kronecker too; TODO
return DenseCovariance
end
elseif Alg <: DiagonalEK1
return BlockDiagonalCovariance
elseif Alg <: EK1
return DenseCovariance
else
throw(ArgumentError("Unknown algorithm type $Alg"))

Check warning on line 77 in src/algorithms.jl

View check run for this annotation

Codecov / codecov/patch

src/algorithms.jl#L77

Added line #L77 was not covered by tests
end
end

Expand Down Expand Up @@ -58,22 +113,25 @@

# [References](@ref references)
"""
struct EK0{PT,DT,IT,RT} <: AbstractEK
struct EK0{PT,DT,IT,RT,CF} <: AbstractEK
prior::PT
diffusionmodel::DT
smooth::Bool
initialization::IT
pn_observation_noise::RT
covariance_factorization::CF
EK0(; order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
pn_observation_noise::RT=nothing,
) where {PT,DT,IT,RT} = begin
ekargcheck(EK0; diffusionmodel, pn_observation_noise)
new{PT,DT,IT,RT}(
prior, diffusionmodel, smooth, initialization, pn_observation_noise)
covariance_factorization::CF=covariance_structure(EK0, prior, diffusionmodel),
) where {PT,DT,IT,RT,CF} = begin
ekargcheck(EK0; diffusionmodel, pn_observation_noise, covariance_factorization)
new{PT,DT,IT,RT,CF}(
prior, diffusionmodel, smooth, initialization, pn_observation_noise,
covariance_factorization)
end
end

Expand Down Expand Up @@ -117,12 +175,13 @@

# [References](@ref references)
"""
struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT} <: AbstractEK
struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
prior::PT
diffusionmodel::DT
smooth::Bool
initialization::IT
pn_observation_noise::RT
covariance_factorization::CF
EK1(;
order=3,
prior::PT=IWP(order),
Expand All @@ -135,8 +194,57 @@
standardtag=Val{true}(),
concrete_jac=nothing,
pn_observation_noise::RT=nothing,
) where {PT,DT,IT,RT} = begin
ekargcheck(EK1; diffusionmodel, pn_observation_noise)
covariance_factorization::CF=covariance_structure(EK1, prior, diffusionmodel),
) where {PT,DT,IT,RT,CF} = begin
ekargcheck(EK1; diffusionmodel, pn_observation_noise, covariance_factorization)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
PT,
DT,
IT,
RT,
CF,
}(
prior,
diffusionmodel,
smooth,
initialization,
pn_observation_noise,
covariance_factorization,
)
end
end

struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
prior::PT
diffusionmodel::DT
smooth::Bool
initialization::IT
pn_observation_noise::RT
covariance_factorization::CF
DiagonalEK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
diff_type=Val{:forward},
standardtag=Val{true}(),
concrete_jac=nothing,
pn_observation_noise::RT=nothing,
covariance_factorization::CF=covariance_structure(
DiagonalEK1,
prior,
diffusionmodel,
),
) where {PT,DT,IT,RT,CF} = begin
ekargcheck(DiagonalEK1; diffusionmodel, pn_observation_noise, covariance_factorization)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
Expand All @@ -147,12 +255,14 @@
DT,
IT,
RT,
CF,
}(
prior,
diffusionmodel,
smooth,
initialization,
pn_observation_noise,
covariance_factorization,
)
end
end
Expand Down Expand Up @@ -236,7 +346,12 @@
)
end

function DiffEqBase.prepare_alg(alg::EK1{0}, u0::AbstractArray{T}, p, prob) where {T}
function DiffEqBase.prepare_alg(
alg::Union{EK1{0},DiagonalEK1{0}},
u0::AbstractArray{T},
p,
prob,
) where {T}
# See OrdinaryDiffEq.jl: ./src/alg_utils.jl (where this is copied from).
# In the future we might want to make EK1 an OrdinaryDiffEqAdaptiveImplicitAlgorithm and
# use the prepare_alg from OrdinaryDiffEq; but right now, we do not use `linsolve` which
Expand Down
Loading
Loading