Skip to content

Commit

Permalink
Avoid Runtime Checks for Zygote Being loaded
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 22, 2023
1 parent 6c52956 commit e8af6a0
Show file tree
Hide file tree
Showing 9 changed files with 31 additions and 16 deletions.
1 change: 1 addition & 0 deletions .JuliaFormatter.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
style = "sciml"
format_markdown = true
annotate_untyped_fields_with_any = false
format_docstrings = true
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
FastLevenbergMarquardt = "7a0df574-e128-4d35-8cbd-3d84502bf7ce"
LeastSquaresOptim = "0fc2ff8b-aaa3-5acd-a817-1944a5e08891"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
NonlinearSolveBandedMatricesExt = "BandedMatrices"
NonlinearSolveFastLevenbergMarquardtExt = "FastLevenbergMarquardt"
NonlinearSolveLeastSquaresOptimExt = "LeastSquaresOptim"
NonlinearSolveZygoteExt = "Zygote"

[compat]
ADTypes = "0.2"
Expand Down
7 changes: 7 additions & 0 deletions ext/NonlinearSolveZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
module NonlinearSolveZygoteExt

import NonlinearSolve, Zygote

NonlinearSolve.is_extension_loaded(::Val{:Zygote}) = true

end
3 changes: 3 additions & 0 deletions src/NonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ import DiffEqBase: AbstractNonlinearTerminationMode,
const AbstractSparseADType = Union{ADTypes.AbstractSparseFiniteDifferences,
ADTypes.AbstractSparseForwardMode, ADTypes.AbstractSparseReverseMode}

# Type-Inference Friendly Check for Extension Loading
is_extension_loaded(::Val) = false

Check warning on line 42 in src/NonlinearSolve.jl

View check run for this annotation

Codecov / codecov/patch

src/NonlinearSolve.jl#L42

Added line #L42 was not covered by tests

abstract type AbstractNonlinearSolveLineSearchAlgorithm end

abstract type AbstractNonlinearSolveAlgorithm <: AbstractNonlinearAlgorithm end
Expand Down
21 changes: 13 additions & 8 deletions src/extension_algs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ for solving `NonlinearLeastSquaresProblem`.
## Arguments:
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If
`nothing`, then `LeastSquaresOptim.jl` will choose the best linear solver based
on the Jacobian structure.
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or `:forward`.
- `alg`: Algorithm to use. Can be `:lm` or `:dogleg`.
- `linsolve`: Linear solver to use. Can be `:qr`, `:cholesky` or `:lsmr`. If `nothing`,
then `LeastSquaresOptim.jl` will choose the best linear solver based on the Jacobian
structure.
- `autodiff`: Automatic differentiation / Finite Differences. Can be `:central` or
`:forward`.
!!! note
This algorithm is only available if `LeastSquaresOptim.jl` is installed.
"""
struct LeastSquaresOptimJL{alg, linsolve} <: AbstractNonlinearSolveAlgorithm
Expand All @@ -36,21 +38,24 @@ end
"""
FastLevenbergMarquardtJL(linsolve = :cholesky)
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl) for solving
`NonlinearLeastSquaresProblem`.
Wrapper over [FastLevenbergMarquardt.jl](https://github.com/kamesy/FastLevenbergMarquardt.jl)
for solving `NonlinearLeastSquaresProblem`.
!!! warning
This is not really the fastest solver. It is called that since the original package
is called "Fast". `LevenbergMarquardt()` is almost always a better choice.
!!! warning
This algorithm requires the jacobian function to be provided!
## Arguments:
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
- `linsolve`: Linear solver to use. Can be `:qr` or `:cholesky`.
!!! note
This algorithm is only available if `FastLevenbergMarquardt.jl` is installed.
"""
@concrete struct FastLevenbergMarquardtJL{linsolve} <: AbstractNonlinearSolveAlgorithm
Expand Down
8 changes: 2 additions & 6 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,8 @@ function __concrete_vjp_autodiff(vjp_autodiff, jvp_autodiff, uf)
# Short circuit if we see that FiniteDiff was used for J computation
jvp_autodiff isa AutoFiniteDiff && return jvp_autodiff
# Check if Zygote is loaded then use Zygote else use FiniteDiff
if haskey(Base.loaded_modules,
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
return AutoZygote()
else
return AutoFiniteDiff()
end
is_extension_loaded(Val{:Zygote}()) && return AutoZygote()
return AutoFiniteDiff()

Check warning on line 170 in src/jacobian.jl

View check run for this annotation

Codecov / codecov/patch

src/jacobian.jl#L170

Added line #L170 was not covered by tests
end
else
ad = __get_nonsparse_ad(vjp_autodiff)
Expand Down
3 changes: 1 addition & 2 deletions src/linesearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe
g₀ = _mutable_zero(u)

autodiff = if ls.autodiff === nothing
if !iip && haskey(Base.loaded_modules,
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
if !iip && is_extension_loaded(Val{:Zygote}())
AutoZygote()
else
AutoFiniteDiff()
Expand Down
1 change: 1 addition & 0 deletions src/pseudotransient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ the time-stepping and algorithm, please see the paper:
SIAM Journal on Scientific Computing,25, 553-569.](https://doi.org/10.1137/S106482750241044X)
### Keyword Arguments
- `autodiff`: determines the backend used for the Jacobian. Note that this argument is
ignored if an analytical Jacobian is passed, as that will be used instead. Defaults to
`nothing` which means that a default is selected according to the problem specification!
Expand Down
1 change: 1 addition & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ Construct the AD type from the arguments. This is mostly needed for compatibilit
code.
!!! warning
`chunk_size`, `standardtag`, `diff_type`, and `autodiff::Union{Val, Bool}` are
deprecated and will be removed in v3. Update your code to directly specify
`autodiff=<ADTypes>`.
Expand Down

0 comments on commit e8af6a0

Please sign in to comment.