Skip to content

Commit

Permalink
fix: exotic types
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 7, 2024
1 parent 4d9c30e commit 4c28213
Show file tree
Hide file tree
Showing 11 changed files with 71 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using SimpleNonlinearSolve: SimpleNonlinearSolve, simplenonlinearsolve_solve_up,
solve_adjoint

function ChainRulesCore.rrule(::typeof(simplenonlinearsolve_solve_up),
prob::Union{InternalNonlinearProblem, NonlinearLeastSquaresProblem},
prob::Union{ImmutableNonlinearProblem, NonlinearLeastSquaresProblem},
sensealg, u0, u0_changed, p, p_changed, alg, args...; kwargs...)
out, ∇internal = solve_adjoint(
prob, sensealg, u0, p, ChainRulesOriginator(), alg, args...; kwargs...)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ using ReverseDiff: ReverseDiff, TrackedArray, TrackedReal
using SciMLBase: ReverseDiffOriginator, NonlinearLeastSquaresProblem, remake

using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint
import SimpleNonlinearSolve: simplenonlinearsolve_solve_up

for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem)
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any)
for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)]
@eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
@eval function simplenonlinearsolve_solve_up(
prob::$(pType), sensealg, u0::$(uT), u0_changed,
p::$(pT), p_changed, alg, args...; kwargs...)
return ReverseDiff.track(SimpleNonlinearSolve.simplenonlinearsolve_solve_up,
Expand All @@ -19,7 +20,7 @@ for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem)
end
end

@eval ReverseDiff.@grad function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
@eval ReverseDiff.@grad function simplenonlinearsolve_solve_up(
tprob::$(pType), sensealg, tu0, u0_changed,
tp, p_changed, alg, args...; kwargs...)
u0, p = ReverseDiff.value(tu0), ReverseDiff.value(tp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ using Tracker: Tracker, TrackedArray, TrackedReal

using SimpleNonlinearSolve: SimpleNonlinearSolve, solve_adjoint

for pType in (InternalNonlinearProblem, NonlinearLeastSquaresProblem)
for pType in (ImmutableNonlinearProblem, NonlinearLeastSquaresProblem)
aTypes = (TrackedArray, AbstractArray{<:TrackedReal}, Any)
for (uT, pT) in collect(Iterators.product(aTypes, aTypes))[1:(end - 1)]
@eval function SimpleNonlinearSolve.simplenonlinearsolve_solve_up(
Expand Down
7 changes: 4 additions & 3 deletions lib/SimpleNonlinearSolve/src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,20 @@ using SciMLBase: AbstractNonlinearAlgorithm, NonlinearProblem, ReturnCode
using StaticArraysCore: StaticArray, SArray, SVector, MArray

# AD Dependencies
using ADTypes: AbstractADType, AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
using ADTypes: AutoFiniteDiff, AutoForwardDiff, AutoPolyesterForwardDiff
using DifferentiationInterface: DifferentiationInterface
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff

using BracketingNonlinearSolve: Alefeld, Bisection, Brent, Falsi, ITP, Ridder
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, get_tolerance,
L2_NORM
using NonlinearSolveBase: NonlinearSolveBase, ImmutableNonlinearProblem, L2_NORM

const DI = DifferentiationInterface

abstract type AbstractSimpleNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end

const safe_similar = NonlinearSolveBase.Utils.safe_similar

is_extension_loaded(::Val) = false

include("utils.jl")
Expand Down
6 changes: 3 additions & 3 deletions lib/SimpleNonlinearSolve/src/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ function SciMLBase.__solve(

strait = setindex_trait(x)

A = strait isa CanSetindex ? similar(x, length(x), length(x)) : x
Aaᵢ = strait isa CanSetindex ? similar(x, length(x)) : x
cᵢ = strait isa CanSetindex ? similar(x) : x
A = strait isa CanSetindex ? safe_similar(x, length(x), length(x)) : x
Aaᵢ = strait isa CanSetindex ? safe_similar(x, length(x)) : x
cᵢ = strait isa CanSetindex ? safe_similar(x) : x

for _ in 1:maxiters
fx, J, H = Utils.compute_jacobian_and_hessian(autodiff, prob, fx, x)
Expand Down
6 changes: 3 additions & 3 deletions lib/SimpleNonlinearSolve/src/lbroyden.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ end
return :(return SVector{$N, $T}(($(getcalls...))))
end

lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = similar(x, threshold)
lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = safe_similar(x, threshold)
function lbroyden_threshold_cache(x::StaticArray, ::Val{threshold}) where {threshold}
return zeros(MArray{Tuple{threshold}, eltype(x)})
end
Expand All @@ -327,7 +327,7 @@ end
end
end
function init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold}
Vᵀ = similar(u, threshold, length(u))
U = similar(u, length(fu), threshold)
Vᵀ = safe_similar(u, threshold, length(u))
U = safe_similar(u, length(fu), threshold)
return U, Vᵀ
end
4 changes: 2 additions & 2 deletions lib/SimpleNonlinearSolve/src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ function SciMLBase.__solve(
NonlinearSolveBase.select_jacobian_autodiff(prob, alg.autodiff)

@bb xo = similar(x)
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? similar(fx) :
nothing
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
safe_similar(fx) : nothing
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)

Expand Down
4 changes: 2 additions & 2 deletions lib/SimpleNonlinearSolve/src/trust_region.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ function SciMLBase.__solve(prob::ImmutableNonlinearProblem, alg::SimpleTrustRegi
norm_fx = L2_NORM(fx)

@bb xo = copy(x)
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ? similar(fx) :
nothing
fx_cache = (SciMLBase.isinplace(prob) && !SciMLBase.has_jac(prob.f)) ?
safe_similar(fx) : nothing
jac_cache = Utils.prepare_jacobian(prob, autodiff, fx_cache, x)
J = Utils.compute_jacobian!!(nothing, prob, autodiff, fx_cache, x, jac_cache)

Expand Down
5 changes: 2 additions & 3 deletions lib/SimpleNonlinearSolve/src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
module Utils

using ADTypes: AbstractADType, AutoForwardDiff, AutoFiniteDiff, AutoPolyesterForwardDiff
using ArrayInterface: ArrayInterface
using ConcreteStructs: @concrete
using DifferentiationInterface: DifferentiationInterface, Constant
Expand Down Expand Up @@ -164,7 +163,7 @@ function compute_jacobian!!(J, prob, autodiff, fx, x, extras)
if J === nothing
if extras isa AnalyticJacobian
if SciMLBase.isinplace(prob.f)
J = similar(fx, length(fx), length(x))
J = safe_similar(fx, length(fx), length(x))
prob.f.jac(J, x, prob.p)
return J
else
Expand Down Expand Up @@ -219,7 +218,7 @@ end
function compute_jacobian_and_hessian(autodiff, prob, fx, x)
if SciMLBase.isinplace(prob)
jac_fn = @closure (u, p) -> begin
du = similar(fx, promote_type(eltype(fx), eltype(u)))
du = safe_similar(fx, promote_type(eltype(fx), eltype(u)))
return DI.jacobian(prob.f, du, autodiff, u, Constant(p))
end
J, H = DI.value_and_jacobian(jac_fn, autodiff, x, Constant(prob.p))
Expand Down
30 changes: 30 additions & 0 deletions lib/SimpleNonlinearSolve/test/core/exotic_type_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
@testitem "BigFloat Support" tags=[:core] begin
using SimpleNonlinearSolve, LinearAlgebra

fn_iip = NonlinearFunction{true}((du, u, p) -> du .= u .* u .- p)
fn_oop = NonlinearFunction{false}((u, p) -> u .* u .- p)

u0 = BigFloat[1.0, 1.0, 1.0]
prob_iip_bf = NonlinearProblem{true}(fn_iip, u0, BigFloat(2))
prob_oop_bf = NonlinearProblem{false}(fn_oop, u0, BigFloat(2))

@testset "$(nameof(typeof(alg)))" for alg in (
SimpleNewtonRaphson(),
SimpleBroyden(),
SimpleKlement(),
SimpleDFSane(),
SimpleTrustRegion(),
SimpleLimitedMemoryBroyden(),
SimpleHalley()
)
sol = solve(prob_oop_bf, alg)
@test maximum(abs, sol.resid) < 1e-6
@test SciMLBase.successful_retcode(sol.retcode)

alg isa SimpleHalley && continue

sol = solve(prob_iip_bf, alg)
@test maximum(abs, sol.resid) < 1e-6
@test SciMLBase.successful_retcode(sol.retcode)
end
end
19 changes: 19 additions & 0 deletions lib/SimpleNonlinearSolve/test/core/qa_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
@testitem "Aqua" tags=[:core] begin
using Aqua, SimpleNonlinearSolve

Aqua.test_all(SimpleNonlinearSolve; piracies = false, ambiguities = false)
Aqua.test_piracies(SimpleNonlinearSolve;
treat_as_own = [
NonlinearProblem, NonlinearLeastSquaresProblem, IntervalNonlinearProblem])
Aqua.test_ambiguities(SimpleNonlinearSolve; recursive = false)
end

@testitem "Explicit Imports" tags=[:core] begin
import ReverseDiff, Tracker, StaticArrays, Zygote
using ExplicitImports, SimpleNonlinearSolve

@test check_no_implicit_imports(
SimpleNonlinearSolve; skip = (Base, Core, SciMLBase)) === nothing
@test check_no_stale_explicit_imports(SimpleNonlinearSolve) === nothing
@test check_all_qualified_accesses_via_owners(SimpleNonlinearSolve) === nothing
end

0 comments on commit 4c28213

Please sign in to comment.