Skip to content

Commit

Permalink
rm tup2, update times
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 28, 2022
1 parent fe779b3 commit 45e2ae7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ end

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(map), f::F, x::AbstractArray, ys::AbstractArray...) where {F}
@debug "rrule(map, f, arrays...)" f
z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...)|>tup2, x, ys...)
z, backs = unzip_map((xy...) -> rrule_via_ad(cfg, f, xy...), x, ys...)
function map_pullback_2(dz)
df, dxy... = unzip_map_reversed(|>, unthunk(dz), backs)
return (NoTangent(), ProjectTo(f)(sum(df)), map(_unmap_pad, (x, ys...), dxy)...)
Expand Down
22 changes: 9 additions & 13 deletions src/rulesets/Base/iterators.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/JuliaDiff/Diffractor.jl/pull/86

#####
##### Comprehension: Iterators.map
#####
Expand All @@ -8,7 +6,7 @@ tup2(x) = Tuple{Any,Any}(x) # temp fix for Diffractor, https://github.com/Julia

function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(collect), gen::G) where {G<:Base.Generator}
@debug "collect generator"
ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x)|>tup2, gen.iter)
ys, backs = unzip_map(x -> rrule_via_ad(cfg, gen.f, x), gen.iter)
proj_f = ProjectTo(gen.f)
proj_iter = ProjectTo(gen.iter)
function generator_pullback(dys_raw)
Expand All @@ -28,8 +26,8 @@ ChainRulesCore.rrule(::Type{<:Iterators.ProductIterator}, iters) = Iterators.Pro
Yota.grad(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
Diffractor.gradient(xs -> sum(abs, [sin(x) for x in xs]), [1,2,3]pi/3)
Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5])
Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5])
Yota.grad((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: all field arrays must have same shape
Diffractor.gradient((xs, ys) -> sum(abs, [atan(x/y) for x in xs, y in ys]), [1,2,3]pi/3, [4,5]) # ERROR: type Array has no field iterators
Yota.grad(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3)
Diffractor.gradient(xs -> sum(abs, map(sin, xs)), [1,2,3]pi/3) # fails internally
Expand All @@ -44,11 +42,10 @@ Diffractor.gradient(xs -> sum(abs, map((x,y) -> sin(x/y), xs, 1:2)), [1,2,3]pi/3
@btime Yota.grad($(rand(1000))) do xs
sum(abs2, [sqrt(x) for x in xs])
end
# Yota min 1.134 ms, mean 1.207 ms (22017 allocations, 548.50 KiB)
# Diffractor min 936.708 μs, mean 1.020 ms (18028 allocations, 611.25 KiB)
# without unzip_map min 734.292 μs, mean 810.341 μs (13063 allocations, 517.97 KiB)
# Yota min 759.000 μs, mean 800.754 μs (22041 allocations, 549.62 KiB)
# Diffractor min 559.000 μs, mean 622.464 μs (18051 allocations, 612.34 KiB)
# Zygote min 6.117 μs, mean 11.287 μs (24 allocations, 40.31 KiB)
# Zygote min 3.198 μs, mean 6.849 μs (20 allocations, 40.11 KiB)
@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys
Expand All @@ -57,11 +54,10 @@ end
end
sum(abs2, zs)
end
# Yota + CR: min 2.643 ms, mean 2.781 ms (35011 allocations, 915.19 KiB)
# Diffractor + CR: min 1.184 ms, mean 1.285 ms (23026 allocations, 775.09 KiB)
# without unzip_map min 947.084 μs, mean 1.036 ms (18062 allocations, 697.86 KiB)
# Yota + CR: min 1.598 ms, mean 1.691 ms (38030 allocations, 978.75 KiB)
# Diffractor + CR: min 767.250 μs, mean 847.640 μs (26045 allocations, 838.66 KiB)
# Zygote: min 21.291 μs, mean 36.456 μs (26 allocations, 79.59 KiB)
# Zygote: min 13.417 μs, mean 22.896 μs (26 allocations, 79.59 KiB) -- 100x faster
=#
Expand Down

0 comments on commit 45e2ae7

Please sign in to comment.