Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try launching unit tests on TPUs from CI #596

Merged
merged 31 commits into from
May 22, 2024
Merged

Try launching unit tests on TPUs from CI #596

merged 31 commits into from
May 22, 2024

Conversation

dlwh
Copy link
Member

@dlwh dlwh commented May 20, 2024

No description provided.

@dlwh dlwh merged commit 57bbadf into main May 22, 2024
5 checks passed
@dlwh dlwh deleted the tpu_ci branch May 22, 2024 06:35
@dlwh
Copy link
Member Author

dlwh commented May 22, 2024

cc @rjpower since I fixed up a few tests on TPU (mostly just scaled down some initializers?)

@rjpower
Copy link
Collaborator

rjpower commented May 22, 2024

Curious... not surprising that the numerics would be slightly different for splash attention but interesting how changing the initialization helps that much. I think we expect a lot of noise through the matmul unit so this is likely what we expect.

I wonder if the scaling is implicitly adjusting the tolerance here, if that makes sense? That is, in the original, our output might be on the scale of 1000.00 {+- 0.5}, and now it's 1.00000 { +- 0.5.}? We're accurate to the same number of digits in both cases, just the digits that we're testing against are different. As a dumb example let's assume one attention went through bf16 and the other didn't, we'd see something like this through the matmuls:

import numpy as np
from paddle_bfloat import bfloat16

x = np.random.normal(-1, 1, 1000)
xs = x * 0.02

b = x.astype(bfloat16)
bs = xs.astype(bfloat16)

print(x@x - b@b)
print(xs@xs - bs@bs)
-0.1878888127475875
8.110935221994353e-05

@dlwh
Copy link
Member Author

dlwh commented May 22, 2024

one of the fixed tests is actually my "pure jax" flash attention vs vanilla dot product, so it's the same algorithm on both CPU and TPU, but the matmuls are different enough on TPU that it matters, even with HIGHEST.

I mostly agree with your hypothesis, though I framed it in terms of "floating point numbers are more precise near 0"

Floating point is so annoying

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants