-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BACKEND] Implement 3xTF32 trick #3234
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome.
@@ -62,6 +62,7 @@ class CUDAOptions: | |||
ptx_version: int = None | |||
enable_fp_fusion: bool = True | |||
allow_fp8e4nv: bool = False | |||
allow_tf32: bool = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about the name of this option. Without the context in this patch, I'd think this means, do we allow implicitly upgrading fp32 dots to tf32, which is not what this controls.
Perhaps supports_tf32, with a comment saying that this indicates whether the hw supports tf32 but doesn't give us permission to use tf32 to silently, except where the result is "as if" we'd computed in fp32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I wasn't sure about this name. I followed the convention of allow_fp8e4nv
, but I agree that supports_tf32
is a much better name. Let me change it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there is going to be a precision problem. Maybe we would need a separate control for it but we want the precision to match fp32
Some comments:
|
A few updates:
As for how to expose this. I agree that it'd be good to give the user the possibility to use the IEEE version. I can either add a new |
It depends if we consider it to be a library level or a backend optimization.
The downside of this strategy is that we need to add another attribute to language. As @sbodenstein pointed out there may be a bunch of other algorithm and we don't want to keep adding attributes to the language.
The current default behavior is tf32, are you suggesting changing it to 3xtf32? Or are you saying it should be default if use_tf32=False? In this case default behavior doesn't mean much and the best would be having an enum of precisions. I would still think a library solution if possible would be nice. |
I have a different thought regarding this. I'm curious about what optimizations you'd like to implement? Maybe we can do this:
|
I think the concern is that if the user writes You could write the user code differently, but I think he's trying to make "reasonable user code" do the right thing. Which, given our goal of portability, makes sense to me.
We already have a tf32 attribute. It actually sounds like he wants to generalize this into OTOH if we don't do this, then it seems to me that every hardware vendor with its different dot precisions is going to want attributes on dot similar to the use_tf32 that we have today. IOW it feels like this is a solution to the problem of attribute-creep? |
Exactly. If you write If you expect the user to always write |
I personally think that replacing However, I do think the default for Leaving the perf considerations aside, doing it in user code with the current limitations of triton metaprogramming seems troublesome. That would force users to carry over constexpr flags that depend on the hardware they're targeting |
I completely agree with Philippe's comment. I'll go ahead and implement that. Yes, I meant the default behaviour for |
Updated the PR. In particular:
The PR may be easier to review commit-by-commit, given the amount of files it touches. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks good, added couple minor comments
@@ -565,7 +565,7 @@ def TT_DotOp : TT_Op<"dot", [Pure, | |||
TT_FpIntTensor:$a, | |||
TT_FpIntTensor:$b, | |||
TT_FpIntTensor:$c, | |||
BoolAttr:$allowTF32, | |||
StrAttr:$f32Backend, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: making it an enum would be a bit less error-prone. You can use I32EnumAttr
for that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use an enum, do we have to list all possible values in TritonOps.td? That doesn't work well for out-of-tree backends.
(I think it would be good to have the list of valid strings written down somewhere in the nvidia backend, though.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct the enum has to be in TritonOps.td, it still feels better than doing string checks? I believe we could decouple it using interfaces but that's probably an overkill to do it right now. Even if we have a lit of valid strings, it feels easy to make a typo when checkin the value.
I'd be fine with the possible constants be declared in a header as constexpr and used instead of literal strings
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct the enum has to be in TritonOps.td, it still feels better than doing string checks?
I just thought we had it as a goal to make it possible for people to write out-of-tree backends? If you think that's not relevant here, then definitely enums would be better.
I'd be fine with the possible constants be declared in a header as constexpr and used instead of literal strings
The problem is there's no way to enforce this, and people will use strings anyway. (And indeed the IR uses a string, so that's the place where we're most likely to have a typo.) So I think there's an argument for leaning into the strings in the C++ and accepting that we need to write tests. At least, that is my experience with this sort of thing from XLA, where it's used a lot.
Again I do think it needs to be documented in the backend which strings are acceptable. Right now I only see it in the frontend. I also think (orthogonal to all this) that we need a check during lowering that we know how to lower the string we got (i.e. check for invalid strings during lowering -- I don't see this anywhere) in order to catch typos in the IR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I went with the enum implementation, as we already use enums in quite a few places. Going from enums to any other representation in the future should be trivial tho.
@@ -565,7 +565,7 @@ def TT_DotOp : TT_Op<"dot", [Pure, | |||
TT_FpIntTensor:$a, | |||
TT_FpIntTensor:$b, | |||
TT_FpIntTensor:$c, | |||
BoolAttr:$allowTF32, | |||
StrAttr:$f32Backend, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps "f32Backend" isn't the best name. We can use this field for the f16 (or whatever) backend too, right? No need to have multiple fields depending on the dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed #3234 (comment). I'm happy to bikeshed if a better name / convention is proposed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
$precision
seems good to me.
@@ -565,7 +565,7 @@ def TT_DotOp : TT_Op<"dot", [Pure, | |||
TT_FpIntTensor:$a, | |||
TT_FpIntTensor:$b, | |||
TT_FpIntTensor:$c, | |||
BoolAttr:$allowTF32, | |||
StrAttr:$f32Backend, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we use an enum, do we have to list all possible values in TritonOps.td? That doesn't work well for out-of-tree backends.
(I think it would be good to have the list of valid strings written down somewhere in the nvidia backend, though.)
Addressed the reviews. Should we bikeshed the name tho? @jlebar proposes "precision" for it to be more generic. |
The name what proposed by @ptillet who is in vacation until Monday. Not sure we can discuss it without him. Other than the name where I don't have a strong opinion I think this looks good. Note that I did add an env var override for TF32 behavior (based on user request): #3290 which might be in your way. If you want to rebase with it that's great but feel free to break it and I'll add it back after your change |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That looks good to me modulo the naming decision
If you're not comfortable making an executive decision, then I propose we use the current name and I volunteer as tribute to run |
tf32: use TC with tf32 ops. | ||
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp | ||
ieee: don't use TC, implement dot in software. | ||
If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Presumably these descriptions should be on the enum now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment still applies?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I figured it'd be better to leave them in the place where one would first look at when finding what this flag really does, but sure, I can move them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example the enum is also used on dot_async.
(If you wanted a "see X" comment I'd be fine with that, although it's probably pretty obvious where to look.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, let's wait @jlebar to approve as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks generally good to me, but I think we need to rename f32_backend to input_precision everywhere. That and a few other smallish changes.
OAI is going on vacation next week and there's some ongoing discussion about whether we want to merge this on the Friday before everyone is out or if that's a Bad Idea. In any case if we decide this should wait there shouldn't be as many merge conflicts as usual.
tf32: use TC with tf32 ops. | ||
tf32x3: implement the 3xTF32 trick. For more info see the pass in F32DotTC.cpp | ||
ieee: don't use TC, implement dot in software. | ||
If the GPU does not have Tensor cores or the inputs are not f32, this flag is ignored. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this comment still applies?
@@ -1473,16 +1473,16 @@ def dot(input, other, acc=None, allow_tf32=None, max_num_imprecise_acc=None, out | |||
:type input: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} | |||
:param other: The second tensor to be multiplied. | |||
:type other: 2D tensor of scalar-type in {:code:`float16`, :code:`bfloat16`, :code:`float32`} | |||
:param input_precision: How to exercise the Tenors cores for f32 x f32. If the device does not have Tensor Cores | |||
or the inputs are not of dtype f32, this option is ignored. | |||
:type other: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Avaliable options for amd: :code:`"ieee"`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the one place where we describe in the user documentation what these fields mean. Seems like we should write a lot more and cite sources for further reading.
test/Analysis/test-alias.mlir
Outdated
@@ -29,7 +29,7 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, | |||
%a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> | |||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL> | |||
%b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> | |||
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, allowTF32 = true, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> | |||
%c = tt.dot %a, %b, %prev_c {transA = false, transB = false, f32Backend = 0 : i32, maxNumImpreciseAcc = 0 : i32} : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like Thomas does not have time to help out with this right now. We're OK merging as-is and figuring out how to fix this as a follow-up.
The result of the discussion about waiting a week is: If we keep the |
(We don't have a full story on bw compatibility, but for us, having a week or two where the old and new APIs are usable makes integration easier. Because the next release will be 3.0, we're ok making bw-incompat changes in general.) |
Alas, I'm leaving now (European time) and I don't have time to do those changes. If anyone wants to champion this through the finish line, feel free to push to my branch to get it merged. |
Sorry we've been so slow on this one. I'll probably be working some next week and I have approval from folks to merge this. You have been very patient and I don't want to keep you in rebase hell indefinitely. |
fwiw, I'll be on PTO on Monday / Tuesday, will be back on Wed. If that's too late for you, we can work on mergin this the Monday after. No rush. |
e614d5b
to
e3dde35
Compare
Just kicked CI for you. Thanks again for pushing through with this one. |
This PR implements the [3xTF32 trick](NVIDIA/cutlass#385) to make use of the TCs on F32 tensors without sacrificing accuracy. This is particularly relevant for PyTorch, as TF32 is off by default. Benchmarks on A100 from `python/tutorials/03-matrix-multiplication.py` run on `float32` data using `use_tf32=False`: ``` M N K cuBLAS Triton This PR 0 256.0 256.0 256.0 1.927529 1.092267 1.489455 1 384.0 384.0 384.0 5.026909 3.567484 3.686400 2 512.0 512.0 512.0 8.192000 6.553600 6.898527 3 640.0 640.0 640.0 12.190476 10.448980 10.666666 4 768.0 768.0 768.0 13.405091 10.287628 14.503869 5 896.0 896.0 896.0 14.049280 13.380267 20.070399 6 1024.0 1024.0 1024.0 15.887515 12.264046 19.239927 7 1152.0 1152.0 1152.0 16.681475 15.633424 24.883201 8 1280.0 1280.0 1280.0 16.516129 15.340824 28.248276 9 1408.0 1408.0 1408.0 17.090206 14.774461 24.016635 10 1536.0 1536.0 1536.0 17.014154 15.624477 26.021647 11 1664.0 1664.0 1664.0 17.043394 15.073554 25.858942 12 1792.0 1792.0 1792.0 17.107190 16.171833 29.577431 13 1920.0 1920.0 1920.0 17.883570 15.762828 26.331430 14 2048.0 2048.0 2048.0 17.623127 17.032706 27.413751 15 2176.0 2176.0 2176.0 17.887688 16.686275 29.945905 16 2304.0 2304.0 2304.0 19.019006 17.933838 33.787654 17 2432.0 2432.0 2432.0 17.940270 17.288901 31.181425 18 2560.0 2560.0 2560.0 18.164080 17.075561 31.844508 19 2688.0 2688.0 2688.0 17.594183 16.703239 30.370742 20 2816.0 2816.0 2816.0 18.766871 18.089676 33.242537 21 2944.0 2944.0 2944.0 18.735350 17.855977 33.695763 22 3072.0 3072.0 3072.0 18.420008 17.766898 32.768000 23 3200.0 3200.0 3200.0 18.470418 17.704011 33.255391 24 3328.0 3328.0 3328.0 18.253370 17.710036 32.753092 25 3456.0 3456.0 3456.0 18.546485 17.793328 33.634362 26 3584.0 3584.0 3584.0 18.368824 17.833278 33.142423 27 3712.0 3712.0 3712.0 18.665424 17.938112 34.036574 28 3840.0 3840.0 3840.0 18.638578 18.076496 33.794348 29 3968.0 3968.0 3968.0 18.965486 18.190808 34.324595 30 4096.0 4096.0 4096.0 19.035276 18.365864 34.450135 ``` It's an overall win, getting roughly a 85% speed-up on large sizes. Note that the rounding is differs a little bit to the one [implemented in CUTLASS](https://github.com/NVIDIA/cutlass/blob/a8f2c80db0564c74f4efccac71993b971dfc448b/include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h#L99-L100). We could implement that rounding if we wanted though. This is still a bit far from the 2x speed-ups announced by CUTLASS. To get close to those numbers, we should probably need to remove the stores to shared before `ldmatrix`.
This PR implements the 3xTF32 trick to make use of the TCs on F32 tensors without sacrificing accuracy. This is particularly relevant for PyTorch, as TF32 is off by default.
Benchmarks on A100 from
python/tutorials/03-matrix-multiplication.py
run onfloat32
data usinguse_tf32=False
:It's an overall win, getting roughly a 85% speed-up on large sizes.
Note that the rounding is differs a little bit to the one implemented in CUTLASS. We could implement that rounding if we wanted though.
This is still a bit far from the 2x speed-ups announced by CUTLASS. To get close to those numbers, we should probably need to remove the stores to shared before
ldmatrix
.