Skip to content

Commit

Permalink
Use allowed_setindex! for cache
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jan 17, 2024
1 parent 4ad919a commit e675a10
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/SparseDiffTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import ADTypes: AbstractADType, AutoSparseZygote, AbstractSparseForwardMode,
import ForwardDiff: Dual, jacobian, partials, DEFAULT_CHUNK_THRESHOLD
# Array Packages
using ArrayInterface, SparseArrays
import ArrayInterface: matrix_colors
import ArrayInterface: matrix_colors, allowed_setindex!
import StaticArrays
import StaticArrays: StaticArray, SArray, MArray, Size
# Others
Expand Down
21 changes: 18 additions & 3 deletions src/differentiation/vecjac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,21 @@ function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v),
return du
end

# Special Non-Allocating case for StaticArrays
function num_vecjac(f::F, x::SArray, v::SArray, f0 = nothing) where {F}
f0 === nothing ? (_f0 = f(x)) : (_f0 = f0)
vv = reshape(v, axes(_f0))
T = eltype(x)
ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x)))
du = zeros(typeof(x))
for i in 1:length(x)
cache = Base.setindex(x, x[i] + ϵ, i)
f0 = f(cache)
du = Base.setindex(du, (((f0 .- _f0) ./ ϵ)' * vv), i)
end

Check warning on line 29 in src/differentiation/vecjac_products.jl

View check run for this annotation

Codecov / codecov/patch

src/differentiation/vecjac_products.jl#L29

Added line #L29 was not covered by tests
return du
end

function num_vecjac(f::F, x, v, f0 = nothing) where {F}
f0 === nothing ? (_f0 = f(x)) : (_f0 = f0)
vv = reshape(v, axes(_f0))
Expand All @@ -25,10 +40,10 @@ function num_vecjac(f::F, x, v, f0 = nothing) where {F}
cache = similar(x)
copyto!(cache, x)
for i in 1:length(x)
cache[i] += ϵ
cache = allowed_setindex!(cache, x[i] + ϵ, i)
f0 = f(cache)
cache[i] = x[i]
du[i] = (((f0 .- _f0) ./ ϵ)' * vv)[1]
cache = allowed_setindex!(cache, x[i], i)
du = allowed_setindex!(du, (((f0 .- _f0) ./ ϵ)' * vv)[1], i)
end
return vec(du)
end
Expand Down
17 changes: 16 additions & 1 deletion test/test_vecjac_products.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SparseDiffTools, Zygote
using SparseDiffTools, Zygote, ForwardDiff
using LinearAlgebra, Test
using StaticArrays

using Random
Random.seed!(123)
Expand Down Expand Up @@ -170,3 +171,17 @@ L = VecJac(f3_iip, copy(x); autodiff = AutoFiniteDiff(), fu = copy(y))
L = VecJac(f3_oop, copy(x); autodiff = AutoZygote())
@test size(L) == (length(x), length(y))
@test L * y Zygote.jacobian(f3_oop, copy(x))[1]' * y

@info "Testing StaticArrays"

const A_sa = rand(SMatrix{4, 4, Float32})
_f_sa(x) = A_sa * (x .^ 2)

u = rand(SVector{4, Float32})
v = rand(SVector{4, Float32})

J = ForwardDiff.jacobian(_f_sa, u)
Jᵀv_true = J' * v

@test num_vecjac(_f_sa, u, v) isa SArray
@test num_vecjac(_f_sa, u, v) Jᵀv_true atol=1e-3

0 comments on commit e675a10

Please sign in to comment.