Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rules for map, zip and some comprehensions #671

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Aug 25, 2022

Comment on lines 54 to 64
@btime Yota.grad($(rand(1000)), $(rand(1000))) do xs, ys
zs = map(xs, ys) do x, y
atan(x/y)
end
sum(abs2, zs)
end
# Yota + CR: min 2.643 ms, mean 2.781 ms (35011 allocations, 915.19 KiB)
# Diffractor + CR: min 1.184 ms, mean 1.285 ms (23026 allocations, 775.09 KiB)
# without unzip_map min 947.084 μs, mean 1.036 ms (18062 allocations, 697.86 KiB)

# Zygote: min 21.291 μs, mean 36.456 μs (26 allocations, 79.59 KiB)
Copy link
Member Author

@mcabbott mcabbott Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These rules seem quite slow compared to what Zygote does, 100x. Not sure why.

Adding some complexity to the function eventually causes Zygote to be slow, e.g. with x > 0 ? atan(x/y) : atan(y/x) it is worse. Not sure whether its fast cases come from a different path which doesn't save the pullbacks, or some lucky optimisation.

function zip_pullback(dy::Tangent)
@debug "zip Tangent pullback"
return (NoTangent(), dy.is...)
end
Copy link
Member Author

@mcabbott mcabbott Aug 25, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zip_pullback(dy::Tangent) is here because Zygote's rule needed this. Not sure it's tested, nor whether it is in fact required. In trying to cook up examples to hit this, using Diffractor or Yota, I just get errors.

Comment on lines +252 to +253
@testset "map(f, ::Array, ::Array)" begin
test_rrule(map, +, [1.0, 2.0], [3.0, 4.0], check_inferred=false) # NoTangent does not match Union{NoTangent, ZeroTangent}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for broadcasting, many rules fail inference tests only because of Union{NoTangent, ZeroTangent}. Why do we have two types again, I am unclear as to what difference they encode & why this is worth doing.


function split_bc_pullbacks(cfg::RCR, f::F, args::Vararg{Any,N}) where {F,N}
@debug("split broadcasting generic", f, N)
ys3, backs = unzip_broadcast(args...) do a...
rrule_via_ad(cfg, f, a...)
end
function back_generic(dys)
deltas = unzip_broadcast(backs, unthunk(dys)) do back, dy # (could be map, sizes match)
deltas = unzip_map_reversed(backs, unthunk(dys)) do back, dy
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since generic broadcasting is slow anyway, maybe I change my mind to thinking it should reverse the order of iteration. Even though the order isn't guaranteed by Julia, perhaps it's better that the rule at least fixes forward & reverse passes to match.

Copy link
Member

@oxinabox oxinabox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fell of my radar sorry.
We should merge this whenever you are happy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rules for zip missing rrule for map
2 participants