Skip to content

Commit

Permalink
Simplify the whole I code a bit
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Jun 13, 2023
1 parent 11d6ded commit a58030d
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function OrdinaryDiffEq.alg_cache(
D = d * (q + 1)

KRONECKER = alg isa EK0
Id = _mul_stable_I(d)
Id = _I(d)

uType = typeof(u)
# uElType = eltype(u_vec)
Expand Down
2 changes: 1 addition & 1 deletion src/initialization/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ function condition_on!(
end

d, q1 = size(H.A, 1), size(x.Σ.R.B, 1)
_I = kronecker(_mul_stable_I(d), I(q1))
_I = kronecker(I(d), I(q1))
KH = K*H
@assert _I.A == KH.A
@. KH.B = _I.B - KH.B
Expand Down
19 changes: 17 additions & 2 deletions src/kronecker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,23 @@ end
copy(A::KP) = kronecker((A.A), copy(A.B))

"""
_mul_stable_I(d) = I(d) * I(d)
_I(d) = I(d) * I(d)
Create an identity matrix that does not change its type when multiplied by another identity matrix.
# Examples
```julia-repl
julia> I(2)|> typeof
Diagonal{Bool, Vector{Bool}}
julia> I(2) * I(2) |> typeof
Diagonal{Bool, BitVector}
julia> _I(2) |> typeof
Diagonal{Bool, BitVector}
julia> _I(2) * _I(2) |> typeof
Diagonal{Bool, BitVector}
```
"""
_mul_stable_I(d) = I(d) * I(d)
_I(d) = I(d) * I(d)
2 changes: 1 addition & 1 deletion src/preconditioning.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function init_preconditioner(d, q, ::Type{elType}=typeof(1.0)) where {elType}
Id = _mul_stable_I(d)
Id = _I(d)
P = kronecker(Id, Diagonal(ones(elType, q + 1)))
PI = kronecker(Id, Diagonal(ones(elType, q + 1)))
return P, PI
Expand Down
4 changes: 2 additions & 2 deletions src/priors/iwp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ function preconditioned_discretize(iwp::IWP)
QR_breve = Q_breve.R |> Matrix

d = iwp.wiener_process_dimension
Id = _mul_stable_I(d)
Id = _I(d)
A = kronecker(Id, A_breve)
QR = kronecker(Id, QR_breve)
Q = PSDMatrix(QR)
Expand Down Expand Up @@ -100,7 +100,7 @@ end
function discretize(p::IWP, dt::Real)
A_breve, Q_breve = discretize_1d(p, dt)
d = p.wiener_process_dimension
Id = _mul_stable_I(d)
Id = _I(d)
A = kronecker(Id, A_breve)
QR = kronecker(Id, Q_breve.R)
Q = PSDMatrix(QR)
Expand Down
2 changes: 1 addition & 1 deletion src/projection.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function projection(d::Integer, q::Integer, ::Type{elType}=typeof(1.0)) where {elType}
Id = _mul_stable_I(d)
Id = _I(d)
Proj(deriv) = kronecker(Id, [i == (deriv + 1) ? 1 : 0 for i in 1:q+1]')

# Slightly faster version of the above:
Expand Down
2 changes: 1 addition & 1 deletion src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ function DiffEqBase.build_solution(
uElType = eltype(prob.u0)
D = d
KRONECKER = true
Id = _mul_stable_I(d)
Id = _I(d)
pu_cov = if KRONECKER
PSDMatrix(kronecker(Id, zeros(uElType, D ÷ d + 1)))
else
Expand Down

0 comments on commit a58030d

Please sign in to comment.