Skip to content

Commit

Permalink
Implemented pt anded tests and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
yonatanwesen committed Oct 19, 2023
1 parent 7fadae1 commit d92f2d7
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ include("trustRegion.jl")
include("levenberg.jl")
include("gaussnewton.jl")
include("dfsane.jl")
include("pseudotransient.jl")
include("jacobian.jl")
include("ad.jl")
include("default.jl")
Expand Down Expand Up @@ -95,7 +96,7 @@ end

export RadiusUpdateSchemes

export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton
export NewtonRaphson, TrustRegion, LevenbergMarquardt, DFSane, GaussNewton, PseudoTransient
export LeastSquaresOptimJL, FastLevenbergMarquardtJL
export RobustMultiNewton, FastShortcutNonlinearPolyalg

Expand Down
6 changes: 6 additions & 0 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@ function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p, ::Val{ii
linprob = LinearProblem(needsJᵀJ ? __maybe_symmetric(JᵀJ) : J,
needsJᵀJ ? _vec(Jᵀfu) : _vec(fu); u0 = _vec(du))

if alg isa PseudoTransient
alpha = convert(eltype(u), alg.alpha_initial)
J_new = J - (1 / alpha) * I
linprob = LinearProblem(J_new, _vec(fu); u0 = _vec(du))
end

weight = similar(u)
recursivefill!(weight, true)

Expand Down
121 changes: 121 additions & 0 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -543,3 +543,124 @@ end
end
end
end

# --- PseudoTransient tests ---

@testset "PseudoTransient" begin
#these are tests for NewtonRaphson so we should set alpha_initial to be high so that we converge quickly

function benchmark_nlsolve_oop(f, u0, p = 2.0; alpha_initial = 10.0)
prob = NonlinearProblem{false}(f, u0, p)
return solve(prob, PseudoTransient(; alpha_initial), abstol = 1e-9)
end

function benchmark_nlsolve_iip(f, u0, p = 2.0; linsolve, precs,
alpha_initial = 10.0)
prob = NonlinearProblem{true}(f, u0, p)
return solve(prob, PseudoTransient(; linsolve, precs, alpha_initial), abstol = 1e-9)
end

@testset "PT: alpha_initial = 10.0 PT AD: $(ad)" for ad in (AutoFiniteDiff(),
AutoZygote())
u0s = VERSION v"1.9" ? ([1.0, 1.0], @SVector[1.0, 1.0], 1.0) : ([1.0, 1.0], 1.0)

@testset "[OOP] u0: $(typeof(u0))" for u0 in u0s
sol = benchmark_nlsolve_oop(quadratic_f, u0)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{false}(quadratic_f, u0, 2.0),
PseudoTransient(alpha_initial = 10.0),
abstol = 1e-9)
@test (@ballocated solve!($cache)) < 200
end

precs = [NonlinearSolve.DEFAULT_PRECS, :Random]

@testset "[IIP] u0: $(typeof(u0)) precs: $(_nameof(prec)) linsolve: $(_nameof(linsolve))" for u0 in ([
1.0, 1.0],), prec in precs, linsolve in (nothing, KrylovJL_GMRES())
ad isa AutoZygote && continue
if prec === :Random
prec = (args...) -> (Diagonal(randn!(similar(u0))), nothing)
end
sol = benchmark_nlsolve_iip(quadratic_f!, u0; linsolve, precs = prec)
@test SciMLBase.successful_retcode(sol)
@test all(abs.(sol.u .* sol.u .- 2) .< 1e-9)

cache = init(NonlinearProblem{true}(quadratic_f!, u0, 2.0),
PseudoTransient(; alpha_initial = 10.0, linsolve, precs = prec),
abstol = 1e-9)
@test (@ballocated solve!($cache)) 64
end
end

if VERSION v"1.9"
@testset "[OOP] [Immutable AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
res_true = sqrt(p)
all(res.u .≈ res_true)
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
end
end
end

@testset "[OOP] [Scalar AD]" begin
for p in 1.0:0.1:100.0
@test begin
res = benchmark_nlsolve_oop(quadratic_f, 1.0, p)
res_true = sqrt(p)
res.u res_true
end
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f, 1.0, p).u,
p)
1 / (2 * sqrt(p))
end
end

if VERSION v"1.9"
t = (p) -> [sqrt(p[2] / p[1])]
p = [0.9, 50.0]
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
p)
ForwardDiff.jacobian(t, p)
end

function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin])
cache = init(probN,
PseudoTransient(alpha_initial = 10.0);
maxiters = 100,
abstol = 1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p, alpha_new = 10.0)
sol = solve!(cache)
sols[i] = iip ? sol.u[1] : sol.u
end
return sols
end
p = range(0.01, 2, length = 200)
@test nlprob_iterator_interface(quadratic_f, p, Val(false)) sqrt.(p)
@test nlprob_iterator_interface(quadratic_f!, p, Val(true)) sqrt.(p)

@testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true,
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),
AutoSparseZygote(), AutoSparseEnzyme()), u0 in (1.0, [1.0, 1.0])
probN = NonlinearProblem(quadratic_f, u0, 2.0)
@test all(solve(probN, PseudoTransient(; alpha_initial = 10.0, autodiff)).u .≈
sqrt(2.0))
end

@testset "NewtonRaphson Fails but PT passes" begin # Test that `PseudoTransient` passes a test that `NewtonRaphson` fails on.
p = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
u0 = [-10.0, -1.0, 1.0, 2.0, 3.0, 4.0, 10.0]
probN = NonlinearProblem{false}(newton_fails, u0, p)
sol = solve(probN, PseudoTransient(alpha_initial = 1.0), abstol = 1e-10)
@test all(abs.(newton_fails(sol.u, p)) .< 1e-10)
end
end

0 comments on commit d92f2d7

Please sign in to comment.