Skip to content

Commit

Permalink
Fix Aqua error
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikQQY committed Dec 3, 2024
1 parent de6eb96 commit 8bacad1
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 37 deletions.
43 changes: 22 additions & 21 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@ module NonlinearSolveBaseForwardDiffExt

using ADTypes: ADTypes, AutoForwardDiff, AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using CommonSolve: CommonSolve, solve
using CommonSolve: CommonSolve, solve, solve!, init
using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake

using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem,
AbstractNonlinearSolveAlgorithm, Utils, InternalAPI,
AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, Utils, InternalAPI,
AbstractNonlinearSolveCache, NonlinearSolvePolyAlgorithm,
NonlinearSolveForwardDiffCache

const DI = DifferentiationInterface

const GENERAL_SOLVER_TYPES = [
Nothing, AbstractNonlinearSolveAlgorithm, NonlinearSolvePolyAlgorithm
Nothing, NonlinearSolvePolyAlgorithm
]

const DualNonlinearProblem = NonlinearProblem{
Expand Down Expand Up @@ -135,24 +135,16 @@ for algType in GENERAL_SOLVER_TYPES
end
end

@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
cache
prob
alg
p
values_p
partials_p
end

function InternalAPI.reinit!(
cache::NonlinearSolveForwardDiffCache, args...;
p = cache.p, u0 = NonlinearSolveBase.get_u(cache.cache), kwargs...
)
InternalAPI.reinit!(
cache.cache; p = nodual_value(p), u0 = nodual_value(u0), kwargs...
cache.cache; p = NonlinearSolveBase.nodual_value(p),
u0 = NonlinearSolveBase.nodual_value(u0), kwargs...
)
cache.p = p
cache.values_p = nodual_value(p)
cache.values_p = NonlinearSolveBase.nodual_value(p)
cache.partials_p = ForwardDiff.partials(p)
return cache
end
Expand All @@ -161,8 +153,8 @@ for algType in GENERAL_SOLVER_TYPES
@eval function SciMLBase.__init(
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
p = NonlinearSolveBase.nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = NonlinearSolveBase.nodual_value(prob.u0), p)
cache = init(newprob, alg, args...; kwargs...)
return NonlinearSolveForwardDiffCache(
cache, newprob, alg, prob.p, p, ForwardDiff.partials(prob.p)
Expand Down Expand Up @@ -196,8 +188,17 @@ function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
)
end

nodual_value(x) = x
nodual_value(x::Dual) = ForwardDiff.value(x)
nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)
NonlinearSolveBase.nodual_value(x) = x
NonlinearSolveBase.nodual_value(x::Dual) = ForwardDiff.value(x)
NonlinearSolveBase.nodual_value(x::AbstractArray{<:Dual}) = map(ForwardDiff.value, x)

"""
pickchunksize(x) = pickchunksize(length(x))
pickchunksize(x::Int)
Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
"""
@inline NonlinearSolveBase.pickchunksize(x) = pickchunksize(length(x))
@inline NonlinearSolveBase.pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)

end
2 changes: 2 additions & 0 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ include("descent/geodesic_acceleration.jl")

include("solve.jl")

include("forward_diff.jl")

# Unexported Public API
@compat(public, (L2_NORM, Linf_NORM, NAN_CHECK, UNITLESS_ABS2, get_tolerance))
@compat(public, (nonlinearsolve_forwarddiff_solve, nonlinearsolve_dual_solution))
Expand Down
9 changes: 0 additions & 9 deletions lib/NonlinearSolveBase/src/common_defaults.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,3 @@ function get_tolerance(::Union{StaticArray, Number}, ::Nothing, ::Type{T}) where
# Rational numbers can throw an error if used inside GPU Kernels
return T(real(oneunit(T)) * (eps(real(one(T)))^(real(T)(0.8))))
end

"""
pickchunksize(x) = pickchunksize(length(x))
pickchunksize(x::Int)
Determine the chunk size for ForwardDiff and PolyesterForwardDiff based on the input length.
"""
@inline pickchunksize(x) = pickchunksize(length(x))
@inline pickchunksize(x::Int) = ForwardDiff.pickchunksize(x)
8 changes: 8 additions & 0 deletions lib/NonlinearSolveBase/src/forward_diff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
@concrete mutable struct NonlinearSolveForwardDiffCache <: AbstractNonlinearSolveCache
cache
prob
alg
p
values_p
partials_p
end
2 changes: 2 additions & 0 deletions lib/NonlinearSolveBase/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ function nonlinearsolve_dual_solution end
function nonlinearsolve_∂f_∂p end
function nonlinearsolve_∂f_∂u end
function nlls_generate_vjp_function end
function nodual_value end
function pickchunksize end

# Nonlinear Solve Termination Conditions
abstract type AbstractNonlinearTerminationMode end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using NonlinearSolveBase: NonlinearSolveBase, AbstractNonlinearSolveAlgorithm,
Utils, InternalAPI, get_timer_output, @static_timeit,
update_trace!, L2_NORM, NonlinearSolvePolyAlgorithm,
NewtonDescent, DampedNewtonDescent, GeodesicAcceleration,
Dogleg
Dogleg, NonlinearSolveForwardDiffCache
using SciMLBase: SciMLBase, AbstractNonlinearProblem, NLStats, ReturnCode,
NonlinearFunction,
NonlinearLeastSquaresProblem, NonlinearProblem, NoSpecialize
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
module NonlinearSolveQuasiNewtonForwardDiffExt

using CommonSolve: CommonSolve, solve
using CommonSolve: CommonSolve, init
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem

using NonlinearSolveBase: NonlinearSolveBase

Expand Down
1 change: 1 addition & 0 deletions lib/NonlinearSolveSpectralMethods/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ CommonSolve = "0.2.4"
ConcreteStructs = "0.2.3"
DiffEqBase = "6.158.3"
ExplicitImports = "1.5"
ForwardDiff = "0.10.36"
Hwloc = "3"
InteractiveUtils = "<0.0.1, 1"
LineSearch = "0.1.4"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
module NonlinearSolveSpectralMethodsForwardDiffExt

using CommonSolve: CommonSolve, solve
using CommonSolve: CommonSolve, init
using ForwardDiff: ForwardDiff, Dual
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake
using SciMLBase: SciMLBase, NonlinearProblem, NonlinearLeastSquaresProblem

using NonlinearSolveBase: NonlinearSolveBase

Expand Down

0 comments on commit 8bacad1

Please sign in to comment.