From 4c282134d069085f4d1ada673241e60ab8b10f99 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 6 Oct 2024 21:34:58 -0400 Subject: [PATCH] fix: exotic types --- .../SimpleNonlinearSolveChainRulesCoreExt.jl | 2 +- .../ext/SimpleNonlinearSolveReverseDiffExt.jl | 7 +++-- .../ext/SimpleNonlinearSolveTrackerExt.jl | 2 +- .../src/SimpleNonlinearSolve.jl | 7 +++-- lib/SimpleNonlinearSolve/src/halley.jl | 6 ++-- lib/SimpleNonlinearSolve/src/lbroyden.jl | 6 ++-- lib/SimpleNonlinearSolve/src/raphson.jl | 4 +-- lib/SimpleNonlinearSolve/src/trust_region.jl | 4 +-- lib/SimpleNonlinearSolve/src/utils.jl | 5 ++-- .../test/core/exotic_type_tests.jl | 30 +++++++++++++++++++ .../test/core/qa_tests.jl | 19 ++++++++++++ 11 files changed, 71 insertions(+), 21 deletions(-) create mode 100644 lib/SimpleNonlinearSolve/test/core/exotic_type_tests.jl create mode 100644 lib/SimpleNonlinearSolve/test/core/qa_tests.jl diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl index df0bd7573..f56dee537 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveChainRulesCoreExt.jl @@ -8,7 +8,7 @@ using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up, solve_adjoint function ChainRulesCore.rrule(::typeof(simplenonlinearsolve_solve_up), - prob::Union{InternalNonlinearProblem, NonlinearLeastSquaresProblem}, + prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem}, sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...) out, ∇internal = solve_adjoint( prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...) diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl index 1357bec83..7b476b3c5 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveReverseDiffExt.jl @@ -6,11 +6,12 @@ using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint +import SimpleNonlinearSolve: simplenonlinearsolve_solve_up -for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem) +for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any) for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)] - @eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up( + @eval function simplenonlinearsolve_solve_up( prob::$(pType), sensealg, u0::$(uT), u0_changed, p::$(pT), p_changed, alg, args...; kwargs...) return ReverseDiff.track(SimpleNonlinearSolve.simplenonlinearsolve_solve_up, @@ -19,7 +20,7 @@ for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem) end end - @eval ReverseDiff.@grad function SimpleNonlinearSolve.simplenonlinearsolve_solve_up( + @eval ReverseDiff.@grad function simplenonlinearsolve_solve_up( tprob::$(pType), sensealg, tu0, u0_changed, tp, p_changed, alg, args...; kwargs...) u0, p = ReverseDiff.value(tu0), ReverseDiff.value(tp) diff --git a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl index a2fa8ff40..ead5a8e29 100644 --- a/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl +++ b/lib/SimpleNonlinearSolve/ext/SimpleNonlinearSolveTrackerExt.jl @@ -7,7 +7,7 @@ using Tracker: Tracker, TrackedArray, TrackedReal using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint -for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem) +for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem) aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any) for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)] @eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up( diff --git a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl index c8df31e51..0b33c923e 100644 --- a/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl +++ b/lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl @@ -14,19 +14,20 @@ using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode using StaticArraysCore: StaticArray, SArray, SVector, MArray # AD Dependencies -using ADTypes: AbstractADType, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff +using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff using DifferentiationInterface: DifferentiationInterface using FiniteDiff: FiniteDiff using ForwardDiff: ForwardDiff using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder -using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, get_tolerance, - L2_NORM +using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM const DI = DifferentiationInterface abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end +const safe_similar = NonlinearSolveBase.Utils.safe_similar + is_extension_loaded(::Val) = false include("utils.jl") diff --git a/lib/SimpleNonlinearSolve/src/halley.jl b/lib/SimpleNonlinearSolve/src/halley.jl index 23ba1c847..6b8948248 100644 --- a/lib/SimpleNonlinearSolve/src/halley.jl +++ b/lib/SimpleNonlinearSolve/src/halley.jl @@ -41,9 +41,9 @@ function SciMLBase.__solve( strait = setindex_trait(x) - A = strait isa CanSetindex ? similar(x, length(x), length(x)) : x - Aaᵢ = strait isa CanSetindex ? similar(x, length(x)) : x - cᵢ = strait isa CanSetindex ? similar(x) : x + A = strait isa CanSetindex ? safe_similar(x, length(x), length(x)) : x + Aaᵢ = strait isa CanSetindex ? safe_similar(x, length(x)) : x + cᵢ = strait isa CanSetindex ? safe_similar(x) : x for _ in 1:maxiters fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x) diff --git a/lib/SimpleNonlinearSolve/src/lbroyden.jl b/lib/SimpleNonlinearSolve/src/lbroyden.jl index 1ab200f74..ce39ca10d 100644 --- a/lib/SimpleNonlinearSolve/src/lbroyden.jl +++ b/lib/SimpleNonlinearSolve/src/lbroyden.jl @@ -301,7 +301,7 @@ end return :(return SVector{$N, $T}(($(getcalls...)))) end -lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = similar(x, threshold) +lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = safe_similar(x, threshold) function lbroyden_threshold_cache(x::StaticArray, ::Val{threshold}) where {threshold} return zeros(MArray{Tuple{threshold}, eltype(x)}) end @@ -327,7 +327,7 @@ end end end function init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold} - Vᵀ = similar(u, threshold, length(u)) - U = similar(u, length(fu), threshold) + Vᵀ = safe_similar(u, threshold, length(u)) + U = safe_similar(u, length(fu), threshold) return U, Vᵀ end diff --git a/lib/SimpleNonlinearSolve/src/raphson.jl b/lib/SimpleNonlinearSolve/src/raphson.jl index 2af3a825a..763fcf32a 100644 --- a/lib/SimpleNonlinearSolve/src/raphson.jl +++ b/lib/SimpleNonlinearSolve/src/raphson.jl @@ -41,8 +41,8 @@ function SciMLBase.__solve( NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff) @bb xo = similar(x) - fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? similar(fx) : - nothing + fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? + safe_similar(fx) : nothing jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x) J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache) diff --git a/lib/SimpleNonlinearSolve/src/trust_region.jl b/lib/SimpleNonlinearSolve/src/trust_region.jl index 6bf543220..27b210d65 100644 --- a/lib/SimpleNonlinearSolve/src/trust_region.jl +++ b/lib/SimpleNonlinearSolve/src/trust_region.jl @@ -93,8 +93,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi norm_fx = L2_NORM(fx) @bb xo = copy(x) - fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? similar(fx) : - nothing + fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? + safe_similar(fx) : nothing jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x) J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache) diff --git a/lib/SimpleNonlinearSolve/src/utils.jl b/lib/SimpleNonlinearSolve/src/utils.jl index 011788a1c..946c10529 100644 --- a/lib/SimpleNonlinearSolve/src/utils.jl +++ b/lib/SimpleNonlinearSolve/src/utils.jl @@ -1,6 +1,5 @@ module Utils -using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff using ArrayInterface: ArrayInterface using ConcreteStructs: @concrete using DifferentiationInterface: DifferentiationInterface, Constant @@ -164,7 +163,7 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras) if J === nothing if extras isa AnalyticJacobian if SciMLBase.isinplace(prob.f) - J = similar(fx, length(fx), length(x)) + J = safe_similar(fx, length(fx), length(x)) prob.f.jac(J, x, prob.p) return J else @@ -219,7 +218,7 @@ end function compute_jacobian_and_hessian(autodiff, prob, fx, x) if SciMLBase.isinplace(prob) jac_fn = @closure (u, p) -> begin - du = similar(fx, promote_type(eltype(fx), eltype(u))) + du = safe_similar(fx, promote_type(eltype(fx), eltype(u))) return DI.jacobian(prob.f, du, autodiff, u, Constant(p)) end J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p)) diff --git a/lib/SimpleNonlinearSolve/test/core/exotic_type_tests.jl b/lib/SimpleNonlinearSolve/test/core/exotic_type_tests.jl new file mode 100644 index 000000000..e19a3d32e --- /dev/null +++ b/lib/SimpleNonlinearSolve/test/core/exotic_type_tests.jl @@ -0,0 +1,30 @@ +@testitem "BigFloat Support" tags=[:core] begin + using SimpleNonlinearSolve, LinearAlgebra + + fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p) + fn_oop = NonlinearFunction{false}((u, p) -> u .* u .- p) + + u0 = BigFloat[1.0, 1.0, 1.0] + prob_iip_bf = NonlinearProblem{true}(fn_iip, u0, BigFloat(2)) + prob_oop_bf = NonlinearProblem{false}(fn_oop, u0, BigFloat(2)) + + @testset "$(nameof(typeof(alg)))" for alg in ( + SimpleNewtonRaphson(), + SimpleBroyden(), + SimpleKlement(), + SimpleDFSane(), + SimpleTrustRegion(), + SimpleLimitedMemoryBroyden(), + SimpleHalley() + ) + sol = solve(prob_oop_bf, alg) + @test maximum(abs, sol.resid) < 1e-6 + @test SciMLBase.successful_retcode(sol.retcode) + + alg isa SimpleHalley && continue + + sol = solve(prob_iip_bf, alg) + @test maximum(abs, sol.resid) < 1e-6 + @test SciMLBase.successful_retcode(sol.retcode) + end +end diff --git a/lib/SimpleNonlinearSolve/test/core/qa_tests.jl b/lib/SimpleNonlinearSolve/test/core/qa_tests.jl new file mode 100644 index 000000000..d6a5a9b8e --- /dev/null +++ b/lib/SimpleNonlinearSolve/test/core/qa_tests.jl @@ -0,0 +1,19 @@ +@testitem "Aqua" tags=[:core] begin + using Aqua, SimpleNonlinearSolve + + Aqua.test_all(SimpleNonlinearSolve; piracies = false, ambiguities = false) + Aqua.test_piracies(SimpleNonlinearSolve; + treat_as_own = [ + NonlinearProblem, NonlinearLeastSquaresProblem, IntervalNonlinearProblem]) + Aqua.test_ambiguities(SimpleNonlinearSolve; recursive = false) +end + +@testitem "Explicit Imports" tags=[:core] begin + import ReverseDiff, Tracker, StaticArrays, Zygote + using ExplicitImports, SimpleNonlinearSolve + + @test check_no_implicit_imports( + SimpleNonlinearSolve; skip = (Base, Core, SciMLBase)) === nothing + @test check_no_stale_explicit_imports(SimpleNonlinearSolve) === nothing + @test check_all_qualified_accesses_via_owners(SimpleNonlinearSolve) === nothing +end