Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
vyudu committed Dec 16, 2024
1 parent 0cb4893 commit 67d8164
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
11 changes: 6 additions & 5 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,6 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
eval_expression = false,
eval_module = @__MODULE__,
kwargs...) where {iip, specialize}

if !iscomplete(sys)
error("A completed system is required. Call `complete` or `structural_simplify` on the system before creating an `BVProblem`")
end
Expand All @@ -528,12 +527,12 @@ function SciMLBase.BVProblem{iip, specialize}(sys::AbstractODESystem, u0map = []
if cbs !== nothing
kwargs1 = merge(kwargs1, (callback = cbs,))
end

# Construct initial conditions.
_u0 = u0 isa Function ? u0(tspan[1]) : u0

# Define the boundary conditions.
bc = if iip
bc = if iip
(residual, u, p, t) -> (residual .= u[1] .- _u0)
else
(u, p, t) -> (u[1] - _u0)
Expand All @@ -544,11 +543,13 @@ end

get_callback(prob::BVProblem) = error("BVP solvers do not support callbacks.")

@inline function create_array(::Type{Base.ReinterpretArray}, ::Nothing, ::Val{1}, ::Val{dims}, elems...) where dims
@inline function create_array(::Type{Base.ReinterpretArray}, ::Nothing,
::Val{1}, ::Val{dims}, elems...) where {dims}
[elems...]
end

@inline function create_array(::Type{Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where dims
@inline function create_array(
::Type{Base.ReinterpretArray}, T, ::Val{1}, ::Val{dims}, elems...) where {dims}
T[elems...]
end

Expand Down
42 changes: 22 additions & 20 deletions test/bvproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,51 @@ using ModelingToolkit: t_nounits as t, D_nounits as D

solvers = [MIRK4, RadauIIa5, LobattoIIIa3]

@parameters α = 7.5 β = 4. γ = 8. δ = 5.
@variables x(t) = 1. y(t) = 2.
@parameters α=7.5 β=4.0 γ=8.0 δ=5.0
@variables x(t)=1.0 y(t)=2.0

eqs = [D(x) ~ α*x - β*x*y,
D(y) ~ -γ*y + δ*x*y]
eqs = [D(x) ~ α * x - β * x * y,
D(y) ~ -γ * y + δ * x * y]

u0map = [:x => 1., :y => 2.]
parammap = [ => 7.5, => 4, => 8., => 5.]
tspan = (0., 10.)
u0map = [:x => 1.0, :y => 2.0]
parammap = [ => 7.5, => 4, => 8.0, => 5.0]
tspan = (0.0, 10.0)

@mtkbuild lotkavolterra = ODESystem(eqs, t)
op = ODEProblem(lotkavolterra, u0map, tspan, parammap)
osol = solve(op, Vern9())

bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; eval_expression = true)
bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(
lotkavolterra, u0map, tspan, parammap; eval_expression = true)

for solver in solvers
sol = solve(bvp, solver(), dt = 0.01)
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] == [1., 2.]
@test sol.u[1] == [1.0, 2.0]
end

# Test out of place
bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(lotkavolterra, u0map, tspan, parammap; eval_expression = true)
bvp2 = SciMLBase.BVProblem{false, SciMLBase.AutoSpecialize}(
lotkavolterra, u0map, tspan, parammap; eval_expression = true)

for solver in solvers
sol = solve(bvp2, solver(), dt = 0.01)
@test isapprox(sol.u[end],osol.u[end]; atol = 0.01)
@test sol.u[1] == [1., 2.]
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] == [1.0, 2.0]
end

### Testing on pendulum

@parameters g = 9.81 L = 1.
@variables θ(t) = π/2
@parameters g=9.81 L=1.0
@variables θ(t) = π / 2

eqs = [D(D(θ)) ~ -(g / L) * sin(θ)]

@mtkbuild pend = ODESystem(eqs, t)

u0map ==> π/2, D(θ) => π/2]
parammap = [:L => 1., :g => 9.81]
tspan = (0., 6.)
u0map ==> π / 2, D(θ) => π / 2]
parammap = [:L => 1.0, :g => 9.81]
tspan = (0.0, 6.0)

op = ODEProblem(pend, u0map, tspan, parammap)
osol = solve(op, Vern9())
Expand All @@ -55,14 +57,14 @@ bvp = SciMLBase.BVProblem{true, SciMLBase.AutoSpecialize}(pend, u0map, tspan, pa
for solver in solvers
sol = solve(bvp, solver(), dt = 0.01)
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] ==/2, π/2]
@test sol.u[1] == / 2, π / 2]
end

# Test out-of-place
bvp2 = SciMLBase.BVProblem{false, SciMLBase.FullSpecialize}(pend, u0map, tspan, parammap)

for solver in solvers
sol = solve(bvp2, solver(), dt = 0.01)
@test isapprox(sol.u[end],osol.u[end]; atol = 0.01)
@test sol.u[1] ==/2, π/2]
@test isapprox(sol.u[end], osol.u[end]; atol = 0.01)
@test sol.u[1] == / 2, π / 2]
end

0 comments on commit 67d8164

Please sign in to comment.