Skip to content

Commit

Permalink
day two
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 20, 2022
1 parent 328432f commit fe779b3
Show file tree
Hide file tree
Showing 8 changed files with 193 additions and 61 deletions.
30 changes: 9 additions & 21 deletions src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,34 +249,22 @@ end
#####

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray) where {F}
y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # could be broadcast, but Yota likes this one
return Broadcast.materialize(y), back
# y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # Yota likes this one
# return Broadcast.materialize(y), back
y, back = rrule_via_ad(cfg, broadcast, f, x) # but testing like this one
return y, back
end

# Could accept Any?
# `_unmap_pad` is also used for `zip`
function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F}
@debug "rrule(map, f, arrays...)" f
z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...)
# z, backs = unzip(map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...))
function map_pullback(dz)
df, dxy... = unzip_map(|>, unthunk(dz), backs)
# df, dxy... = unzip(map(|>, unthunk(dz), backs))
return (NoTangent(), ProjectTo(sum(df)), map(_unmap_pad, (x, ys...), dxy)...)
function map_pullback_2(dz)
df, dxy... = unzip_map_reversed(|>, unthunk(dz), backs)
return (NoTangent(), ProjectTo(f)(sum(df)), map(_unmap_pad, (x, ys...), dxy)...)
end
z, map_pullback
z, map_pullback_2
end

# function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F}
# z, zip_back = rrule(zip, x, ys...)
# m, map_back = rrule(config, map, Splat(f), z) # maybe this is inefficient?
# function map_pullback(dm)
# _, dsplatf, dz = map_back(dm)
# _, dxys... = zip_back(dz)
# return (NoTangent(), 0, dxys...)
# end
# return m, map_back
# end

#####
##### `task_local_storage`
#####
Expand Down
25 changes: 22 additions & 3 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,16 +119,35 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F}
end

# Path 4: The most generic, save all the pullbacks. Can be 1000x slower.
# Since broadcast makes no guarantee about order of calls, and un-fusing
# can change the number of calls, don't bother to try to reverse the iteration.
# While broadcast makes no guarantee about order of calls, it's cheap to reverse the iteration.

#=
julia> Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), [1,2,3.0])
┌ Debug: split broadcasting generic
│ f = #69 (generic function with 1 method)
│ N = 1
└ @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:126
(14.0, (ZeroTangent(), [2.0, 4.0, 6.0]))
julia> ENV["JULIA_DEBUG"] = nothing
julia> @btime Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), $(rand(1000)));
min 1.321 ms, mean 1.434 ms (23010 allocations, 594.66 KiB) # with unzip_map, as before
min 1.279 ms, mean 1.393 ms (23029 allocations, 595.73 KiB) # with unzip_map_reversed
julia> @btime Yota.grad(xs -> sum(abs2, abs.(xs)), $(randn(1000))); # Debug: split broadcasting derivative
min 2.144 μs, mean 6.620 μs (6 allocations, 23.88 KiB)
=#

function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
@debug("split broadcasting generic", f, N)
ys3, backs = unzip_broadcast(args...) do a...
rrule_via_ad(cfg, f, a...)
end
function back_generic(dys)
deltas = unzip_broadcast(backs, unthunk(dys)) do back, dy # (could be map, sizes match)
deltas = unzip_map_reversed(backs, unthunk(dys)) do back, dy
map(unthunk, back(dy))
end
dargs = map(unbroadcast, args, Base.tail(deltas))
Expand Down
41 changes: 14 additions & 27 deletions src/rulesets/Base/iterators.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor
tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/JuliaDiff/Diffractor.jl/pull/86

#####
##### Comprehension: Iterators.map
Expand All @@ -7,38 +7,18 @@ tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor
# Comprehension does guarantee iteration order. Thus its gradient must reverse.

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator}
# ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter)
ys, backs = unzip(map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter))
@debug "collect generator"
ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter)
proj_f = ProjectTo(gen.f)
proj_iter = ProjectTo(gen.iter)
function generator_pullback(dys_raw)
dys = unthunk(dys_raw)
# dfs, dxs = unzip_map(|>, Iterators.reverse(dys), Iterators.reverse(backs))
dfs, dxs = unzip(map(|>, Iterators.reverse(dys), Iterators.reverse(backs)))
return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(reverse!!(dxs))))
dfs, dxs = unzip_map_reversed(|>, dys, backs)
return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(dxs)))
end
ys, generator_pullback
end

"""
reverse!!(x)
Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`.
Only safe if you are quite sure nothing else closes over `x`.
"""
function reverse!!(x::AbstractArray)
if ChainRulesCore.is_inplaceable_destination(x)
Base.reverse!(x)
else
Base.reverse(x)
end
end
frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot)
function rrule(::typeof(reverse!!), x::AbstractArray)
reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy)))
return reverse!!(x), reverse!!_back
end

# Needed for Yota, but shouldn't these be automatic?
ChainRulesCore.rrule(::Type{<:Base.Generator}, f, iter) = Base.Generator(f, iter), dy -> (NoTangent(), dy.f, dy.iter)
ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.ProductIterator(iters), dy -> (NoTangent(), dy.iterators)
Expand Down Expand Up @@ -107,18 +87,25 @@ function rrule(::typeof(zip), xs::AbstractArray...)
end

_tangent_unzip(xs::AbstractArray{Tangent{T,B}}) where {T<:Tuple, B<:Tuple} = unzip(reinterpret(B, xs))
_tangent_unzip(xs::AbstractArray) = unzip(xs) # Diffractor
_tangent_unzip(xs::AbstractArray) = unzip(xs) # temp fix for Diffractor

# This is like unbroadcast, except for map's stopping-short behaviour, not broadcast's extension.
# Closing over `x` lets us re-use ∇getindex.
function _unmap_pad(x::AbstractArray, dx::AbstractArray)
if length(x) == length(dx)
ProjectTo(x)(reshape(dx, axes(x)))
else
@debug "_unmap_pad is extending gradient" length(x) == length(dx)
i1 = firstindex(x)
∇getindex(x, vec(dx), i1:i1+length(dx)-1)
# dx2 = vcat(vec(dx), similar(x, ZeroTangent, length(x) - length(dx)))
# ProjectTo(x)(reshape(dx2, axes(x)))
end
end


# For testing
function rrule(::ComposedFunction{typeof(collect), typeof(zip)}, xs::AbstractArray...)
y, back = rrule(zip, xs...)
return collect(y), back
end

59 changes: 59 additions & 0 deletions src/unzipped.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,65 @@ function unzip_map(f::F, args...) where {F}
return StructArrays.components(StructArray(Iterators.map(f, args...)))
end

unzip_map(f::F, args::Tuple...) where {F} = unzip(map(f, args...))

unzip_map(f::F, args::AbstractGPUArray...) where {F} = unzip(map(f, args...))

function unzip_map_reversed(f::F, args...) where {F}
T = Broadcast.combine_eltypes(f, args)
if isconcretetype(T)
T <: Tuple || throw(ArgumentError("""unzip_map_reversed(f, args) only works on functions returning a tuple,
but f = $(sprint(show, f)) returns type T = $T"""))
end
len1 = length(first(args))
if all(a -> length(a)==len1, args)
rev_args = map(Iterators.reverse, args)
outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...)))
else
len = minimum(length, args)
rev_args = map(a -> Iterators.reverse(@view a[begin:begin+len-1]), args)
outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...)))
end
return map(reverse!!, outs)
end

function unzip_map_reversed(f::F, args::Tuple...) where {F}
len = minimum(length, args)
rev_args = map(a -> reverse(a[1:len]), args)
# vlen = Val(len)
# rev_args = map(args) do a
# reverse(ntuple(i -> a[i], vlen)) # does not infer better
# end
return map(reverse, unzip(map(f, rev_args...)))
end
# function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N}
# rev_args = map(reverse, args)
# return map(reverse, unzip(map(f, rev_args...)))
# end

"""
reverse!!(x)
Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`.
Only safe if you are quite sure nothing else closes over `x`.
"""
function reverse!!(x::AbstractArray)
if ChainRulesCore.is_inplaceable_destination(x)
Base.reverse!(x)
else
Base.reverse(x)
end
end
reverse!!(x::AbstractArray{<:AbstractZero}) = x

frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot)

function rrule(::typeof(reverse!!), x::AbstractArray)
reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy)))
return reverse!!(x), reverse!!_back
end


#####
##### unzip
#####
Expand Down
34 changes: 34 additions & 0 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -229,4 +229,38 @@
test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false)
end
end

@testset "map(f, ::Array)" begin
test_rrule(map, identity, [1.0, 2.0], check_inferred=false)
test_rrule(map, conj, [1, 2+im, 3.0]', check_inferred=false)
test_rrule(map, make_two_vec, [4.0, 5.0 + 6im], check_inferred=false)
# @test rrule(CFG, map, make_two_vec, [4.0, 5.0 + 6im])[2]([1:2, 3:4])[3] ≈ [1 + 2im, 3 + 4im] # FiniteDifferences DimensionMismatch

@test_skip test_rrule(map, Multiplier(rand() + im), rand(3), check_inferred=false)
rrule(CFG, map, Multiplier(2.0), [3, 4, 5.0])[2]([10, 20, 30]) # (NoTangent(), Multiplier{Float64}(259.99999), [19.99999, 40.000, 60.000]) -- WTF?
@test_skip test_rrule(map, Multiplier(rand() + im) NoTangent(), rand(3), check_inferred=false) # Expression: ad_cotangent isa NoTangent Evaluated: Multiplier{ComplexF64}(-3.7869064372333963 + 2.046139872866103im) isa NoTangent

y1, bk1 = rrule(CFG, map, abs2, [1.0, 2.0, 3.0])
@test y1 == [1, 4, 9]
@test bk1([4, 5, 6.0])[3] 2 .* (1:3) .* (4:6)

y2, bk2 = rrule(CFG, map, Counter(), [11, 12, 13.0])
@test y2 == map(Counter(), 11:13)
@test_skip bk2(ones(3))[3] == [93, 83, 73] # FiniteDifferences has incremented the counter very high
end

@testset "map(f, ::Array, ::Array)" begin
test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent}
test_rrule(map, /, [1.0, 2.0], [3.0, 4.0, 5.0], check_inferred=false)
test_rrule(map, atan, [1, 2, 3.0], [4 5; 6 7.0], check_inferred=false)

test_rrule(map, Multiplier(rand()), rand(3), rand(4), check_inferred=false)

cnt3 = Counter()
y3, bk3 = rrule(CFG, map, cnt3, [1, 2, 3.0], [0, -1, -2, -33.3])
@test y3 == 1:3
@test cnt3 == Counter(3)
z3 = bk3([1, 1, 1000])
@test z3[3] == [53, 33, 13000]
end
end
26 changes: 26 additions & 0 deletions test/rulesets/Base/iterators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@

@testset "Comprehension" begin
@testset "simple" begin
y1, bk1 = rrule(CFG, collect, (i^2 for i in [1.0, 2.0, 3.0]))
@test y1 == [1,4,9]
t1 = bk1(4:6)[2]
@test t1 isa Tangent{<:Base.Generator}
@test t1.f == NoTangent()
@test t1.iter 2 .* (1:3) .* (4:6)

y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0]))
@test y2 == map(Counter(), 11:13)
@test bk2(ones(3))[2].iter == [93, 83, 73]
end
end

@testset "Iterators" begin
@testset "zip" begin
test_rrule(collectzip, rand(3), rand(3))
test_rrule(collectzip, rand(2,2), rand(2,2), rand(2,2))
test_rrule(collectzip, rand(4), rand(2,2))

test_rrule(collectzip, rand(3), rand(5))
test_rrule(collectzip, rand(3,2), rand(5))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ end
include_test("rulesets/Base/mapreduce.jl")
include_test("rulesets/Base/sort.jl")
include_test("rulesets/Base/broadcast.jl")
include_test("rulesets/Base/iterators.jl")

include_test("unzipped.jl") # used primarily for broadcast

Expand Down
38 changes: 28 additions & 10 deletions test/unzipped.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@

using ChainRules: unzip_broadcast, unzip #, unzip_map
using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed

@testset "unzipped.jl" begin
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzipmap, unzipbroadcast] # unzip_map,
@testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzipmap, unzipbroadcast, unzip_map, unzip_map_reversed]
@test_throws Exception fun(sqrt, 1:3)

@test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6])
@test fun(tuple, [1, 10, 100]) == ([1, 10, 100],)
@test fun(tuple, 1:3, fill(nothing, 3)) == (1:3, fill(nothing, 3))
@test fun(tuple, [1, 10, 100], fill(nothing, 3)) == ([1, 10, 100], fill(nothing, 3))
@test fun(tuple, fill(nothing, 3), fill(nothing, 3)) == (fill(nothing, 3), fill(nothing, 3))
@test @inferred(fun(tuple, 1:3, 4:6)) == ([1, 2, 3], [4, 5, 6])
@test @inferred(fun(tuple, [1, 10, 100])) == ([1, 10, 100],)
@test @inferred(fun(tuple, 1:3, fill(nothing, 3))) == (1:3, fill(nothing, 3))
@test @inferred(fun(tuple, [1, 10, 100], fill(nothing, 3))) == ([1, 10, 100], fill(nothing, 3))
@test @inferred(fun(tuple, fill(nothing, 3), fill(nothing, 3))) == (fill(nothing, 3), fill(nothing, 3))

if contains(string(fun), "map")
@test fun(tuple, 1:3, 4:999) == ([1, 2, 3], [4, 5, 6])
@test @inferred(fun(tuple, 1:3, 4:999)) == ([1, 2, 3], [4, 5, 6])
else
@test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5])
@test @inferred(fun(tuple, [1,2,3], [4 5])) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5])
@test @inferred(fun(tuple, [1,2,3], 6)) == ([1, 2, 3], [6, 6, 6])
end

if contains(string(fun), "map")
Expand All @@ -24,7 +25,24 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map
@test fun(tuple, (1,2,3), (7,)) == ((1, 2, 3), (7, 7, 7))
@test fun(tuple, (1,2,3), 8) == ((1, 2, 3), (8, 8, 8))
end
@test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector
@test @inferred(fun(tuple, (1,2,3), [4,5,6])) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector
end

@testset "zip behaviour: $unzip_map" for unzip_map in [unzip_map, unzip_map_reversed]
check(f, args...) = @inferred(unzip_map(f, args...)) == unzip(map(f, args...))
@test check(tuple, [1 2; 3 4], [5,6,7,8]) # makes a vector
@test check(tuple, [1 2; 3 4], [5,6,7])
@test check(tuple, [1 2; 3 4], [5,6,7,8,9,10])
end

@testset "unzip_map_reversed" begin
cnt(x, y) = (x, y) .+ (CNT[] += 1)
CNT = Ref(0)
@test unzip_map_reversed(cnt, [10, 20], [30, 40, 50]) == ([12, 21], [32, 41])
@test CNT[] == 2

CNT = Ref(0)
@test unzip_map_reversed(cnt, (10, 20, 99), (30, 40)) == ((12, 21), (32, 41))
end

@testset "rrules" begin
Expand Down

0 comments on commit fe779b3

Please sign in to comment.