From e8af6a0a7d38e92fc0822a2a8878a775054754c2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 22 Nov 2023 11:55:56 -0500 Subject: [PATCH] Avoid Runtime Checks for Zygote Being loaded --- .JuliaFormatter.toml | 1 + Project.toml | 2 ++ ext/NonlinearSolveZygoteExt.jl | 7 +++++++ src/NonlinearSolve.jl | 3 +++ src/extension_algs.jl | 21 +++++++++++++-------- src/jacobian.jl | 8 ++------ src/linesearch.jl | 3 +-- src/pseudotransient.jl | 1 + src/utils.jl | 1 + 9 files changed, 31 insertions(+), 16 deletions(-) create mode 100644 ext/NonlinearSolveZygoteExt.jl diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index 320e0c073..1768a1a7f 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,3 +1,4 @@ style = "sciml" format_markdown = true annotate_untyped_fields_with_any = false +format_docstrings = true diff --git a/Project.toml b/Project.toml index bab6ce5a6..37513ae88 100644 --- a/Project.toml +++ b/Project.toml @@ -30,11 +30,13 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce" LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] NonlinearSolveBandedMatricesExt = "BandedMatrices" NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt" NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim" +NonlinearSolveZygoteExt = "Zygote" [compat] ADTypes = "0.2" diff --git a/ext/NonlinearSolveZygoteExt.jl b/ext/NonlinearSolveZygoteExt.jl new file mode 100644 index 000000000..d58faabbd --- /dev/null +++ b/ext/NonlinearSolveZygoteExt.jl @@ -0,0 +1,7 @@ +module NonlinearSolveZygoteExt + +import NonlinearSolve, Zygote + +NonlinearSolve.is_extension_loaded(::Val{:Zygote}) = true + +end diff --git a/src/NonlinearSolve.jl b/src/NonlinearSolve.jl index 58d10b290..369de3669 100644 --- a/src/NonlinearSolve.jl +++ b/src/NonlinearSolve.jl @@ -38,6 +38,9 @@ import DiffEqBase: AbstractNonlinearTerminationMode, const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences, ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode} +# Type-Inference Friendly Check for Extension Loading +is_extension_loaded(::Val) = false + abstract type AbstractNonlinearSolveLineSearchAlgorithm end abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end diff --git a/src/extension_algs.jl b/src/extension_algs.jl index c4e56b6a5..8f9ed4400 100644 --- a/src/extension_algs.jl +++ b/src/extension_algs.jl @@ -8,13 +8,15 @@ for solving `NonlinearLeastSquaresProblem`. ## Arguments: -- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`. -- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If - `nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based - on the Jacobian structure. -- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`. + - `alg`: Algorithm to use. Can be `:lm` or `:dogleg`. + - `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If `nothing`, + then `LeastSquaresOptim.jl` will choose the best linear solver based on the Jacobian + structure. + - `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or + `:forward`. !!! note + This algorithm is only available if `LeastSquaresOptim.jl` is installed. """ struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm @@ -36,21 +38,24 @@ end """ FastLevenbergMarquardtJL(linsolve = :cholesky) -Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving -`NonlinearLeastSquaresProblem`. +Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) +for solving `NonlinearLeastSquaresProblem`. !!! warning + This is not really the fastest solver. It is called that since the original package is called "Fast". `LevenbergMarquardt()` is almost always a better choice. !!! warning + This algorithm requires the jacobian function to be provided! ## Arguments: -- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`. + - `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`. !!! note + This algorithm is only available if `FastLevenbergMarquardt.jl` is installed. """ @concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm diff --git a/src/jacobian.jl b/src/jacobian.jl index 6fef600af..41c7319a1 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -166,12 +166,8 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf) # Short circuit if we see that FiniteDiff was used for J computation jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff # Check if Zygote is loaded then use Zygote else use FiniteDiff - if haskey(Base.loaded_modules, - Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote")) - return AutoZygote() - else - return AutoFiniteDiff() - end + is_extension_loaded(Val{:Zygote}()) && return AutoZygote() + return AutoFiniteDiff() end else ad = __get_nonsparse_ad(vjp_autodiff) diff --git a/src/linesearch.jl b/src/linesearch.jl index c9e87a4cb..d67ac978c 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -114,8 +114,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe g₀ = _mutable_zero(u) autodiff = if ls.autodiff === nothing - if !iip && haskey(Base.loaded_modules, - Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote")) + if !iip && is_extension_loaded(Val{:Zygote}()) AutoZygote() else AutoFiniteDiff() diff --git a/src/pseudotransient.jl b/src/pseudotransient.jl index 5da1375d6..b343138de 100644 --- a/src/pseudotransient.jl +++ b/src/pseudotransient.jl @@ -12,6 +12,7 @@ the time-stepping and algorithm, please see the paper: SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S106482750241044X) ### Keyword Arguments + - `autodiff`: determines the backend used for the Jacobian. Note that this argument is ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to `nothing` which means that a default is selected according to the problem specification! diff --git a/src/utils.jl b/src/utils.jl index 9d96e7b75..c5161df7c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -23,6 +23,7 @@ Construct the AD type from the arguments. This is mostly needed for compatibilit code. !!! warning + `chunk_size`, `standardtag`, `diff_type`, and `autodiff::Union{Val, Bool}` are deprecated and will be removed in v3. Update your code to directly specify `autodiff=`.