diff --git a/src/lib/array.jl b/src/lib/array.jl index eef7244bd..bea53542b 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -167,6 +167,7 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U) # With mismatched lengths, map stops early. With mismatched shapes, it makes a vector. # So we keep axes(x) to restore gradient dx to its full length & correct shape. +_tryaxes(x) = (s = Base.IteratorSize(x); s isa Base.HasShape ? axes(x) : s isa Base.HasLength ? (Base.OneTo(length(x)),) : throw(ArgumentError("iterator size must be finite"))) _tryaxes(x::AbstractArray) = axes(x) _tryaxes(x::Tuple) = Val(length(x)) _restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax) @@ -296,6 +297,20 @@ end Iterators.Zip(xs), back end +takefunc(itr, dy) = _restore(dy, _tryaxes(itr)) + +@adjoint function Iterators.take(itr, n) + take_pullback(dy::NamedTuple{(:xs,:n)}) = (dy.xs, dy.n) + take_pullback(dy::NamedTuple{(:n,:xs)}) = (dy.xs, dy.n) + take_pullback(dy) = (takefunc(itr, dy), nothing) + Iterators.take(itr, n), take_pullback +end + +@adjoint function Base.collect(t::Iterators.Take) + collect_take_pullback(dy) = ((xs=takefunc(t.xs, dy), n=nothing),) + collect(t), collect_take_pullback +end + # Reductions @adjoint function sum(xs::AbstractArray; dims = :) if dims === (:) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 9cd1facc4..4391fba90 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -215,8 +215,22 @@ end @test gradient(xs -> sum(inv, [x^2 for x in xs]), ones(2)) == ([-2, -2],) - # no adjoint for generic iterators - @test_broken gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 5))), 1.0) + # adjoint of generators is available and should support generic arrays and iterators + # generator of array + @test gradient(p -> sum(collect(p*i for i in [1.0, 2.0, 3.0])), 2.0) == (6.0,) + # generator of iterator with HasShape + @test gradient(p -> sum(collect(p*i for (i,) in zip([1.0, 2.0, 3.0]))), 2.0) == (6.0,) + # generator of iterator with HasLength + @test gradient(p -> sum(collect(p*i for i in Iterators.take([1.0, 2.0, 3.0], 3))), 2.0) == (6.0,) + @test gradient(p -> sum(collect(p*i for i in Iterators.take(p*[1.0, 2.0, 3.0], 2))), 2.0) == (12.0,) + # generator 0-d behavior handled incorrectly + @test_broken gradient(p -> sum(collect(p*i for i in 1.0)), 2.0) + @test_broken gradient(p -> sum(collect(p*i for i in fill(1.0))), 2.0) + + # adjoints for iterators + @test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 4))), 1.0) == (10.0,) + @test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 5))), 1.0) == (15.0,) + @test_broken gradient(sum∘collect, 1.0) == (1.0,) # broken since no generic adjoint end @test gradtest(x -> reverse(x), rand(17))