Skip to content

Commit

Permalink
add take iterator and test map adjoint on iterators
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Jan 17, 2024
1 parent b37baf7 commit 1495b9b
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
15 changes: 15 additions & 0 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ _reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U)

# With mismatched lengths, map stops early. With mismatched shapes, it makes a vector.
# So we keep axes(x) to restore gradient dx to its full length & correct shape.
_tryaxes(x) = (s = Base.IteratorSize(x); s isa Base.HasShape ? axes(x) : s isa Base.HasLength ? (Base.OneTo(length(x)),) : throw(ArgumentError("iterator size must be finite")))
_tryaxes(x::AbstractArray) = axes(x)
_tryaxes(x::Tuple) = Val(length(x))
_restore(dx, ax::Tuple) = axes(dx) == ax ? dx : reshape(vcat(dx, falses(prod(length, ax) - length(dx))), ax)
Expand Down Expand Up @@ -296,6 +297,20 @@ end
Iterators.Zip(xs), back
end

takefunc(itr, dy) = _restore(dy, _tryaxes(itr))

@adjoint function Iterators.take(itr, n)
take_pullback(dy::NamedTuple{(:xs,:n)}) = (dy.xs, dy.n)
take_pullback(dy::NamedTuple{(:n,:xs)}) = (dy.xs, dy.n)
take_pullback(dy) = (takefunc(itr, dy), nothing)
Iterators.take(itr, n), take_pullback
end

@adjoint function Base.collect(t::Iterators.Take)
collect_take_pullback(dy) = ((xs=takefunc(t.xs, dy), n=nothing),)
collect(t), collect_take_pullback
end

# Reductions
@adjoint function sum(xs::AbstractArray; dims = :)
if dims === (:)
Expand Down
18 changes: 16 additions & 2 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,22 @@ end

@test gradient(xs -> sum(inv, [x^2 for x in xs]), ones(2)) == ([-2, -2],)

# no adjoint for generic iterators
@test_broken gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 5))), 1.0)
# adjoint of generators is available and should support generic arrays and iterators
# generator of array
@test gradient(p -> sum(collect(p*i for i in [1.0, 2.0, 3.0])), 2.0) == (6.0,)
# generator of iterator with HasShape
@test gradient(p -> sum(collect(p*i for (i,) in zip([1.0, 2.0, 3.0]))), 2.0) == (6.0,)
# generator of iterator with HasLength
@test gradient(p -> sum(collect(p*i for i in Iterators.take([1.0, 2.0, 3.0], 3))), 2.0) == (6.0,)
@test gradient(p -> sum(collect(p*i for i in Iterators.take(p*[1.0, 2.0, 3.0], 2))), 2.0) == (12.0,)
# generator 0-d behavior handled incorrectly
@test_broken gradient(p -> sum(collect(p*i for i in 1.0)), 2.0)
@test_broken gradient(p -> sum(collect(p*i for i in fill(1.0))), 2.0)

# adjoints for iterators
@test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 4))), 1.0) == (10.0,)
@test gradient(x -> sum(collect(Iterators.take([x*i for i in 1:5], 5))), 1.0) == (15.0,)
@test_broken gradient(sumcollect, 1.0) == (1.0,) # broken since no generic adjoint
end

@test gradtest(x -> reverse(x), rand(17))
Expand Down

0 comments on commit 1495b9b

Please sign in to comment.