From 5eb49ef52ad4b6ece1c573cd7ecd964a81508575 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 13 Nov 2023 23:20:11 -0500 Subject: [PATCH] Use Zygote for LineSearch if loaded --- src/linesearch.jl | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/linesearch.jl b/src/linesearch.jl index a2b396b06..71a85be31 100644 --- a/src/linesearch.jl +++ b/src/linesearch.jl @@ -1,5 +1,5 @@ """ - LineSearch(method = Static(), autodiff = AutoFiniteDiff(), alpha = true) + LineSearch(method = nothing, autodiff = nothing, alpha = true) Wrapper over algorithms from [LineSeaches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl/). Allows automatic @@ -13,7 +13,7 @@ differentiation for fast Vector Jacobian Products. - `autodiff`: the automatic differentiation backend to use for the line search. Defaults to `AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP. `AutoZygote()` will be faster in most cases, but it requires `Zygote.jl` to be manually - installed and loaded + installed and loaded. - `alpha`: the initial step size to use. Defaults to `true` (which is equivalent to `1`). """ @concrete struct LineSearch @@ -22,7 +22,7 @@ differentiation for fast Vector Jacobian Products. α end -function LineSearch(; method = nothing, autodiff = AutoFiniteDiff(), alpha = true) +function LineSearch(; method = nothing, autodiff = nothing, alpha = true) return LineSearch(method, autodiff, alpha) end @@ -113,12 +113,21 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe g₀ = _mutable_zero(u) - autodiff = if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote) - @warn "Attempting to use Zygote.jl for linesearch on an in-place problem. Falling \ - back to finite differencing." - AutoFiniteDiff() + autodiff = if ls.autodiff === nothing + if !iip && haskey(Base.loaded_modules, + Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote")) + AutoZygote() + else + AutoFiniteDiff() + end else - ls.autodiff + if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote) + @warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \ + Falling back to finite differencing." + AutoFiniteDiff() + else + ls.autodiff + end end function g!(u, fu)