Skip to content

Commit

Permalink
Use Zygote for LineSearch if loaded
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 14, 2023
1 parent 40c6de3 commit 5eb49ef
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions src/linesearch.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
LineSearch(method = Static(), autodiff = AutoFiniteDiff(), alpha = true)
LineSearch(method = nothing, autodiff = nothing, alpha = true)
Wrapper over algorithms from
[LineSeaches.jl](https://github.com/JuliaNLSolvers/LineSearches.jl/). Allows automatic
Expand All @@ -13,7 +13,7 @@ differentiation for fast Vector Jacobian Products.
- `autodiff`: the automatic differentiation backend to use for the line search. Defaults to
`AutoFiniteDiff()`, which means that finite differencing is used to compute the VJP.
`AutoZygote()` will be faster in most cases, but it requires `Zygote.jl` to be manually
installed and loaded
installed and loaded.
- `alpha`: the initial step size to use. Defaults to `true` (which is equivalent to `1`).
"""
@concrete struct LineSearch
Expand All @@ -22,7 +22,7 @@ differentiation for fast Vector Jacobian Products.
α
end

function LineSearch(; method = nothing, autodiff = AutoFiniteDiff(), alpha = true)
function LineSearch(; method = nothing, autodiff = nothing, alpha = true)

Check warning on line 25 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L25

Added line #L25 was not covered by tests
return LineSearch(method, autodiff, alpha)
end

Expand Down Expand Up @@ -113,12 +113,21 @@ function LineSearchesJLCache(ls::LineSearch, f::F, u, p, fu1, IIP::Val{iip}) whe

g₀ = _mutable_zero(u)

autodiff = if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote)
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. Falling \
back to finite differencing."
AutoFiniteDiff()
autodiff = if ls.autodiff === nothing
if !iip && haskey(Base.loaded_modules,

Check warning on line 117 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L116-L117

Added lines #L116 - L117 were not covered by tests
Base.PkgId(Base.UUID("e88e6eb3-aa80-5325-afca-941959d7151f"), "Zygote"))
AutoZygote()

Check warning on line 119 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L119

Added line #L119 was not covered by tests
else
AutoFiniteDiff()

Check warning on line 121 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L121

Added line #L121 was not covered by tests
end
else
ls.autodiff
if iip && (ls.autodiff isa AutoZygote || ls.autodiff isa AutoSparseZygote)
@warn "Attempting to use Zygote.jl for linesearch on an in-place problem. \

Check warning on line 125 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L124-L125

Added lines #L124 - L125 were not covered by tests
Falling back to finite differencing."
AutoFiniteDiff()

Check warning on line 127 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L127

Added line #L127 was not covered by tests
else
ls.autodiff

Check warning on line 129 in src/linesearch.jl

View check run for this annotation

Codecov / codecov/patch

src/linesearch.jl#L129

Added line #L129 was not covered by tests
end
end

function g!(u, fu)
Expand Down

0 comments on commit 5eb49ef

Please sign in to comment.