From cd5e5146261ead1bb98403ce1c20ce094d416f4e Mon Sep 17 00:00:00 2001 From: lxvm Date: Mon, 1 Jan 2024 11:38:05 -0800 Subject: [PATCH 01/12] initial commit --- src/lib/buffer.jl | 2 +- test/gradcheck.jl | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl index 83fceb713..62f904ba4 100644 --- a/src/lib/buffer.jl +++ b/src/lib/buffer.jl @@ -44,7 +44,7 @@ end end end -_pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x::AbstractArray) = +_pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x) = _pullback(cx, copyto!, b, x) @adjoint function copy(b::Buffer) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 8cb7e6e1a..d5c429e15 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -523,7 +523,7 @@ end @test gradtest(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4)) @test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9] - + # issue 1224, second order f1244(w, x) = sum(maximum((w * x).^2, dims=1)) g1244(w, x) = sum(gradient(f1244, w, x)[2].^2) @@ -1538,6 +1538,12 @@ using Zygote: Buffer return sum(copy(b)) end == ([2,2,2],) + @test gradient([1, 2, 3]) do xs + b = Zygote.Buffer(xs) + b .= 2 + return sum(copy(b)) + end == (nothing,) + @test gradient(2) do x b = Zygote.Buffer([]) push!(b, x) @@ -1701,7 +1707,7 @@ end end @testset "FillArrays" begin - + @test gradcheck(x->sum(Fill(x[], (2, 2))), [0.1]) @test first(Zygote.gradient(sz->sum(Ones(sz)), 6)) === nothing @test first(Zygote.gradient(sz->sum(Zeros(sz)), 6)) === nothing From bad5bae37de201c2a780a853bbb8a14930e80c0b Mon Sep 17 00:00:00 2001 From: lxvm Date: Mon, 1 Jan 2024 14:12:33 -0800 Subject: [PATCH 02/12] add adjoint for copyto!(::Buffer, ::Number) --- src/lib/buffer.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl index 62f904ba4..beb226a3f 100644 --- a/src/lib/buffer.jl +++ b/src/lib/buffer.jl @@ -28,7 +28,7 @@ end end end -@adjoint! function copyto!(b::Buffer, xs) +@adjoint! function copyto!(b::Buffer, xs::AbstractArray) copyto!(b, xs), function (_) grad = grad_mut(__context__, b) x̄s = copy(grad) @@ -37,6 +37,13 @@ end end end +@adjoint! function copyto!(b::Buffer, x::Number) + copyto!(b, x), function (_) + grad = grad_mut(__context__, b) + return (nothing, sum(grad)) + end +end + @adjoint! function push!(b::Buffer, x) push!(b, x), function (y) grad = grad_mut(__context__, b) @@ -44,8 +51,9 @@ end end end -_pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x) = - _pullback(cx, copyto!, b, x) +function _pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x) + _pullback(cx, copyto!, b, x) +end @adjoint function copy(b::Buffer) res = copy(b) From 37637255aa13ef64e3a31c42dc87a1584b33058a Mon Sep 17 00:00:00 2001 From: lxvm Date: Tue, 2 Jan 2024 00:33:44 -0800 Subject: [PATCH 03/12] buffer broadcasting --- src/lib/buffer.jl | 32 ++++++++++++++------------------ test/gradcheck.jl | 6 ++++++ 2 files changed, 20 insertions(+), 18 deletions(-) diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl index beb226a3f..335c64c31 100644 --- a/src/lib/buffer.jl +++ b/src/lib/buffer.jl @@ -28,21 +28,6 @@ end end end -@adjoint! function copyto!(b::Buffer, xs::AbstractArray) - copyto!(b, xs), function (_) - grad = grad_mut(__context__, b) - x̄s = copy(grad) - grad .= eltype(grad) <: Number ? 0 : nothing - return (nothing, x̄s) - end -end - -@adjoint! function copyto!(b::Buffer, x::Number) - copyto!(b, x), function (_) - grad = grad_mut(__context__, b) - return (nothing, sum(grad)) - end -end @adjoint! function push!(b::Buffer, x) push!(b, x), function (y) @@ -51,9 +36,6 @@ end end end -function _pullback(cx::AContext, ::typeof(Broadcast.materialize!), b::Buffer, x) - _pullback(cx, copyto!, b, x) -end @adjoint function copy(b::Buffer) res = copy(b) @@ -70,3 +52,17 @@ end return res, copy_sensitivity end + +Base.BroadcastStyle(::Type{Buffer{T,A}}) where {T,A} = Base.BroadcastStyle(A) + +@non_differentiable Base.Broadcast.Broadcasted(::Nothing) + +function _pullback(cx::AContext, ::typeof(copyto!), b::Buffer, bc::Base.Broadcast.Broadcasted) + xs, map_pullback = ∇map(cx, i -> bc[i], eachindex(bc)) + copyto!(b, xs), function (_) + grad = grad_mut(cx, b) + # ys = copy(grad) + d, = map_pullback(reshape(first(grad, length(xs)), size(xs))) + return (nothing, nothing, d.bc) + end +end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index d5c429e15..50579e044 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1544,6 +1544,12 @@ using Zygote: Buffer return sum(copy(b)) end == (nothing,) + @test gradient(1.1) do p + b = Zygote.Buffer(zeros(3)) + b .= (p*i for i in eachindex(b)) + return sum(copy(b) .* (2:4)) + end[1] ≈ 1*2 + 2*3 + 3*4 + @test gradient(2) do x b = Zygote.Buffer([]) push!(b, x) From 739f1efc1ab1d1f26e1c95f1822f0ead86842585 Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 17 Jan 2024 00:04:51 -0500 Subject: [PATCH 04/12] move definitions around --- src/lib/buffer.jl | 22 ++++++++-------------- src/tools/buffer.jl | 6 ++++-- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl index 335c64c31..fc6547343 100644 --- a/src/lib/buffer.jl +++ b/src/lib/buffer.jl @@ -28,6 +28,14 @@ end end end +@adjoint! function copyto!(b::Buffer, bc::Base.Broadcast.Broadcasted) + xs, map_pullback = ∇map(__context__, i -> bc[i], eachindex(bc)) + copyto!(b, xs), function (_) + grad = grad_mut(__context__, b) + d, = map_pullback(reshape(first(grad, length(xs)), size(xs))) + return (nothing, nothing, d.bc) + end +end @adjoint! function push!(b::Buffer, x) push!(b, x), function (y) @@ -52,17 +60,3 @@ end return res, copy_sensitivity end - -Base.BroadcastStyle(::Type{Buffer{T,A}}) where {T,A} = Base.BroadcastStyle(A) - -@non_differentiable Base.Broadcast.Broadcasted(::Nothing) - -function _pullback(cx::AContext, ::typeof(copyto!), b::Buffer, bc::Base.Broadcast.Broadcasted) - xs, map_pullback = ∇map(cx, i -> bc[i], eachindex(bc)) - copyto!(b, xs), function (_) - grad = grad_mut(cx, b) - # ys = copy(grad) - d, = map_pullback(reshape(first(grad, length(xs)), size(xs))) - return (nothing, nothing, d.bc) - end -end diff --git a/src/tools/buffer.jl b/src/tools/buffer.jl index 9409a74bc..d8c9a82d3 100644 --- a/src/tools/buffer.jl +++ b/src/tools/buffer.jl @@ -72,8 +72,8 @@ function Base.deleteat!(b::Buffer, i) return b end -@forward Buffer.data Base.eltype, Base.length, Base.ndims, Base.size, Base.axes, - Base.eachindex, Base.stride, Base.strides, Base.findfirst, +@forward Buffer.data Base.eltype, Base.length, Base.ndims, Base.size, Base.axes, + Base.eachindex, Base.stride, Base.strides, Base.findfirst, Base.keys Base.IteratorSize(::Type{<:Buffer{<:Any, A}}) where {A} = Base.IteratorSize(A) @@ -84,3 +84,5 @@ function Base.iterate(b::Buffer, state=(eachindex(b),)) y === nothing && return nothing b[y[1]], (state[1], tail(y)...) end + +Base.BroadcastStyle(::Type{Buffer{T,A}}) where {T,A} = Base.BroadcastStyle(A) From c749e8772416ae172610c17ce313f1ff22357c74 Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 17 Jan 2024 14:51:47 -0500 Subject: [PATCH 05/12] gives adjoint names --- src/lib/buffer.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl index fc6547343..2ddf8396f 100644 --- a/src/lib/buffer.jl +++ b/src/lib/buffer.jl @@ -7,15 +7,16 @@ grad_mut(cx::Context, b::Buffer{T}, ::Type{S}=Union{}) where {T<:Number, S<:Numb @non_differentiable Buffer(::Any...) @adjoint function getindex(b::Buffer, i...) - b[i...], function (Δ) + function getindex_buffer_pullback(Δ) grad = grad_mut(__context__, b, eltype(Δ)) grad[i...] = accum(grad[i...], Δ) return end + b[i...], getindex_buffer_pullback end @adjoint! function setindex!(b::Buffer, v, i...) - setindex!(b, v, i...), function (_) + function setindex!_buffer_pullback(_) grad = grad_mut(__context__, b) v̄ = grad[i...] zero = eltype(grad) <: Number ? 0 : nothing @@ -26,22 +27,25 @@ end end (nothing, v̄, map(_->nothing, i)...) end + setindex!(b, v, i...), setindex!_buffer_pullback end @adjoint! function copyto!(b::Buffer, bc::Base.Broadcast.Broadcasted) xs, map_pullback = ∇map(__context__, i -> bc[i], eachindex(bc)) - copyto!(b, xs), function (_) + function copyto!_buffer_pullback(_) grad = grad_mut(__context__, b) d, = map_pullback(reshape(first(grad, length(xs)), size(xs))) return (nothing, nothing, d.bc) end + copyto!(b, xs), copyto!_buffer_pullback end @adjoint! function push!(b::Buffer, x) - push!(b, x), function (y) + function push!_buffer_pullback(_) grad = grad_mut(__context__, b) return (nothing, pop!(grad)) end + push!(b, x), push!_buffer_pullback end From 8ea82efdeb92e448f67844678838c34b9c5b663e Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 17 Jan 2024 17:26:07 -0500 Subject: [PATCH 06/12] restore adjoint of copyto! buffer array --- src/lib/buffer.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl index 2ddf8396f..267e542af 100644 --- a/src/lib/buffer.jl +++ b/src/lib/buffer.jl @@ -30,14 +30,22 @@ end setindex!(b, v, i...), setindex!_buffer_pullback end +@adjoint! function copyto!(b::Buffer, src::AbstractArray) + function copyto!_buffer_array_pullback(_) + grad = grad_mut(__context__, b) + return (nothing, copy(grad)) + end + copyto!(b, src), copyto!_buffer_array_pullback +end + @adjoint! function copyto!(b::Buffer, bc::Base.Broadcast.Broadcasted) xs, map_pullback = ∇map(__context__, i -> bc[i], eachindex(bc)) - function copyto!_buffer_pullback(_) + function copyto!_buffer_broadcast_pullback(_) grad = grad_mut(__context__, b) d, = map_pullback(reshape(first(grad, length(xs)), size(xs))) - return (nothing, nothing, d.bc) + return (nothing, d.bc) end - copyto!(b, xs), copyto!_buffer_pullback + copyto!(b, xs), copyto!_buffer_broadcast_pullback end @adjoint! function push!(b::Buffer, x) From 09f888b6d5011f174548bd721e104cf77031a87c Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 17 Jan 2024 17:28:57 -0500 Subject: [PATCH 07/12] add tests for copyto! buffer, including broken ones for iterators --- test/gradcheck.jl | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 50579e044..9cd1facc4 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -214,6 +214,9 @@ end @test gradient(x -> sum(inv, collect(view(x', 1,:))), ones(2,2)) == ([-1 0; -1 0],) @test gradient(xs -> sum(inv, [x^2 for x in xs]), ones(2)) == ([-2, -2],) + + # no adjoint for generic iterators + @test_broken gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 5))), 1.0) end @test gradtest(x -> reverse(x), rand(17)) @@ -1550,6 +1553,18 @@ using Zygote: Buffer return sum(copy(b) .* (2:4)) end[1] ≈ 1*2 + 2*3 + 3*4 + @test gradient(1.1) do p + b = Zygote.Buffer(zeros(3)) + copyto!(b, [p*i for i in eachindex(b)]) + return sum(copy(b) .* (2:4)) + end[1] ≈ 1*2 + 2*3 + 3*4 + + @test_broken gradient(1.1) do p + b = Zygote.Buffer(zeros(3)) + copyto!(b, (p*i for i in eachindex(b))) + return sum(copy(b) .* (2:4)) + end[1] ≈ 1*2 + 2*3 + 3*4 + @test gradient(2) do x b = Zygote.Buffer([]) push!(b, x) From 728a830da472128ad7e489706a65e718bb55475e Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 17 Jan 2024 17:30:07 -0500 Subject: [PATCH 08/12] restrict _tryaxes to AbstractArrays since iterators would error --- src/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 6d914d272..8e9d98ec7 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -168,7 +168,7 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U) # With mismatched lengths, map stops early. With mismatched shapes, it makes a vector. # So we keep axes(x) to restore gradient dx to its full length & correct shape. -_tryaxes(x) = axes(x) +_tryaxes(x::AbstractArray) = axes(x) _tryaxes(x::Tuple) = Val(length(x)) _tryaxes(x::Number) = x _restore(dx::AbstractArray{Nothing}, ax::Tuple) = similar(dx, ax) From e9314df043947a5d7483582c53dfad4fb999d20a Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 17 Jan 2024 18:40:05 -0500 Subject: [PATCH 09/12] add take iterator and test map adjoint on iterators --- src/lib/array.jl | 15 +++++++++++++++ test/gradcheck.jl | 18 ++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 8e9d98ec7..f39506cc0 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -168,6 +168,7 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U) # With mismatched lengths, map stops early. With mismatched shapes, it makes a vector. # So we keep axes(x) to restore gradient dx to its full length & correct shape. +_tryaxes(x) = (s = Base.IteratorSize(x); s isa Base.HasShape ? axes(x) : s isa Base.HasLength ? (Base.OneTo(length(x)),) : throw(ArgumentError("iterator size must be finite"))) _tryaxes(x::AbstractArray) = axes(x) _tryaxes(x::Tuple) = Val(length(x)) _tryaxes(x::Number) = x @@ -319,6 +320,20 @@ end collect(z), collect_zip_pullback end +takefunc(itr, dy) = _restore(dy, _tryaxes(itr)) + +@adjoint function Iterators.take(itr, n) + take_pullback(dy::NamedTuple{(:xs,:n)}) = (dy.xs, dy.n) + take_pullback(dy::NamedTuple{(:n,:xs)}) = (dy.xs, dy.n) + take_pullback(dy) = (takefunc(itr, dy), nothing) + Iterators.take(itr, n), take_pullback +end + +@adjoint function Base.collect(t::Iterators.Take) + collect_take_pullback(dy) = ((xs=takefunc(t.xs, dy), n=nothing),) + collect(t), collect_take_pullback +end + # Reductions @adjoint function sum(xs::AbstractArray; dims = :) if dims === (:) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 9cd1facc4..4391fba90 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -215,8 +215,22 @@ end @test gradient(xs -> sum(inv, [x^2 for x in xs]), ones(2)) == ([-2, -2],) - # no adjoint for generic iterators - @test_broken gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 5))), 1.0) + # adjoint of generators is available and should support generic arrays and iterators + # generator of array + @test gradient(p -> sum(collect(p*i for i in [1.0, 2.0, 3.0])), 2.0) == (6.0,) + # generator of iterator with HasShape + @test gradient(p -> sum(collect(p*i for (i,) in zip([1.0, 2.0, 3.0]))), 2.0) == (6.0,) + # generator of iterator with HasLength + @test gradient(p -> sum(collect(p*i for i in Iterators.take([1.0, 2.0, 3.0], 3))), 2.0) == (6.0,) + @test gradient(p -> sum(collect(p*i for i in Iterators.take(p*[1.0, 2.0, 3.0], 2))), 2.0) == (12.0,) + # generator 0-d behavior handled incorrectly + @test_broken gradient(p -> sum(collect(p*i for i in 1.0)), 2.0) + @test_broken gradient(p -> sum(collect(p*i for i in fill(1.0))), 2.0) + + # adjoints for iterators + @test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 4))), 1.0) == (10.0,) + @test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 5))), 1.0) == (15.0,) + @test_broken gradient(sum∘collect, 1.0) == (1.0,) # broken since no generic adjoint end @test gradtest(x -> reverse(x), rand(17)) From 3cee46573fe808d7b40a4236b5d9068ae8bf6781 Mon Sep 17 00:00:00 2001 From: lxvm Date: Wed, 17 Jan 2024 19:01:32 -0500 Subject: [PATCH 10/12] add adjoint for copyto! buffer generator and another broken test --- src/lib/buffer.jl | 10 ++++++++++ test/gradcheck.jl | 8 +++++++- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl index 267e542af..7d5bf460f 100644 --- a/src/lib/buffer.jl +++ b/src/lib/buffer.jl @@ -48,6 +48,16 @@ end copyto!(b, xs), copyto!_buffer_broadcast_pullback end +function _pullback(cx::AContext, ::typeof(copyto!), b::Buffer, g::Base.Generator) + xs, collect_pullback = _pullback(cx, collect, g) + function copyto!_buffer_generator_pullback(_) + grad = grad_mut(cx, b) + _, dg = collect_pullback(reshape(first(grad, length(xs)), size(xs))) + return (nothing, nothing, dg) + end + copyto!(b, xs), copyto!_buffer_generator_pullback + end + @adjoint! function push!(b::Buffer, x) function push!_buffer_pullback(_) grad = grad_mut(__context__, b) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 4391fba90..133383048 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1573,12 +1573,18 @@ using Zygote: Buffer return sum(copy(b) .* (2:4)) end[1] ≈ 1*2 + 2*3 + 3*4 - @test_broken gradient(1.1) do p + @test gradient(1.1) do p b = Zygote.Buffer(zeros(3)) copyto!(b, (p*i for i in eachindex(b))) return sum(copy(b) .* (2:4)) end[1] ≈ 1*2 + 2*3 + 3*4 + @test_broken gradient(1.1) do p + b = Zygote.Buffer(zeros(3)) + copyto!(b, p) + return sum(copy(b) .* (2:4)) + end[1] ≈ 1*2 + @test gradient(2) do x b = Zygote.Buffer([]) push!(b, x) From c87697a48a0363b1ea941bdb7e094bf31e3bdc13 Mon Sep 17 00:00:00 2001 From: lxvm Date: Sun, 11 Feb 2024 23:04:23 -0500 Subject: [PATCH 11/12] restore zeroing out grad cache --- src/lib/buffer.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/lib/buffer.jl b/src/lib/buffer.jl index 7d5bf460f..0666172b4 100644 --- a/src/lib/buffer.jl +++ b/src/lib/buffer.jl @@ -33,7 +33,9 @@ end @adjoint! function copyto!(b::Buffer, src::AbstractArray) function copyto!_buffer_array_pullback(_) grad = grad_mut(__context__, b) - return (nothing, copy(grad)) + xs = copy(grad) + grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing + return (nothing, xs) end copyto!(b, src), copyto!_buffer_array_pullback end @@ -43,6 +45,7 @@ end function copyto!_buffer_broadcast_pullback(_) grad = grad_mut(__context__, b) d, = map_pullback(reshape(first(grad, length(xs)), size(xs))) + grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing return (nothing, d.bc) end copyto!(b, xs), copyto!_buffer_broadcast_pullback @@ -53,6 +56,7 @@ function _pullback(cx::AContext, ::typeof(copyto!), b::Buffer, g::Base.Generator function copyto!_buffer_generator_pullback(_) grad = grad_mut(cx, b) _, dg = collect_pullback(reshape(first(grad, length(xs)), size(xs))) + grad .= eltype(grad) <: Number ? zero(eltype(grad)) : nothing return (nothing, nothing, dg) end copyto!(b, xs), copyto!_buffer_generator_pullback From 9df71466e5bea46871764a017c16bb38af4c7430 Mon Sep 17 00:00:00 2001 From: lxvm Date: Tue, 13 Feb 2024 07:43:31 -0500 Subject: [PATCH 12/12] add tests for Iterators.take --- src/lib/array.jl | 3 ++- test/lib/array.jl | 27 +++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index f39506cc0..633aeaf06 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -323,9 +323,10 @@ end takefunc(itr, dy) = _restore(dy, _tryaxes(itr)) @adjoint function Iterators.take(itr, n) + take_pullback(::AbstractArray{Nothing}) = nothing take_pullback(dy::NamedTuple{(:xs,:n)}) = (dy.xs, dy.n) take_pullback(dy::NamedTuple{(:n,:xs)}) = (dy.xs, dy.n) - take_pullback(dy) = (takefunc(itr, dy), nothing) + take_pullback(dy::AbstractArray) = (takefunc(itr, dy), nothing) Iterators.take(itr, n), take_pullback end diff --git a/test/lib/array.jl b/test/lib/array.jl index 8016c9541..64ecfa9b5 100644 --- a/test/lib/array.jl +++ b/test/lib/array.jl @@ -50,6 +50,23 @@ end @test @inferred back(collect(y)) == (nothing, [1.0, 0.0, 0.0, 0.0, 0.0], fill(1.0)) end +@testset "adjoints of Iterators.take" begin + y, back = _pullback(Iterators.take, 1:5, 3) + @test back(collect(y)) == (nothing, [1.0, 2.0, 3.0, 0.0, 0.0], nothing) + @test back([nothing for i in 1:3]) === nothing + + @test gradient(x -> sum([2y for y in Iterators.take(x, 4)]), [1,2,3,4])[1] ≈ [2, 2, 2, 2] + @test gradient(x -> sum(2y for y in Iterators.take(x, 4)), [1,2,3,4])[1] ≈ [2, 2, 2, 2] + + for p in (1.0, fill(1.0), [1.0]) + @test gradient(p_ -> sum(map(prod, Iterators.take(p_, 1))), p) == (p,) + @test gradient(p_ -> sum(x for x in Iterators.take(p_, 1)), p) == (p,) + end + + y, back = _pullback(Iterators.take, ones(2, 2), 3) + @test @inferred back(collect(y)) == (nothing, [1.0 1.0; 1.0 0.0], nothing) +end + @testset "collect" begin @testset "Dict" begin d = Dict(1 => 5, 2 => 6) @@ -97,6 +114,16 @@ end @test gradient(x -> sum(broadcast(prod, Iterators.zip(x,x.^2))), ones(4)) == (3ones(4),) @test gradient(x -> sum(broadcast(prod, Iterators.zip(x.^2,x.^2))), ones(4)) == (4ones(4),) end + + + @testset "Iterators.Take" begin + z = Iterators.take(1:3, 2) + g = gradient(z -> sum(collect(z)), z)[1] + @test g == (xs=[1.0, 1.0, 0.0], n=nothing) + + @test gradient(x -> sum(broadcast(prod, Iterators.take(x,2))), ones(4)) == ([1.0,1.0,0.0,0.0],) + @test gradient(x -> sum(broadcast(prod, Iterators.take(x.^2,2))), ones(4)) == (2*[1.0,1.0,0.0,0.0],) + end end @testset "dictionary comprehension" begin