Skip to content

Commit

Permalink
Added tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
axla-io committed Oct 17, 2023
1 parent e98ec26 commit 0df8aaf
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 56 deletions.
115 changes: 59 additions & 56 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,21 @@ function DFSane(; σ_min = 1e-10,
η_strategy = (fn_1, n, x_n, f_n) -> fn_1 / n^2,
max_inner_iterations = 1000)
return DFSane{typeof(σ_min), typeof(η_strategy)}(σ_min,
σ_max,
σ_1,
M,
γ,
τ_min,
τ_max,
n_exp,
η_strategy,
max_inner_iterations)
σ_max,
σ_1,
M,
γ,
τ_min,
τ_max,
n_exp,
η_strategy,
max_inner_iterations)
end
mutable struct DFSaneCache{iip, fType, algType, uType, resType, T, pType,
mutable struct DFSaneCache{iip, algType, uType, resType, T, pType,
INType,
tolType,
probType}
f::fType
f::Function
alg::algType
uₙ::uType
uₙ₋₁::uType
Expand All @@ -109,19 +109,19 @@ mutable struct DFSaneCache{iip, fType, algType, uType, resType, T, pType,
abstol::tolType
prob::probType
stats::NLStats
function DFSaneCache{iip}(f::fType, alg::algType, uₙ::uType, uₙ₋₁::uType,
function DFSaneCache{iip}(f::Function, alg::algType, uₙ::uType, uₙ₋₁::uType,
fuₙ::resType, fuₙ₋₁::resType, 𝒹::uType, ℋ::Vector{T},
f₍ₙₒᵣₘ₎ₙ₋₁::T, f₍ₙₒᵣₘ₎₀::T, M::Int, σₙ::T, σₘᵢₙ::T, σₘₐₓ::T,
α₁::T, γ::T, τₘᵢₙ::T, τₘₐₓ::T, nₑₓₚ::Int, p::pType,
force_stop::Bool, maxiters::Int, internalnorm::INType,
retcode::SciMLBase.ReturnCode.T, abstol::tolType,
prob::probType,
stats::NLStats) where {iip, fType, algType, uType,
stats::NLStats) where {iip, algType, uType,
resType, T, pType, INType,
tolType,
probType
}
new{iip, fType, algType, uType, resType, T, pType, INType, tolType,
new{iip, algType, uType, resType, T, pType, INType, tolType,
probType
}(f, alg, uₙ, uₙ₋₁, fuₙ, fuₙ₋₁, 𝒹, ℋ, f₍ₙₒᵣₘ₎ₙ₋₁, f₍ₙₒᵣₘ₎₀, M, σₙ,
σₘᵢₙ, σₘₐₓ, α₁, γ, τₘᵢₙ,
Expand All @@ -146,7 +146,8 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane,

p = prob.p
T = eltype(uₙ)
σₘᵢₙ, σₘₐₓ, γ, τₘᵢₙ, τₘₐₓ = T(alg.σ_min), T(alg.σ_max), T(alg.γ), T(alg.τ_min), T(alg.τ_max)
σₘᵢₙ, σₘₐₓ, γ, τₘᵢₙ, τₘₐₓ = T(alg.σ_min), T(alg.σ_max), T(alg.γ), T(alg.τ_min),
T(alg.τ_max)
α₁ = one(T)
γ = T(alg.γ)
f₍ₙₒᵣₘ₎ₙ₋₁ = α₁
Expand Down Expand Up @@ -262,16 +263,16 @@ function perform_step!(cache::DFSaneCache{false})
σₙ = sign(σₙ) * clamp(abs(σₙ), σₘᵢₙ, σₘₐₓ)

# Line search direction
@. cache.𝒹 = -σₙ * cache.fuₙ₋₁
cache.𝒹 = -σₙ * cache.fuₙ₋₁

η = alg.ηₛ(f₍ₙₒᵣₘ₎₀, n, cache.uₙ₋₁, cache.fuₙ₋₁)
η = alg.η_strategy(f₍ₙₒᵣₘ₎₀, n, cache.uₙ₋₁, cache.fuₙ₋₁)

= maximum(cache.ℋ)
α₊ = α₁
α₋ = α₁
@. cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹
cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹

cache.fuₙ .= f(cache.uₙ)
cache.fuₙ = f(cache.uₙ)
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
for _ in 1:(cache.alg.max_inner_iterations)
𝒸 =+ η - γ * α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁
Expand All @@ -282,9 +283,9 @@ function perform_step!(cache::DFSaneCache{false})
(f₍ₙₒᵣₘ₎ₙ + (T(2) * α₊ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₊,
τₘₐₓ * α₊)
@. cache.uₙ = cache.uₙ₋₁ - α₋ * cache.𝒹
cache.uₙ = @. cache.uₙ₋₁ - α₋ * cache.𝒹

cache.fuₙ .= f(cache.uₙ)
cache.fuₙ = f(cache.uₙ)
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ

f₍ₙₒᵣₘ₎ₙ .≤ 𝒸 && break
Expand All @@ -293,8 +294,8 @@ function perform_step!(cache::DFSaneCache{false})
τₘᵢₙ * α₋,
τₘₐₓ * α₋)

@. cache.uₙ = cache.uₙ₋₁ + α₊ * cache.𝒹
cache.fuₙ .= f(cache.uₙ)
cache.uₙ = @. cache.uₙ₋₁ + α₊ * cache.𝒹
cache.fuₙ = f(cache.uₙ)
f₍ₙₒᵣₘ₎ₙ = norm(cache.fuₙ)^nₑₓₚ
end

Expand All @@ -303,11 +304,11 @@ function perform_step!(cache::DFSaneCache{false})
end

# Update spectral parameter
@. cache.uₙ₋₁ = cache.uₙ - cache.uₙ₋₁
@. cache.fuₙ₋₁ = cache.fuₙ - cache.fuₙ₋₁
cache.uₙ₋₁ = @. cache.uₙ - cache.uₙ₋₁
cache.fuₙ₋₁ = @. cache.fuₙ - cache.fuₙ₋₁

α₊ = sum(abs2, cache.uₙ₋₁)
@. cache.uₙ₋₁ = cache.uₙ₋₁ * cache.fuₙ₋₁
cache.uₙ₋₁ = @. cache.uₙ₋₁ * cache.fuₙ₋₁
α₋ = sum(cache.uₙ₋₁)
cache.σₙ = α₊ / α₋

Expand All @@ -318,8 +319,8 @@ function perform_step!(cache::DFSaneCache{false})
end

# Take step
@. cache.uₙ₋₁ = cache.uₙ
@. cache.fuₙ₋₁ = cache.fuₙ
cache.uₙ₋₁ = cache.uₙ
cache.fuₙ₋₁ = cache.fuₙ
cache.f₍ₙₒᵣₘ₎ₙ₋₁ = f₍ₙₒᵣₘ₎ₙ

# Update history
Expand All @@ -345,32 +346,34 @@ function SciMLBase.solve!(cache::DFSaneCache)
end

function SciMLBase.reinit!(cache::DFSaneCache{iip}, u0 = cache.uₙ; p = cache.p,
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.uₙ, u0)
recursivecopy!(cache.uₙ₋₁, u0)
cache.f(cache.fuₙ, cache.uₙ, p)
cache.f(cache.fuₙ₋₁, cache.uₙ, p)
else
cache.uₙ = u0
cache.uₙ₋₁ = u0
cache.fuₙ = cache.f(cache.uₙ, p)
cache.fuₙ₋₁ = cache.f(cache.uₙ, p)
end

cache.f₍ₙₒᵣₘ₎ₙ₋₁ = norm(fuₙ₋₁)^nₑₓₚ
cache.f₍ₙₒᵣₘ₎₀ = cache.f₍ₙₒᵣₘ₎ₙ₋₁
fill!(cache.ℋ, cache.f₍ₙₒᵣₘ₎ₙ₋₁, cache.M)

T = eltype(cache.uₙ)
cache.σₙ = T(cache.alg.σ_1)

cache.abstol = abstol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
abstol = cache.abstol, maxiters = cache.maxiters) where {iip}
cache.p = p
if iip
recursivecopy!(cache.uₙ, u0)
recursivecopy!(cache.uₙ₋₁, u0)
cache.f = (dx, x) -> cache.prob.f(dx, x, p)
cache.f(cache.fuₙ, cache.uₙ)
cache.f(cache.fuₙ₋₁, cache.uₙ)
else
cache.uₙ = u0
cache.uₙ₋₁ = u0
cache.f = (x) -> cache.prob.f(x, p)
cache.fuₙ = cache.f(cache.uₙ)
cache.fuₙ₋₁ = cache.f(cache.uₙ)
end

cache.f₍ₙₒᵣₘ₎ₙ₋₁ = norm(cache.fuₙ₋₁)^cache.nₑₓₚ
cache.f₍ₙₒᵣₘ₎₀ = cache.f₍ₙₒᵣₘ₎ₙ₋₁
fill!(cache.ℋ, cache.f₍ₙₒᵣₘ₎ₙ₋₁)

T = eltype(cache.uₙ)
cache.σₙ = T(cache.alg.σ_1)

cache.abstol = abstol
cache.maxiters = maxiters
cache.stats.nf = 1
cache.stats.nsteps = 1
cache.force_stop = false
cache.retcode = ReturnCode.Default
return cache
end
145 changes: 145 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,3 +390,148 @@ end
end
end
end


# --- DFSane tests ---

@testset "DFSane" begin
function benchmark_nlsolve_oop(f, u0, p=2.0)
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, DFSane(), abstol=1e-9)
end

function benchmark_nlsolve_iip(f, u0, p=2.0)
prob = NonlinearProblem{true}(f, u0, p)
return solve(prob, DFSane(), abstol=1e-9)
end

u0s = ([1.0, 1.0], @SVector[1.0, 1.0], 1.0)

@testset "[OOP] u0: $(typeof(u0))" for u0 in u0s
sol = benchmark_nlsolve_oop(quadratic_f, u0)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0), DFSane(),
abstol=1e-9)
@test (@ballocated solve!($cache)) < 200
end

@testset "[IIP] u0: $(typeof(u0))" for u0 in ([
1.0, 1.0],)
sol = benchmark_nlsolve_iip(quadratic_f!, u0)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
DFSane(), abstol=1e-9)
@test (@ballocated solve!($cache)) 64
end


@testset "[OOP] [Immutable AD]" begin
broken_forwarddiff = [1.6, 2.9, 3.0, 3.5, 4.0, 81.0]
for p in 1.1:0.1:100.0
res = abs.(benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p).u)

if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
@test_broken all(res .≈ sqrt(p))
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
elseif p in broken_forwarddiff
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
else
@test all(res .≈ sqrt(p))
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)), 1 / (2 * sqrt(p)))
end
end
end

@testset "[OOP] [Scalar AD]" begin
broken_forwarddiff = [1.6, 2.9, 3.0, 3.5, 4.0, 81.0]
for p in 1.1:0.1:100.0
res = abs(benchmark_nlsolve_oop(quadratic_f, 1.0, p).u)

if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
@test_broken res sqrt(p)
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)) 1 / (2 * sqrt(p))
elseif p in broken_forwarddiff
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)) 1 / (2 * sqrt(p))
else
@test res sqrt(p)
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u, p)), 1 / (2 * sqrt(p)))
end
end
end

t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p) ForwardDiff.jacobian(t, p)

# Iterator interface
function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin])
cache = init(probN, DFSane(); maxiters=100, abstol=1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, iip ? [cache.uₙ[1]] : cache.uₙ; p=p)
sol = solve!(cache)
sols[i] = iip ? sol.u[1] : sol.u
end
return sols
end
p = range(0.01, 2, length=200)
@test abs.(nlprob_iterator_interface(quadratic_f, p, Val(false))) sqrt.(p)
@test abs.(nlprob_iterator_interface(quadratic_f!, p, Val(true))) sqrt.(p)


# Test that `DFSane` passes a test that `NewtonRaphson` fails on.
@testset "Newton Raphson Fails" begin
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
sol = benchmark_nlsolve_oop(newton_fails, u0, p)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(newton_fails(sol.u, p)) .< 1e-9)
end

# Test kwargs in `DFSane`
@testset "Keyword Arguments" begin
σ_min = [1e-10, 1e-5, 1e-4]
σ_max = [1e10, 1e5, 1e4]
σ_1 = [1.0, 0.5, 2.0]
M = [10, 1, 100]
γ = [1e-4, 1e-3, 1e-5]
τ_min = [0.1, 0.2, 0.3]
τ_max = [0.5, 0.8, 0.9]
nexp = [2, 1, 2]
η_strategy = [
(f_1, k, x, F) -> f_1 / k^2,
(f_1, k, x, F) -> f_1 / k^3,
(f_1, k, x, F) -> f_1 / k^4,
]

list_of_options = zip(σ_min, σ_max, σ_1, M, γ, τ_min, τ_max, nexp,
η_strategy)
for options in list_of_options
local probN, sol, alg
alg = DFSane(σ_min=options[1],
σ_max=options[2],
σ_1=options[3],
M=options[4],
γ=options[5],
τ_min=options[6],
τ_max=options[7],
n_exp=options[8],
η_strategy=options[9])

probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0)
sol = solve(probN, alg, abstol=1e-11)
println(abs.(quadratic_f(sol.u, 2.0)))
@test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10)
end
end
end

0 comments on commit 0df8aaf

Please sign in to comment.