Skip to content

Commit

Permalink
Should make tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 30, 2023
1 parent 1a0df4e commit 81c86f3
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ uuid = "ade2ca70-3891-5945-98fb-dc099432e06a"

[[deps.DiffEqBase]]
deps = ["ArrayInterface", "ChainRulesCore", "DataStructures", "DocStringExtensions", "EnumX", "EnzymeCore", "FastBroadcast", "ForwardDiff", "FunctionWrappers", "FunctionWrappersWrappers", "LinearAlgebra", "Logging", "Markdown", "MuladdMacro", "Parameters", "PreallocationTools", "PrecompileTools", "Printf", "RecursiveArrayTools", "Reexport", "Requires", "SciMLBase", "SciMLOperators", "Setfield", "SparseArrays", "Static", "StaticArraysCore", "Statistics", "Tricks", "TruncatedStacktraces", "ZygoteRules"]
git-tree-sha1 = "4e661d0beddac31da05e71b79afd769232622de8"
git-tree-sha1 = "0ab52aef95c5cc71e9a8c9d26919ce1f7fb472fa"
repo-rev = "ap/tstable_termination"
repo-url = "https://github.com/SciML/DiffEqBase.jl"
uuid = "2b5f629d-d688-5b77-993f-72d75c75574e"
Expand Down
1 change: 1 addition & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ end
get_fu(cache::AbstractNonlinearSolveCache) = cache.fu1
set_fu!(cache::AbstractNonlinearSolveCache, fu) = (cache.fu1 = fu)
get_u(cache::AbstractNonlinearSolveCache) = cache.u
set_u!(cache::AbstractNonlinearSolveCache, u) = (cache.u = u)

function SciMLBase.solve!(cache::AbstractNonlinearSolveCache)
while not_terminated(cache)
Expand Down
1 change: 1 addition & 0 deletions src/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ end
get_fu(cache::DFSaneCache) = cache.fuₙ
set_fu!(cache::DFSaneCache, fu) = (cache.fuₙ = fu)
get_u(cache::DFSaneCache) = cache.uₙ
set_u!(cache::DFSaneCache, u) = (cache.uₙ = u)

function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::DFSane, args...;
alias_u0 = false, maxiters = 1000, abstol = nothing, reltol = nothing,
Expand Down
14 changes: 6 additions & 8 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,7 @@ function check_and_update!(tc_cache, cache, fu, u, uprev,
if isinplace(cache)
cache.prob.f(get_fu(cache), u, cache.prob.p)
else
cache.u = u
set_fu!(cache, cache.prob.f(cache.u, cache.prob.p))
set_fu!(cache, cache.prob.f(u, cache.prob.p))
end
cache.force_stop = true
end
Expand All @@ -252,8 +251,7 @@ function check_and_update!(tc_cache, cache, fu, u, uprev,
if isinplace(cache)
cache.prob.f(get_fu(cache), u, cache.prob.p)
else
cache.u = u
set_fu!(cache, cache.prob.f(cache.u, cache.prob.p))
set_fu!(cache, cache.prob.f(u, cache.prob.p))
end
cache.force_stop = true
end
Expand All @@ -271,11 +269,11 @@ function check_and_update!(tc_cache, cache, fu, u, uprev,
cache.retcode = ReturnCode.Unstable
end
if isinplace(cache)
copyto!(u, tc_cache.u)
cache.prob.f(get_fu(cache), u, cache.prob.p)
copyto!(get_u(cache), tc_cache.u)
cache.prob.f(get_fu(cache), get_u(cache), cache.prob.p)
else
cache.u = tc_cache.u
set_fu!(cache, cache.prob.f(cache.u, cache.prob.p))
set_u!(cache, tc_cache.u)
set_fu!(cache, cache.prob.f(get_u(cache), cache.prob.p))
end
cache.force_stop = true
end
Expand Down
15 changes: 2 additions & 13 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -453,17 +453,13 @@ end
end

@testset "[OOP] [Immutable AD]" begin
broken_forwarddiff = [3.0, 4.0, 81.0]
for p in 1.1:0.1:100.0
res = abs.(benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p).u)

if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
@test_broken all(res .≈ sqrt(p))
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
elseif p in broken_forwarddiff
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
@SVector[1.0, 1.0], p).u[end], p)) 1 / (2 * sqrt(p))
else
@test all(res .≈ sqrt(p))
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
Expand All @@ -473,17 +469,13 @@ end
end

@testset "[OOP] [Scalar AD]" begin
broken_forwarddiff = [3.0, 4.0, 81.0]
for p in 1.1:0.1:100.0
res = abs(benchmark_nlsolve_oop(quadratic_f, 1.0, p).u)

if any(x -> isnan(x) || x <= 1e-5 || x >= 1e5, res)
@test_broken res sqrt(p)
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
1.0, p).u, p)) 1 / (2 * sqrt(p))
elseif p in broken_forwarddiff
@test_broken abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
1.0, p).u, p)) 1 / (2 * sqrt(p))
else
@test res sqrt(p)
@test isapprox(abs.(ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
Expand Down Expand Up @@ -549,7 +541,6 @@ end

probN = NonlinearProblem{false}(quadratic_f, [1.0, 1.0], 2.0)
sol = solve(probN, alg, abstol = 1e-11)
println(abs.(quadratic_f(sol.u, 2.0)))
@test all(abs.(quadratic_f(sol.u, 2.0)) .< 1e-10)
end
end
Expand Down Expand Up @@ -644,13 +635,11 @@ end

function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
probN = NonlinearProblem{iip}(f, iip ? [0.5] : 0.5, p_range[begin])
cache = init(probN,
PseudoTransient(alpha_initial = 10.0);
maxiters = 100,
cache = init(probN, PseudoTransient(alpha_initial = 10.0); maxiters = 100,
abstol = 1e-10)
sols = zeros(length(p_range))
for (i, p) in enumerate(p_range)
reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p, alpha_new = 10.0)
reinit!(cache, iip ? [cache.u[1]] : cache.u; p = p, alpha = 10.0)
sol = solve!(cache)
sols[i] = iip ? sol.u[1] : sol.u
end
Expand Down

0 comments on commit 81c86f3

Please sign in to comment.