-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rules for map
, zip
and some comprehensions
#671
base: main
Are you sure you want to change the base?
Conversation
src/rulesets/Base/iterators.jl
Outdated
@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys | ||
zs = map(xs, ys) do x, y | ||
atan(x/y) | ||
end | ||
sum(abs2, zs) | ||
end | ||
# Yota + CR: min 2.643 ms, mean 2.781 ms (35011 allocations, 915.19 KiB) | ||
# Diffractor + CR: min 1.184 ms, mean 1.285 ms (23026 allocations, 775.09 KiB) | ||
# without unzip_map min 947.084 μs, mean 1.036 ms (18062 allocations, 697.86 KiB) | ||
|
||
# Zygote: min 21.291 μs, mean 36.456 μs (26 allocations, 79.59 KiB) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These rules seem quite slow compared to what Zygote does, 100x. Not sure why.
Adding some complexity to the function eventually causes Zygote to be slow, e.g. with x > 0 ? atan(x/y) : atan(y/x)
it is worse. Not sure whether its fast cases come from a different path which doesn't save the pullbacks, or some lucky optimisation.
function zip_pullback(dy::Tangent) | ||
@debug "zip Tangent pullback" | ||
return (NoTangent(), dy.is...) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
zip_pullback(dy::Tangent)
is here because Zygote's rule needed this. Not sure it's tested, nor whether it is in fact required. In trying to cook up examples to hit this, using Diffractor or Yota, I just get errors.
@testset "map(f, ::Array, ::Array)" begin | ||
test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for broadcasting, many rules fail inference tests only because of Union{NoTangent, ZeroTangent}
. Why do we have two types again, I am unclear as to what difference they encode & why this is worth doing.
|
||
function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N} | ||
@debug("split broadcasting generic", f, N) | ||
ys3, backs = unzip_broadcast(args...) do a... | ||
rrule_via_ad(cfg, f, a...) | ||
end | ||
function back_generic(dys) | ||
deltas = unzip_broadcast(backs, unthunk(dys)) do back, dy # (could be map, sizes match) | ||
deltas = unzip_map_reversed(backs, unthunk(dys)) do back, dy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since generic broadcasting is slow anyway, maybe I change my mind to thinking it should reverse the order of iteration. Even though the order isn't guaranteed by Julia, perhaps it's better that the rule at least fixes forward & reverse passes to match.
45e2ae7
to
7f56d8d
Compare
b88d70f
to
4f144fd
Compare
d6b38e9
to
1a948c1
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This fell of my radar sorry.
We should merge this whenever you are happy
Closes #507, closes #314
Some commit made dfdx/Yota.jl#78 work, but it needs these lines:
4f144fd#diff-dbac688fe6656bd395f7931766da39b9a03d2e640f1dcc9864cddf506bbf65e1L20-L21