Skip to content

Commit

Permalink
Fix bugs and make code efficient => It's FAST!
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Jun 15, 2023
1 parent a58030d commit 1d653be
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 25 deletions.
2 changes: 2 additions & 0 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ function calc_H_EK0!(H, integ, cache)
@unpack f = integ
@unpack d, ddu, E1, E2 = cache

return nothing

if f isa DynamicalODEFunction
@assert f.mass_matrix === I
H .= E2
Expand Down
18 changes: 3 additions & 15 deletions src/fast_linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,9 @@ _matmul!(C, A, B)
_matmul!(C, A, B) = mul!(C, A, B)
_matmul!(C, A, B, a, b) = mul!(C, A, B, a, b)
# Some special cases
_matmul!(
C::AbstractMatrix{T},
A::AbstractMatrix{T},
B::Diagonal{T},
) where {T<:LinearAlgebra.BlasFloat} = (@.. C = A * B.diag')
_matmul!(
C::AbstractMatrix{T},
A::Diagonal{T},
B::AbstractMatrix{T},
) where {T<:LinearAlgebra.BlasFloat} = (@.. C = A.diag * B)
_matmul!(
C::AbstractMatrix{T},
A::Diagonal{T},
B::Diagonal{T},
) where {T<:LinearAlgebra.BlasFloat} = @.. C = A * B
_matmul!(C::AbstractMatrix, A::AbstractMatrix, B::Diagonal) = (@.. C = A * B.diag')
_matmul!(C::AbstractMatrix, A::Diagonal, B::AbstractMatrix) = (@.. C = A.diag * B)
_matmul!(C::AbstractMatrix, A::Diagonal, B::Diagonal) = @.. C = A * B
_matmul!(
C::AbstractMatrix{T},
A::AbstractVecOrMat{T},
Expand Down
8 changes: 4 additions & 4 deletions src/initialization/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ function condition_on!(
_matmul!(x.μ, K, datadiff, 1, 1)

D = length(x.μ)
_matmul!(Mcache, K, H, -1, 0)
@inbounds @simd ivdep for i in 1:D
Mcache[i, i] += 1
end
# _matmul!(Mcache, K, H, -1, 0)
# @inbounds @simd ivdep for i in 1:D
# Mcache[i, i] += 1
# end

d, q1 = size(H.A, 1), size(x.Σ.R.B, 1)
_I = kronecker(I(d), I(q1))
Expand Down
3 changes: 2 additions & 1 deletion src/perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,12 @@ function estimate_errors!(cache::AbstractODEFilterCache)

if H isa Kronecker.KroneckerProduct
error_estimate = ones(d)
error_estimate .*= (Qh.R.B * H.B')[1]^2
error_estimate .*= sum(abs2, Qh.R.B * H.B')
# error_estimate = view(cache.tmp, 1:d)
# sum!(abs2, error_estimate', view(R, :, 1:d))
error_estimate .*= local_diffusion
error_estimate .= sqrt.(error_estimate)
return error_estimate
else

_matmul!(R, Qh.R, H')
Expand Down
6 changes: 2 additions & 4 deletions src/priors/iwp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ function make_transition_matrices!(cache, prior::IWP, dt)
make_preconditioners!(cache, dt)
# A, Q = preconditioned_discretize(p) # not necessary since it's dt-independent
# Ah = PI * A * P
# @.. Ah = PI.diag * A * P.diag'
Ah.B .= PI.B * A.B * P.B
@.. Ah.B = PI.B.diag * A.B * P.B.diag'
# X_A_Xt!(Qh, Q, PI)
# @.. Qh.R = Q.R * PI.diag
Qh.R.B .= Q.R.B * PI.B
@.. Qh.R.B = Q.R.B * PI.B.diag
end
6 changes: 5 additions & 1 deletion src/projection.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
function projection(d::Integer, q::Integer, ::Type{elType}=typeof(1.0)) where {elType}
Id = _I(d)
Proj(deriv) = kronecker(Id, [i == (deriv + 1) ? 1 : 0 for i in 1:q+1]')
Proj(deriv) = begin
e_i = zeros(q+1, 1)
e_i[deriv+1] = 1
kronecker(Id, e_i')
end

# Slightly faster version of the above:
# D = d * (q + 1)
Expand Down

0 comments on commit 1d653be

Please sign in to comment.