-
Notifications
You must be signed in to change notification settings - Fork 87
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Try launching unit tests on TPUs from CI #596
Conversation
cc @rjpower since I fixed up a few tests on TPU (mostly just scaled down some initializers?) |
Curious... not surprising that the numerics would be slightly different for splash attention but interesting how changing the initialization helps that much. I think we expect a lot of noise through the matmul unit so this is likely what we expect. I wonder if the scaling is implicitly adjusting the tolerance here, if that makes sense? That is, in the original, our output might be on the scale of 1000.00 {+- 0.5}, and now it's 1.00000 { +- 0.5.}? We're accurate to the same number of digits in both cases, just the digits that we're testing against are different. As a dumb example let's assume one attention went through bf16 and the other didn't, we'd see something like this through the matmuls:
|
one of the fixed tests is actually my "pure jax" flash attention vs vanilla dot product, so it's the same algorithm on both CPU and TPU, but the matmuls are different enough on TPU that it matters, even with HIGHEST. I mostly agree with your hypothesis, though I framed it in terms of "floating point numbers are more precise near 0" Floating point is so annoying |
No description provided.