Skip to content

Commit

Permalink
JuliaFormatter.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Nov 4, 2023
1 parent 3c8fda8 commit 32c4c2c
Showing 1 changed file with 23 additions and 5 deletions.
28 changes: 23 additions & 5 deletions src/initialization/autodiffinit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ function initial_update!(integ, cache, init::AutodiffInitializationScheme)

f_derivatives = get_derivatives(init, u, f, p, t)
integ.stats.nf += init.order
@assert length(f_derivatives) == init.order+1
@assert length(f_derivatives) == init.order + 1

# This is hacky and should definitely be removed. But it also works so 🤷
MM = if f.mass_matrix isa UniformScaling
Expand Down Expand Up @@ -49,7 +49,13 @@ end
"""
Compute initial derivatives of an IIP ODEProblem with TaylorIntegration.jl
"""
function get_derivatives(init::TaylorModeInit, u, f::SciMLBase.AbstractODEFunction{true}, p, t)
function get_derivatives(
init::TaylorModeInit,
u,
f::SciMLBase.AbstractODEFunction{true},
p,
t,
)
q = init.order
tT = Taylor1(typeof(t), q)
tT[0] = t
Expand All @@ -64,7 +70,13 @@ function get_derivatives(init::TaylorModeInit, u, f::SciMLBase.AbstractODEFuncti
return [evaluate.(differentiate.(uT, i)) for i in 0:q]
end

function get_derivatives(init::ForwardDiffInit, u, f::SciMLBase.AbstractODEFunction{true}, p, t)
function get_derivatives(
init::ForwardDiffInit,
u,
f::SciMLBase.AbstractODEFunction{true},
p,
t,
)
q = init.order
_f(u) = (du = copy(u); f(du, u, p, t); du)
f_n = _f
Expand All @@ -85,8 +97,14 @@ function forwarddiff_oop_vectorfield_derivative_iteration(f_n, f_0)
return df
end


function forwarddiff_get_derivatives!(out, u, f::SciMLBase.AbstractODEFunction{true}, p, t, q)
function forwarddiff_get_derivatives!(
out,
u,
f::SciMLBase.AbstractODEFunction{true},
p,
t,
q,
)
_f(du, u) = f(du, u, p, t)
d = length(u0)
f_n = _f
Expand Down

0 comments on commit 32c4c2c

Please sign in to comment.