diff --git a/src/trustRegion.jl b/src/trustRegion.jl index f380cef4c..19d5ef0ad 100644 --- a/src/trustRegion.jl +++ b/src/trustRegion.jl @@ -25,6 +25,13 @@ EnumX.@enumx RadiusUpdateSchemes begin """ Simple + """ + `RadiusUpdateSchemes.NLsolve` + + The same updating rule as in NLsolve's trust region implementation + """ + NLsolve + """ `RadiusUpdateSchemes.Hei` @@ -177,8 +184,8 @@ function TrustRegion(; chunk_size = Val{0}(), max_trust_radius::Real = 0 // 1, initial_trust_radius::Real = 0 // 1, step_threshold::Real = 1 // 10000, - shrink_threshold::Real = 1 // 4, - expand_threshold::Real = 3 // 4, + shrink_threshold::Real = 1 // 10, #1 // 4, + expand_threshold::Real = 9 // 10, #3 // 4, shrink_factor::Real = 1 // 4, expand_factor::Real = 2 // 1, max_shrink_times::Int = 32) @@ -340,7 +347,9 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion, p3 = convert(eltype(u), 0.0) p4 = convert(eltype(u), 0.0) ϵ = convert(eltype(u), 1.0e-8) - if radius_update_scheme === RadiusUpdateSchemes.Hei + if radius_update_scheme === RadiusUpdateSchemes.NLsolve + p1 = convert(eltype(u), 0.5) + elseif radius_update_scheme === RadiusUpdateSchemes.Hei step_threshold = convert(eltype(u), 0.0) shrink_threshold = convert(eltype(u), 0.25) expand_threshold = convert(eltype(u), 0.25) @@ -407,7 +416,7 @@ function perform_step!(cache::TrustRegionCache{true}) cache.stats.njacs += 1 end - linres = dolinsolve(alg.precs, linsolve, A = cache.H, b = _vec(cache.g), + linres = dolinsolve(alg.precs, linsolve, A = J, b = _vec(fu), # cache.H, b = _vec(cache.g), linu = _vec(u_tmp), p = p, reltol = cache.abstol) cache.linsolve = linres.cache @@ -479,7 +488,7 @@ function trust_region_step!(cache::TrustRegionCache) # Compute the ratio of the actual reduction to the predicted reduction. cache.r = -(loss - cache.loss_new) / (dot(step_size, g) + dot(step_size, H, step_size) / 2) - @unpack r = cache + @unpack r = cache if radius_update_scheme === RadiusUpdateSchemes.Simple # Update the trust region radius. @@ -508,6 +517,30 @@ function trust_region_step!(cache::TrustRegionCache) cache.force_stop = true end + elseif radius_update_scheme === RadiusUpdateSchemes.NLsolve + # accept/reject decision + if r > cache.step_threshold # accept + take_step!(cache) + cache.loss = cache.loss_new + cache.make_new_J = true + else # reject + cache.make_new_J = false + end + + # trust region update + if r < cache.shrink_threshold # default 1 // 10 + cache.trust_r *= cache.shrink_factor # default 1 // 2 + elseif r >= cache.expand_threshold # default 9 // 10 + cache.trust_r = cache.expand_factor * norm(cache.step_size) # default 2 + elseif r >= cache.p1 # default 1 // 2 + cache.trust_r = max(cache.trust_r, cache.expand_factor * norm(cache.step_size)) + end + + # convergence test + if iszero(cache.fu) || cache.internalnorm(cache.fu) < cache.abstol + cache.force_stop = true + end + elseif radius_update_scheme === RadiusUpdateSchemes.Hei if r > cache.step_threshold take_step!(cache)