From fe779b33f0705b5ccf67716f2d53f177cf9bcda8 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sat, 20 Aug 2022 12:57:26 -0400 Subject: [PATCH] day two --- src/rulesets/Base/base.jl | 30 +++++------------ src/rulesets/Base/broadcast.jl | 25 ++++++++++++-- src/rulesets/Base/iterators.jl | 41 ++++++++--------------- src/unzipped.jl | 59 +++++++++++++++++++++++++++++++++ test/rulesets/Base/base.jl | 34 +++++++++++++++++++ test/rulesets/Base/iterators.jl | 26 +++++++++++++++ test/runtests.jl | 1 + test/unzipped.jl | 38 +++++++++++++++------ 8 files changed, 193 insertions(+), 61 deletions(-) create mode 100644 test/rulesets/Base/iterators.jl diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 94a90bd9f..c385e87ea 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -249,34 +249,22 @@ end ##### function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray) where {F} - y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # could be broadcast, but Yota likes this one - return Broadcast.materialize(y), back + # y, back = rrule_via_ad(cfg, Broadcast.broadcasted, f, x) # Yota likes this one + # return Broadcast.materialize(y), back + y, back = rrule_via_ad(cfg, broadcast, f, x) # but testing like this one + return y, back end -# Could accept Any? -# `_unmap_pad` is also used for `zip` function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} + @debug "rrule(map, f, arrays...)" f z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...) - # z, backs = unzip(map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...)) - function map_pullback(dz) - df, dxy... = unzip_map(|>, unthunk(dz), backs) - # df, dxy... = unzip(map(|>, unthunk(dz), backs)) - return (NoTangent(), ProjectTo(sum(df)), map(_unmap_pad, (x, ys...), dxy)...) + function map_pullback_2(dz) + df, dxy... = unzip_map_reversed(|>, unthunk(dz), backs) + return (NoTangent(), ProjectTo(f)(sum(df)), map(_unmap_pad, (x, ys...), dxy)...) end - z, map_pullback + z, map_pullback_2 end -# function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F} -# z, zip_back = rrule(zip, x, ys...) -# m, map_back = rrule(config, map, Splat(f), z) # maybe this is inefficient? -# function map_pullback(dm) -# _, dsplatf, dz = map_back(dm) -# _, dxys... = zip_back(dz) -# return (NoTangent(), 0, dxys...) -# end -# return m, map_back -# end - ##### ##### `task_local_storage` ##### diff --git a/src/rulesets/Base/broadcast.jl b/src/rulesets/Base/broadcast.jl index ddf4dc426..0aae75ab5 100644 --- a/src/rulesets/Base/broadcast.jl +++ b/src/rulesets/Base/broadcast.jl @@ -119,8 +119,27 @@ function split_bc_inner(frule_fun::R, cfg::RuleConfig, f::F, arg) where {R,F} end # Path 4: The most generic, save all the pullbacks. Can be 1000x slower. -# Since broadcast makes no guarantee about order of calls, and un-fusing -# can change the number of calls, don't bother to try to reverse the iteration. +# While broadcast makes no guarantee about order of calls, it's cheap to reverse the iteration. + +#= + +julia> Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), [1,2,3.0]) +┌ Debug: split broadcasting generic +│ f = #69 (generic function with 1 method) +│ N = 1 +└ @ ChainRules ~/.julia/dev/ChainRules/src/rulesets/Base/broadcast.jl:126 +(14.0, (ZeroTangent(), [2.0, 4.0, 6.0])) + +julia> ENV["JULIA_DEBUG"] = nothing + +julia> @btime Yota.grad(xs -> sum(abs2, (x -> abs(x)).(xs)), $(rand(1000))); + min 1.321 ms, mean 1.434 ms (23010 allocations, 594.66 KiB) # with unzip_map, as before + min 1.279 ms, mean 1.393 ms (23029 allocations, 595.73 KiB) # with unzip_map_reversed + +julia> @btime Yota.grad(xs -> sum(abs2, abs.(xs)), $(randn(1000))); # Debug: split broadcasting derivative + min 2.144 μs, mean 6.620 μs (6 allocations, 23.88 KiB) + +=# function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} @debug("split broadcasting generic", f, N) @@ -128,7 +147,7 @@ function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} rrule_via_ad(cfg, f, a...) end function back_generic(dys) - deltas = unzip_broadcast(backs, unthunk(dys)) do back, dy # (could be map, sizes match) + deltas = unzip_map_reversed(backs, unthunk(dys)) do back, dy map(unthunk, back(dy)) end dargs = map(unbroadcast, args, Base.tail(deltas)) diff --git a/src/rulesets/Base/iterators.jl b/src/rulesets/Base/iterators.jl index 8c801aff2..f264b53d0 100644 --- a/src/rulesets/Base/iterators.jl +++ b/src/rulesets/Base/iterators.jl @@ -1,4 +1,4 @@ -tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor +tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/JuliaDiff/Diffractor.jl/pull/86 ##### ##### Comprehension: Iterators.map @@ -7,38 +7,18 @@ tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor # Comprehension does guarantee iteration order. Thus its gradient must reverse. function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator} - # ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter) - ys, backs = unzip(map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter)) + @debug "collect generator" + ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter) proj_f = ProjectTo(gen.f) proj_iter = ProjectTo(gen.iter) function generator_pullback(dys_raw) dys = unthunk(dys_raw) - # dfs, dxs = unzip_map(|>, Iterators.reverse(dys), Iterators.reverse(backs)) - dfs, dxs = unzip(map(|>, Iterators.reverse(dys), Iterators.reverse(backs))) - return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(reverse!!(dxs)))) + dfs, dxs = unzip_map_reversed(|>, dys, backs) + return (NoTangent(), Tangent{G}(; f = proj_f(sum(dfs)), iter = proj_iter(dxs))) end ys, generator_pullback end -""" - reverse!!(x) - -Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`. -Only safe if you are quite sure nothing else closes over `x`. -""" -function reverse!!(x::AbstractArray) - if ChainRulesCore.is_inplaceable_destination(x) - Base.reverse!(x) - else - Base.reverse(x) - end -end -frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot) -function rrule(::typeof(reverse!!), x::AbstractArray) - reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy))) - return reverse!!(x), reverse!!_back -end - # Needed for Yota, but shouldn't these be automatic? ChainRulesCore.rrule(::Type{<:Base.Generator}, f, iter) = Base.Generator(f, iter), dy -> (NoTangent(), dy.f, dy.iter) ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.ProductIterator(iters), dy -> (NoTangent(), dy.iterators) @@ -107,12 +87,15 @@ function rrule(::typeof(zip), xs::AbstractArray...) end _tangent_unzip(xs::AbstractArray{Tangent{T,B}}) where {T<:Tuple, B<:Tuple} = unzip(reinterpret(B, xs)) -_tangent_unzip(xs::AbstractArray) = unzip(xs) # Diffractor +_tangent_unzip(xs::AbstractArray) = unzip(xs) # temp fix for Diffractor +# This is like unbroadcast, except for map's stopping-short behaviour, not broadcast's extension. +# Closing over `x` lets us re-use ∇getindex. function _unmap_pad(x::AbstractArray, dx::AbstractArray) if length(x) == length(dx) ProjectTo(x)(reshape(dx, axes(x))) else + @debug "_unmap_pad is extending gradient" length(x) == length(dx) i1 = firstindex(x) ∇getindex(x, vec(dx), i1:i1+length(dx)-1) # dx2 = vcat(vec(dx), similar(x, ZeroTangent, length(x) - length(dx))) @@ -120,5 +103,9 @@ function _unmap_pad(x::AbstractArray, dx::AbstractArray) end end - +# For testing +function rrule(::ComposedFunction{typeof(collect), typeof(zip)}, xs::AbstractArray...) + y, back = rrule(zip, xs...) + return collect(y), back +end diff --git a/src/unzipped.jl b/src/unzipped.jl index 8da3c30fd..6fcc7eecf 100644 --- a/src/unzipped.jl +++ b/src/unzipped.jl @@ -85,6 +85,65 @@ function unzip_map(f::F, args...) where {F} return StructArrays.components(StructArray(Iterators.map(f, args...))) end +unzip_map(f::F, args::Tuple...) where {F} = unzip(map(f, args...)) + +unzip_map(f::F, args::AbstractGPUArray...) where {F} = unzip(map(f, args...)) + +function unzip_map_reversed(f::F, args...) where {F} + T = Broadcast.combine_eltypes(f, args) + if isconcretetype(T) + T <: Tuple || throw(ArgumentError("""unzip_map_reversed(f, args) only works on functions returning a tuple, + but f = $(sprint(show, f)) returns type T = $T""")) + end + len1 = length(first(args)) + if all(a -> length(a)==len1, args) + rev_args = map(Iterators.reverse, args) + outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...))) + else + len = minimum(length, args) + rev_args = map(a -> Iterators.reverse(@view a[begin:begin+len-1]), args) + outs = StructArrays.components(StructArray(Iterators.map(f, rev_args...))) + end + return map(reverse!!, outs) +end + +function unzip_map_reversed(f::F, args::Tuple...) where {F} + len = minimum(length, args) + rev_args = map(a -> reverse(a[1:len]), args) + # vlen = Val(len) + # rev_args = map(args) do a + # reverse(ntuple(i -> a[i], vlen)) # does not infer better + # end + return map(reverse, unzip(map(f, rev_args...))) +end +# function unzip_map_reversed(f::F, args::Tuple{Vararg{Any, N}}...) where {F,N} +# rev_args = map(reverse, args) +# return map(reverse, unzip(map(f, rev_args...))) +# end + +""" + reverse!!(x) + +Reverses `x` in-place if possible, according to `ChainRulesCore.is_inplaceable_destination`. +Only safe if you are quite sure nothing else closes over `x`. +""" +function reverse!!(x::AbstractArray) + if ChainRulesCore.is_inplaceable_destination(x) + Base.reverse!(x) + else + Base.reverse(x) + end +end +reverse!!(x::AbstractArray{<:AbstractZero}) = x + +frule((_, xdot), ::typeof(reverse!!), x::AbstractArray) = reverse!!(x), reverse!!(xdot) + +function rrule(::typeof(reverse!!), x::AbstractArray) + reverse!!_back(dy) = (NoTangent(), reverse(unthunk(dy))) + return reverse!!(x), reverse!!_back +end + + ##### ##### unzip ##### diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 36452da1e..a7c166b0f 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -229,4 +229,38 @@ test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false) end end + + @testset "map(f, ::Array)" begin + test_rrule(map, identity, [1.0, 2.0], check_inferred=false) + test_rrule(map, conj, [1, 2+im, 3.0]', check_inferred=false) + test_rrule(map, make_two_vec, [4.0, 5.0 + 6im], check_inferred=false) + # @test rrule(CFG, map, make_two_vec, [4.0, 5.0 + 6im])[2]([1:2, 3:4])[3] ≈ [1 + 2im, 3 + 4im] # FiniteDifferences DimensionMismatch + + @test_skip test_rrule(map, Multiplier(rand() + im), rand(3), check_inferred=false) + rrule(CFG, map, Multiplier(2.0), [3, 4, 5.0])[2]([10, 20, 30]) # (NoTangent(), Multiplier{Float64}(259.99999), [19.99999, 40.000, 60.000]) -- WTF? + @test_skip test_rrule(map, Multiplier(rand() + im) ⊢ NoTangent(), rand(3), check_inferred=false) # Expression: ad_cotangent isa NoTangent Evaluated: Multiplier{ComplexF64}(-3.7869064372333963 + 2.046139872866103im) isa NoTangent + + y1, bk1 = rrule(CFG, map, abs2, [1.0, 2.0, 3.0]) + @test y1 == [1, 4, 9] + @test bk1([4, 5, 6.0])[3] ≈ 2 .* (1:3) .* (4:6) + + y2, bk2 = rrule(CFG, map, Counter(), [11, 12, 13.0]) + @test y2 == map(Counter(), 11:13) + @test_skip bk2(ones(3))[3] == [93, 83, 73] # FiniteDifferences has incremented the counter very high + end + + @testset "map(f, ::Array, ::Array)" begin + test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent} + test_rrule(map, /, [1.0, 2.0], [3.0, 4.0, 5.0], check_inferred=false) + test_rrule(map, atan, [1, 2, 3.0], [4 5; 6 7.0], check_inferred=false) + + test_rrule(map, Multiplier(rand()), rand(3), rand(4), check_inferred=false) + + cnt3 = Counter() + y3, bk3 = rrule(CFG, map, cnt3, [1, 2, 3.0], [0, -1, -2, -33.3]) + @test y3 == 1:3 + @test cnt3 == Counter(3) + z3 = bk3([1, 1, 1000]) + @test z3[3] == [53, 33, 13000] + end end diff --git a/test/rulesets/Base/iterators.jl b/test/rulesets/Base/iterators.jl new file mode 100644 index 000000000..d9060d985 --- /dev/null +++ b/test/rulesets/Base/iterators.jl @@ -0,0 +1,26 @@ + +@testset "Comprehension" begin + @testset "simple" begin + y1, bk1 = rrule(CFG, collect, (i^2 for i in [1.0, 2.0, 3.0])) + @test y1 == [1,4,9] + t1 = bk1(4:6)[2] + @test t1 isa Tangent{<:Base.Generator} + @test t1.f == NoTangent() + @test t1.iter ≈ 2 .* (1:3) .* (4:6) + + y2, bk2 = rrule(CFG, collect, Iterators.map(Counter(), [11, 12, 13.0])) + @test y2 == map(Counter(), 11:13) + @test bk2(ones(3))[2].iter == [93, 83, 73] + end +end + +@testset "Iterators" begin + @testset "zip" begin + test_rrule(collect∘zip, rand(3), rand(3)) + test_rrule(collect∘zip, rand(2,2), rand(2,2), rand(2,2)) + test_rrule(collect∘zip, rand(4), rand(2,2)) + + test_rrule(collect∘zip, rand(3), rand(5)) + test_rrule(collect∘zip, rand(3,2), rand(5)) + end +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 71444f388..7052d608e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -61,6 +61,7 @@ end include_test("rulesets/Base/mapreduce.jl") include_test("rulesets/Base/sort.jl") include_test("rulesets/Base/broadcast.jl") + include_test("rulesets/Base/iterators.jl") include_test("unzipped.jl") # used primarily for broadcast diff --git a/test/unzipped.jl b/test/unzipped.jl index 97aaa23f5..1677d7c9d 100644 --- a/test/unzipped.jl +++ b/test/unzipped.jl @@ -1,20 +1,21 @@ -using ChainRules: unzip_broadcast, unzip #, unzip_map +using ChainRules: unzip_broadcast, unzip, unzip_map, unzip_map_reversed @testset "unzipped.jl" begin - @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast] # unzip_map, + @testset "basics: $(sprint(show, fun))" for fun in [unzip_broadcast, unzip∘map, unzip∘broadcast, unzip_map, unzip_map_reversed] @test_throws Exception fun(sqrt, 1:3) - @test fun(tuple, 1:3, 4:6) == ([1, 2, 3], [4, 5, 6]) - @test fun(tuple, [1, 10, 100]) == ([1, 10, 100],) - @test fun(tuple, 1:3, fill(nothing, 3)) == (1:3, fill(nothing, 3)) - @test fun(tuple, [1, 10, 100], fill(nothing, 3)) == ([1, 10, 100], fill(nothing, 3)) - @test fun(tuple, fill(nothing, 3), fill(nothing, 3)) == (fill(nothing, 3), fill(nothing, 3)) + @test @inferred(fun(tuple, 1:3, 4:6)) == ([1, 2, 3], [4, 5, 6]) + @test @inferred(fun(tuple, [1, 10, 100])) == ([1, 10, 100],) + @test @inferred(fun(tuple, 1:3, fill(nothing, 3))) == (1:3, fill(nothing, 3)) + @test @inferred(fun(tuple, [1, 10, 100], fill(nothing, 3))) == ([1, 10, 100], fill(nothing, 3)) + @test @inferred(fun(tuple, fill(nothing, 3), fill(nothing, 3))) == (fill(nothing, 3), fill(nothing, 3)) if contains(string(fun), "map") - @test fun(tuple, 1:3, 4:999) == ([1, 2, 3], [4, 5, 6]) + @test @inferred(fun(tuple, 1:3, 4:999)) == ([1, 2, 3], [4, 5, 6]) else - @test fun(tuple, [1,2,3], [4 5]) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) + @test @inferred(fun(tuple, [1,2,3], [4 5])) == ([1 1; 2 2; 3 3], [4 5; 4 5; 4 5]) + @test @inferred(fun(tuple, [1,2,3], 6)) == ([1, 2, 3], [6, 6, 6]) end if contains(string(fun), "map") @@ -24,7 +25,24 @@ using ChainRules: unzip_broadcast, unzip #, unzip_map @test fun(tuple, (1,2,3), (7,)) == ((1, 2, 3), (7, 7, 7)) @test fun(tuple, (1,2,3), 8) == ((1, 2, 3), (8, 8, 8)) end - @test fun(tuple, (1,2,3), [4,5,6]) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector + @test @inferred(fun(tuple, (1,2,3), [4,5,6])) == ([1, 2, 3], [4, 5, 6]) # mix tuple & vector + end + + @testset "zip behaviour: $unzip_map" for unzip_map in [unzip_map, unzip_map_reversed] + check(f, args...) = @inferred(unzip_map(f, args...)) == unzip(map(f, args...)) + @test check(tuple, [1 2; 3 4], [5,6,7,8]) # makes a vector + @test check(tuple, [1 2; 3 4], [5,6,7]) + @test check(tuple, [1 2; 3 4], [5,6,7,8,9,10]) + end + + @testset "unzip_map_reversed" begin + cnt(x, y) = (x, y) .+ (CNT[] += 1) + CNT = Ref(0) + @test unzip_map_reversed(cnt, [10, 20], [30, 40, 50]) == ([12, 21], [32, 41]) + @test CNT[] == 2 + + CNT = Ref(0) + @test unzip_map_reversed(cnt, (10, 20, 99), (30, 40)) == ((12, 21), (32, 41)) end @testset "rrules" begin