Skip to content

Commit

Permalink
parameter types should not be converted to eltype(u). For now, defaul…
Browse files Browse the repository at this point in the history
…t to Float64.
  • Loading branch information
FHoltorf committed Sep 26, 2023
1 parent 0e99655 commit 439415b
Showing 1 changed file with 59 additions and 59 deletions.
118 changes: 59 additions & 59 deletions src/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,10 @@ end
fu_new
make_new_J::Bool
r::floatType
p1::floatType
p2::floatType
p3::floatType
p4::floatType
p1::parType
p2::parType
p3::parType
p4::parType
ϵ::floatType
stats::NLStats
end
Expand All @@ -227,23 +227,6 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
uf, linsolve, J, fu2, jac_cache, du = jacobian_caches(alg, f, u, p, Val(iip);
linsolve_kwargs)

radius_update_scheme = alg.radius_update_scheme
max_trust_radius = convert(eltype(u), alg.max_trust_radius)
initial_trust_radius = convert(eltype(u), alg.initial_trust_radius)
step_threshold = convert(eltype(u), alg.step_threshold)
shrink_threshold = convert(eltype(u), alg.shrink_threshold)
expand_threshold = convert(eltype(u), alg.expand_threshold)
shrink_factor = convert(eltype(u), alg.shrink_factor)
expand_factor = convert(eltype(u), alg.expand_factor)

# Set default trust region radius if not specified
if iszero(max_trust_radius)
max_trust_radius = convert(eltype(u), max(norm(fu1), maximum(u) - minimum(u)))
end
if iszero(initial_trust_radius)
initial_trust_radius = convert(eltype(u), max_trust_radius / 11)
end

loss_new = loss
H = zero(J)
g = _mutable_zero(fu1)
Expand All @@ -253,31 +236,50 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
make_new_J = true
r = loss

# set trust region update scheme
radius_update_scheme = alg.radius_update_scheme

# set default type for all trust region parameters
trustType = Float64 #typeof(alg.initial_trust_radius)
max_trust_radius = convert(trustType, alg.max_trust_radius)
if iszero(max_trust_radius)
max_trust_radius = convert(trustType, max(norm(fu), maximum(u) - minimum(u)))
end
initial_trust_radius = convert(trustType, alg.initial_trust_radius)
if iszero(initial_trust_radius)
initial_trust_radius = convert(trustType, max_trust_radius / 11)
end
step_threshold = convert(trustType, alg.step_threshold)
shrink_threshold = convert(trustType, alg.shrink_threshold)
expand_threshold = convert(trustType, alg.expand_threshold)
shrink_factor = convert(trustType, alg.shrink_factor)
expand_factor = convert(trustType, alg.expand_factor)

# Parameters for the Schemes
p1 = convert(eltype(u), 0.0)
p2 = convert(eltype(u), 0.0)
p3 = convert(eltype(u), 0.0)
p4 = convert(eltype(u), 0.0)
ϵ = convert(eltype(u), 1.0e-8)
parType = Float64
p1 = convert(parType, 0.0)
p2 = convert(parType, 0.0)
p3 = convert(parType, 0.0)
p4 = convert(parType, 0.0)
ϵ = convert(typeof(r), 1.0e-8)
if radius_update_scheme === RadiusUpdateSchemes.NLsolve
p1 = convert(eltype(u), 0.5)
p1 = convert(parType, 0.5)
elseif radius_update_scheme === RadiusUpdateSchemes.Hei
step_threshold = convert(eltype(u), 0.0)
shrink_threshold = convert(eltype(u), 0.25)
expand_threshold = convert(eltype(u), 0.25)
p1 = convert(eltype(u), 5.0) # M
p2 = convert(eltype(u), 0.1) # β
p3 = convert(eltype(u), 0.15) # γ1
p4 = convert(eltype(u), 0.15) # γ2
initial_trust_radius = convert(eltype(u), 1.0)
step_threshold = convert(trustType, 0.0)
shrink_threshold = convert(trustType, 0.25)
expand_threshold = convert(trustType, 0.25)
p1 = convert(parType, 5.0) # M
p2 = convert(parType, 0.1) # β
p3 = convert(parType, 0.15) # γ1
p4 = convert(parType, 0.15) # γ2
initial_trust_radius = convert(trustType, 1.0)
elseif radius_update_scheme === RadiusUpdateSchemes.Yuan
step_threshold = convert(eltype(u), 0.0001)
shrink_threshold = convert(eltype(u), 0.25)
expand_threshold = convert(eltype(u), 0.25)
p1 = convert(eltype(u), 2.0) # μ
p2 = convert(eltype(u), 1 / 6) # c5
p3 = convert(eltype(u), 6.0) # c6
p4 = convert(eltype(u), 0.0)
step_threshold = convert(trustType, 0.0001)
shrink_threshold = convert(trustType, 0.25)
expand_threshold = convert(trustType, 0.25)
p1 = convert(parType, 2.0) # μ
p2 = convert(parType, 1 / 6) # c5
p3 = convert(parType, 6.0) # c6
if iip
auto_jacvec!(g, (fu, x) -> f(fu, x, p), u, fu1)
else
Expand All @@ -287,25 +289,23 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::TrustRegion,
g = auto_jacvec(x -> f(x, p), u, fu1)
end
end
initial_trust_radius = convert(eltype(u), p1 * norm(g))
initial_trust_radius = convert(trustType, p1 * norm(g))
elseif radius_update_scheme === RadiusUpdateSchemes.Fan
step_threshold = convert(eltype(u), 0.0001)
shrink_threshold = convert(eltype(u), 0.25)
expand_threshold = convert(eltype(u), 0.75)
p1 = convert(eltype(u), 0.1) # μ
p2 = convert(eltype(u), 1 / 4) # c5
p3 = convert(eltype(u), 12) # c6
p4 = convert(eltype(u), 1.0e18) # M
initial_trust_radius = convert(eltype(u), p1 * (norm(fu1)^0.99))
step_threshold = convert(trustType, 0.0001)
shrink_threshold = convert(trustType, 0.25)
expand_threshold = convert(trustType, 0.75)
p1 = convert(parType, 0.1) # μ
p2 = convert(parType, 0.25) # c5
p3 = convert(parType, 12.0) # c6
p4 = convert(parType, 1.0e18) # M
initial_trust_radius = convert(trustType, p1 * (norm(fu)^0.99))
elseif radius_update_scheme === RadiusUpdateSchemes.Bastin
step_threshold = convert(eltype(u), 0.05)
shrink_threshold = convert(eltype(u), 0.05)
expand_threshold = convert(eltype(u), 0.9)
p1 = convert(eltype(u), 2.5) #alpha_1
p2 = convert(eltype(u), 0.25) # alpha_2
p3 = convert(eltype(u), 0) # not required
p4 = convert(eltype(u), 0) # not required
initial_trust_radius = convert(eltype(u), 1.0)
step_threshold = convert(trustType, 0.05)
shrink_threshold = convert(trustType, 0.05)
expand_threshold = convert(trustType, 0.9)
p1 = convert(parType, 2.5) #alpha_1
p2 = convert(parType, 0.25) # alpha_2
initial_trust_radius = convert(trustType, 1.0)
end

return TrustRegionCache{iip}(f, alg, u_prev, u, fu_prev, fu1, fu2, p, uf, linsolve, J,
Expand Down

0 comments on commit 439415b

Please sign in to comment.