Skip to content

Triton and Torch matmul result do not really match? #1334

Answered by bertmaher
zw2326 asked this question in Q&A
Discussion options

You must be logged in to vote

So, I repro'ed this, and it seemed like an oddly large difference to me, until I remembered that Triton is almost certainly using tf32, and that while torch may be using tf32 (if you torch.set_float32_matmul_precision("high")), it may not be! In fact for a problem this small, I observe ampere_sgemm_* kernels in the profile, which are fp32 (not tf32) kernels.

Comparing gemms for equality is hard :). I usually like an approach like this: https://twitter.com/bwasti/status/1621370782436687872 Basically, torch.randn(shape)+1.0)/k

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@zw2326
Comment options

Answer selected by zw2326
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants