Skip to content

Commit

Permalink
Merge pull request #1488 from lxvm/buffer
Browse files Browse the repository at this point in the history
fix broadcasting into buffers
  • Loading branch information
ToucheSir authored Sep 23, 2024
2 parents 512184d + 9df7146 commit fe393b0
Show file tree
Hide file tree
Showing 5 changed files with 131 additions and 15 deletions.
18 changes: 17 additions & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U)

# With mismatched lengths, map stops early. With mismatched shapes, it makes a vector.
# So we keep axes(x) to restore gradient dx to its full length & correct shape.
_tryaxes(x) = axes(x)
_tryaxes(x) = (s = Base.IteratorSize(x); s isa Base.HasShape ? axes(x) : s isa Base.HasLength ? (Base.OneTo(length(x)),) : throw(ArgumentError("iterator size must be finite")))
_tryaxes(x::AbstractArray) = axes(x)
_tryaxes(x::Tuple) = Val(length(x))
_tryaxes(x::Number) = x
_restore(dx::AbstractArray{Nothing}, ax::Tuple) = similar(dx, ax)
Expand Down Expand Up @@ -320,6 +321,21 @@ end
collect(z), collect_zip_pullback
end

takefunc(itr, dy) = _restore(dy, _tryaxes(itr))

@adjoint function Iterators.take(itr, n)
take_pullback(::AbstractArray{Nothing}) = nothing
take_pullback(dy::NamedTuple{(:xs,:n)}) = (dy.xs, dy.n)
take_pullback(dy::NamedTuple{(:n,:xs)}) = (dy.xs, dy.n)
take_pullback(dy::AbstractArray) = (takefunc(itr, dy), nothing)
Iterators.take(itr, n), take_pullback
end

@adjoint function Base.collect(t::Iterators.Take)
collect_take_pullback(dy) = ((xs=takefunc(t.xs, dy), n=nothing),)
collect(t), collect_take_pullback
end

# Reductions
@adjoint function sum(xs::AbstractArray; dims = :)
if dims === (:)
Expand Down
44 changes: 34 additions & 10 deletions src/lib/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,16 @@ grad_mut(cx::Context, b::Buffer{T}, ::Type{S}=Union{}) where {T<:Number, S<:Numb
@non_differentiable Buffer(::Any...)

@adjoint function getindex(b::Buffer, i...)
b[i...], function (Δ)
function getindex_buffer_pullback(Δ)
grad = grad_mut(__context__, b, eltype(Δ))
grad[i...] = accum(grad[i...], Δ)
return
end
b[i...], getindex_buffer_pullback
end

@adjoint! function setindex!(b::Buffer, v, i...)
setindex!(b, v, i...), function (_)
function setindex!_buffer_pullback(_)
grad = grad_mut(__context__, b)
= grad[i...]
zero = eltype(grad) <: Number ? 0 : nothing
Expand All @@ -26,26 +27,49 @@ end
end
(nothing, v̄, map(_->nothing, i)...)
end
setindex!(b, v, i...), setindex!_buffer_pullback
end

@adjoint! function copyto!(b::Buffer, xs)
copyto!(b, xs), function (_)
@adjoint! function copyto!(b::Buffer, src::AbstractArray)
function copyto!_buffer_array_pullback(_)
grad = grad_mut(__context__, b)
x̄s = copy(grad)
grad .= eltype(grad) <: Number ? 0 : nothing
return (nothing, x̄s)
xs = copy(grad)
grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing
return (nothing, xs)
end
copyto!(b, src), copyto!_buffer_array_pullback
end

@adjoint! function copyto!(b::Buffer, bc::Base.Broadcast.Broadcasted)
xs, map_pullback = ∇map(__context__, i -> bc[i], eachindex(bc))
function copyto!_buffer_broadcast_pullback(_)
grad = grad_mut(__context__, b)
d, = map_pullback(reshape(first(grad, length(xs)), size(xs)))
grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing
return (nothing, d.bc)
end
copyto!(b, xs), copyto!_buffer_broadcast_pullback
end

function _pullback(cx::AContext, ::typeof(copyto!), b::Buffer, g::Base.Generator)
xs, collect_pullback = _pullback(cx, collect, g)
function copyto!_buffer_generator_pullback(_)
grad = grad_mut(cx, b)
_, dg = collect_pullback(reshape(first(grad, length(xs)), size(xs)))
grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing
return (nothing, nothing, dg)
end
copyto!(b, xs), copyto!_buffer_generator_pullback
end

@adjoint! function push!(b::Buffer, x)
push!(b, x), function (y)
function push!_buffer_pullback(_)
grad = grad_mut(__context__, b)
return (nothing, pop!(grad))
end
push!(b, x), push!_buffer_pullback
end

_pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x::AbstractArray) =
_pullback(cx, copyto!, b, x)

@adjoint function copy(b::Buffer)
res = copy(b)
Expand Down
6 changes: 4 additions & 2 deletions src/tools/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ function Base.deleteat!(b::Buffer, i)
return b
end

@forward Buffer.data Base.eltype, Base.length, Base.ndims, Base.size, Base.axes,
Base.eachindex, Base.stride, Base.strides, Base.findfirst,
@forward Buffer.data Base.eltype, Base.length, Base.ndims, Base.size, Base.axes,
Base.eachindex, Base.stride, Base.strides, Base.findfirst,
Base.keys

Base.IteratorSize(::Type{<:Buffer{<:Any, A}}) where {A} = Base.IteratorSize(A)
Expand All @@ -84,3 +84,5 @@ function Base.iterate(b::Buffer, state=(eachindex(b),))
y === nothing && return nothing
b[y[1]], (state[1], tail(y)...)
end

Base.BroadcastStyle(::Type{Buffer{T,A}}) where {T,A} = Base.BroadcastStyle(A)
51 changes: 49 additions & 2 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,23 @@ end
@test gradient(x -> sum(inv, collect(view(x', 1,:))), ones(2,2)) == ([-1 0; -1 0],)

@test gradient(xs -> sum(inv, [x^2 for x in xs]), ones(2)) == ([-2, -2],)

# adjoint of generators is available and should support generic arrays and iterators
# generator of array
@test gradient(p -> sum(collect(p*i for i in [1.0, 2.0, 3.0])), 2.0) == (6.0,)
# generator of iterator with HasShape
@test gradient(p -> sum(collect(p*i for (i,) in zip([1.0, 2.0, 3.0]))), 2.0) == (6.0,)
# generator of iterator with HasLength
@test gradient(p -> sum(collect(p*i for i in Iterators.take([1.0, 2.0, 3.0], 3))), 2.0) == (6.0,)
@test gradient(p -> sum(collect(p*i for i in Iterators.take(p*[1.0, 2.0, 3.0], 2))), 2.0) == (12.0,)
# generator 0-d behavior handled incorrectly
@test_broken gradient(p -> sum(collect(p*i for i in 1.0)), 2.0)
@test_broken gradient(p -> sum(collect(p*i for i in fill(1.0))), 2.0)

# adjoints for iterators
@test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 4))), 1.0) == (10.0,)
@test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 5))), 1.0) == (15.0,)
@test_broken gradient(sumcollect, 1.0) == (1.0,) # broken since no generic adjoint
end

@test gradtest(x -> reverse(x), rand(17))
Expand Down Expand Up @@ -523,7 +540,7 @@ end
@test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))

@test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9]

# issue 1224, second order
f1244(w, x) = sum(maximum((w * x).^2, dims=1))
g1244(w, x) = sum(gradient(f1244, w, x)[2].^2)
Expand Down Expand Up @@ -1538,6 +1555,36 @@ using Zygote: Buffer
return sum(copy(b))
end == ([2,2,2],)

@test gradient([1, 2, 3]) do xs
b = Zygote.Buffer(xs)
b .= 2
return sum(copy(b))
end == (nothing,)

@test gradient(1.1) do p
b = Zygote.Buffer(zeros(3))
b .= (p*i for i in eachindex(b))
return sum(copy(b) .* (2:4))
end[1] 1*2 + 2*3 + 3*4

@test gradient(1.1) do p
b = Zygote.Buffer(zeros(3))
copyto!(b, [p*i for i in eachindex(b)])
return sum(copy(b) .* (2:4))
end[1] 1*2 + 2*3 + 3*4

@test gradient(1.1) do p
b = Zygote.Buffer(zeros(3))
copyto!(b, (p*i for i in eachindex(b)))
return sum(copy(b) .* (2:4))
end[1] 1*2 + 2*3 + 3*4

@test_broken gradient(1.1) do p
b = Zygote.Buffer(zeros(3))
copyto!(b, p)
return sum(copy(b) .* (2:4))
end[1] 1*2

@test gradient(2) do x
b = Zygote.Buffer([])
push!(b, x)
Expand Down Expand Up @@ -1701,7 +1748,7 @@ end
end

@testset "FillArrays" begin

@test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1])
@test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing
@test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing
Expand Down
27 changes: 27 additions & 0 deletions test/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,23 @@ end
@test @inferred back(collect(y)) == (nothing, [1.0, 0.0, 0.0, 0.0, 0.0], fill(1.0))
end

@testset "adjoints of Iterators.take" begin
y, back = _pullback(Iterators.take, 1:5, 3)
@test back(collect(y)) == (nothing, [1.0, 2.0, 3.0, 0.0, 0.0], nothing)
@test back([nothing for i in 1:3]) === nothing

@test gradient(x -> sum([2y for y in Iterators.take(x, 4)]), [1,2,3,4])[1] [2, 2, 2, 2]
@test gradient(x -> sum(2y for y in Iterators.take(x, 4)), [1,2,3,4])[1] [2, 2, 2, 2]

for p in (1.0, fill(1.0), [1.0])
@test gradient(p_ -> sum(map(prod, Iterators.take(p_, 1))), p) == (p,)
@test gradient(p_ -> sum(x for x in Iterators.take(p_, 1)), p) == (p,)
end

y, back = _pullback(Iterators.take, ones(2, 2), 3)
@test @inferred back(collect(y)) == (nothing, [1.0 1.0; 1.0 0.0], nothing)
end

@testset "collect" begin
@testset "Dict" begin
d = Dict(1 => 5, 2 => 6)
Expand Down Expand Up @@ -97,6 +114,16 @@ end
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x.^2))), ones(4)) == (3ones(4),)
@test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x.^2))), ones(4)) == (4ones(4),)
end


@testset "Iterators.Take" begin
z = Iterators.take(1:3, 2)
g = gradient(z -> sum(collect(z)), z)[1]
@test g == (xs=[1.0, 1.0, 0.0], n=nothing)

@test gradient(x -> sum(broadcast(prod, Iterators.take(x,2))), ones(4)) == ([1.0,1.0,0.0,0.0],)
@test gradient(x -> sum(broadcast(prod, Iterators.take(x.^2,2))), ones(4)) == (2*[1.0,1.0,0.0,0.0],)
end
end

@testset "dictionary comprehension" begin
Expand Down

0 comments on commit fe393b0

Please sign in to comment.