Skip to content

Commit

Permalink
Merge pull request #308 from JuliaDiff/os/remove-tricks
Browse files Browse the repository at this point in the history
Remove Tricks
  • Loading branch information
ChrisRackauckas authored Oct 10, 2024
2 parents 4e36edd + 4601a51 commit 8d30c5e
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 21 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
VertexSafeGraphs = "19fa3120-7c27-5ec5-8db8-b0b0aa330d6f"

Expand Down Expand Up @@ -64,7 +63,6 @@ SparseArrays = "<0.0.1, 1"
StaticArrayInterface = "1.3"
StaticArrays = "1"
Symbolics = "5.5, 6"
Tricks = "0.1.6"
UnPack = "1"
VertexSafeGraphs = "0.2"
Zygote = "0.6"
Expand Down
5 changes: 2 additions & 3 deletions ext/SparseDiffToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import SparseDiffTools: SparseDiffTools, DeivVecTag, AutoDiffVJP, __test_backend
import ForwardDiff: ForwardDiff, Dual, partials
import SciMLOperators: update_coefficients, update_coefficients!
import Setfield: @set!
import Tricks: static_hasmethod

import SparseDiffTools: numback_hesvec!,
numback_hesvec, autoback_hesvec!, autoback_hesvec, auto_vecjac!,
Expand Down Expand Up @@ -101,7 +100,7 @@ end

# VJP methods
function auto_vecjac!(du, f::F, x, v) where {F}
!static_hasmethod(f, typeof((x,))) &&
!hasmethod(f, typeof((x,))) &&
error("For inplace function use autodiff = AutoFiniteDiff()")
du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du))
end
Expand All @@ -113,7 +112,7 @@ end

# overload operator interface
function SparseDiffTools._vecjac(f::F, _, u, autodiff::AutoZygote) where {F}
!static_hasmethod(f, typeof((u,))) &&
!hasmethod(f, typeof((u,))) &&
error("For inplace function use autodiff = AutoFiniteDiff()")
pullback = Zygote.pullback(f, u)
return AutoDiffVJP(f, u, (), autodiff, pullback)
Expand Down
1 change: 0 additions & 1 deletion src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ using SciMLOperators, LinearAlgebra, Random
import DataStructures: DisjointSets, find_root!, union!
import SciMLOperators: update_coefficients, update_coefficients!
import Setfield: @set!
import Tricks: Tricks, static_hasmethod

import PackageExtensionCompat: @require_extensions
function __init__()
Expand Down
16 changes: 8 additions & 8 deletions src/differentiation/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ function JacFunctionWrapper(f::F, fu_, u, p, t;
if deporder
# Check this first else we were breaking things
# In the next breaking release, we will fix the ordering of the checks
iip = static_hasmethod(f, typeof((fu, u)))
oop = static_hasmethod(f, typeof((u,)))
iip = hasmethod(f, typeof((fu, u)))
oop = hasmethod(f, typeof((u,)))
if iip || oop
if p !== nothing || t !== nothing
Base.depwarn(
Expand All @@ -74,17 +74,17 @@ function JacFunctionWrapper(f::F, fu_, u, p, t;
end

if t !== nothing
iip = static_hasmethod(f, typeof((fu, u, p, t)))
oop = static_hasmethod(f, typeof((u, p, t)))
iip = hasmethod(f, typeof((fu, u, p, t)))
oop = hasmethod(f, typeof((u, p, t)))
if !iip && !oop
throw(ArgumentError("""`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)`
not defined for `f`!"""))
end
return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f,
fu, p, t)
elseif p !== nothing
iip = static_hasmethod(f, typeof((fu, u, p)))
oop = static_hasmethod(f, typeof((u, p)))
iip = hasmethod(f, typeof((fu, u, p)))
oop = hasmethod(f, typeof((u, p)))
if !iip && !oop
throw(ArgumentError("""`p` is provided but `f(u, p)` or `f(fu, u, p)`
not defined for `f`!"""))
Expand All @@ -94,8 +94,8 @@ function JacFunctionWrapper(f::F, fu_, u, p, t;
end

if !deporder
iip = static_hasmethod(f, typeof((fu, u)))
oop = static_hasmethod(f, typeof((u,)))
iip = hasmethod(f, typeof((fu, u)))
oop = hasmethod(f, typeof((u,)))
if !iip && !oop
throw(ArgumentError("""`p` is provided but `f(u)` or `f(fu, u)` not defined for
`f`!"""))
Expand Down
12 changes: 6 additions & 6 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ function (L::FwdModeAutoDiffVecProd)(dv, v, p, t)
end

function Base.resize!(L::FwdModeAutoDiffVecProd, n::Integer)
static_hasmethod(resize!, typeof((L.f, n))) && resize!(L.f, n)
hasmethod(resize!, typeof((L.f, n))) && resize!(L.f, n)
resize!(L.u, n)

for v in L.cache
Expand Down Expand Up @@ -304,7 +304,7 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing;

(cache1, cache2, cache3), numauto_hesvec, numauto_hesvec!
elseif autodiff isa AutoZygote
@assert static_hasmethod(autoback_hesvec, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"
@assert hasmethod(autoback_hesvec, typeof((f, u, u))) "To use AutoZygote() AD, first load Zygote with `using Zygote`, or `import Zygote`"

cache1 = Dual{
typeof(ForwardDiff.Tag(tag, eltype(u))), eltype(u), 1
Expand All @@ -316,8 +316,8 @@ function HesVec(f, u::AbstractArray, p = nothing, t = nothing;
error("Set autodiff to either AutoForwardDiff(), AutoZygote(), or AutoFiniteDiff()")
end

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u,)))
outofplace = hasmethod(f, typeof((u,)))
isinplace = hasmethod(f, typeof((u,)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u).")
Expand Down Expand Up @@ -347,8 +347,8 @@ function HesVecGrad(f, u::AbstractArray, p = nothing, t = nothing;
error("Set autodiff to either AutoForwardDiff(), or AutoFiniteDiff()")
end

outofplace = static_hasmethod(f, typeof((u,)))
isinplace = static_hasmethod(f, typeof((u, u)))
outofplace = hasmethod(f, typeof((u,)))
isinplace = hasmethod(f, typeof((u, u)))

if !(isinplace) & !(outofplace)
error("$f must have signature f(u), or f(du, u).")
Expand Down
2 changes: 1 addition & 1 deletion src/differentiation/vecjac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ function (L::AutoDiffVJP{<:AutoFiniteDiff})(dv, v, p, t; VJP_input = nothing)
end

function Base.resize!(L::AutoDiffVJP, n::Integer)
static_hasmethod(resize!, typeof((L.f, n))) && resize!(L.f, n)
hasmethod(resize!, typeof((L.f, n))) && resize!(L.f, n)
resize!(L.u, n)
for v in L.cache
resize!(v, n)
Expand Down

0 comments on commit 8d30c5e

Please sign in to comment.