Skip to content

Commit

Permalink
[FRONTEND] Implement 3xTF32 trick
Browse files Browse the repository at this point in the history
This PR implements the [3xTF32 trick](NVIDIA/cutlass#385) to make use of the TCs on F32
tensors without sacrificing accuracy.

Benchmarks on A100 from `python/tutorials/03-matrix-multiplication.py`
run on `float32` data using `use_tf32=False`:
```
         M       N       K     cuBLAS     Triton    This PR
0    256.0   256.0   256.0   1.927529   1.092267   0.799220
1    384.0   384.0   384.0   5.026909   3.567484   2.457600
2    512.0   512.0   512.0   8.192000   6.553600   5.698782
3    640.0   640.0   640.0  12.190476  10.448980  10.039216
4    768.0   768.0   768.0  13.405091  10.287628  13.405091
5    896.0   896.0   896.0  14.049280  13.380267  17.133269
6   1024.0  1024.0  1024.0  15.887515  12.264046  14.563555
7   1152.0  1152.0  1152.0  16.681475  15.633424  17.360372
8   1280.0  1280.0  1280.0  16.516129  15.340824  21.557894
9   1408.0  1408.0  1408.0  17.090206  14.774461  17.473642
10  1536.0  1536.0  1536.0  17.014154  15.624477  19.285798
11  1664.0  1664.0  1664.0  17.043394  15.073554  19.228444
12  1792.0  1792.0  1792.0  17.107190  16.171833  22.433981
13  1920.0  1920.0  1920.0  17.883570  15.762828  19.280335
14  2048.0  2048.0  2048.0  17.623127  17.032706  22.046277
15  2176.0  2176.0  2176.0  17.887688  16.686275  21.909252
16  2304.0  2304.0  2304.0  19.019006  17.933838  24.935148
17  2432.0  2432.0  2432.0  17.940270  17.288901  22.730149
18  2560.0  2560.0  2560.0  18.164080  17.075561  23.043601
19  2688.0  2688.0  2688.0  17.594183  16.703239  22.015703
20  2816.0  2816.0  2816.0  18.766871  18.089676  24.189799
21  2944.0  2944.0  2944.0  18.735350  17.855977  24.513541
22  3072.0  3072.0  3072.0  18.420008  17.766898  23.801221
23  3200.0  3200.0  3200.0  18.470418  17.704011  24.015009
24  3328.0  3328.0  3328.0  18.253370  17.710036  23.732089
25  3456.0  3456.0  3456.0  18.546485  17.793328  24.445594
26  3584.0  3584.0  3584.0  18.368824  17.833278  24.131882
27  3712.0  3712.0  3712.0  18.665424  17.938112  24.659923
28  3840.0  3840.0  3840.0  18.638578  18.076496  24.510639
29  3968.0  3968.0  3968.0  18.965486  18.190808  24.974199
30  4096.0  4096.0  4096.0  19.035276  18.365864  25.195745
```
  • Loading branch information
lezcano committed Feb 28, 2024
1 parent 20038f9 commit 2edbebd
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 0 deletions.
23 changes: 23 additions & 0 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1290,6 +1290,7 @@ def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope

def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int,
out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
kwargs = locals()

def assert_dtypes_valid(lhs_dtype, rhs_dtype, options):
if not options.allow_fp8e4nv:
Expand Down Expand Up @@ -1347,6 +1348,28 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options):
_0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0)
ret_scalar_ty = out_dtype

if not allow_tf32 and lhs.dtype.is_fp32() and rhs.dtype.is_fp32() and getattr(builder.options, "allow_tf32"):
# Implement 3xTF32 trick
def fp32_to_tf32(x):
return tl.inline_asm_elementwise(
asm="""cvt.rna.tf32.f32 $0, $2;\n\tsub.f32 $1, $2, $0;""",
constraints="=r,=r,r",
args=[x],
dtype=[tl.float32, tl.float32],
is_pure=True,
pack=1,
_builder=builder
)
xb, xs = fp32_to_tf32(lhs)
yb, ys = fp32_to_tf32(rhs)
del kwargs["lhs"]
del kwargs["rhs"]
del kwargs["allow_tf32"]
kwargs["acc"] = dot(xb, ys, allow_tf32=True, **kwargs)
kwargs["acc"] = dot(xs, yb, allow_tf32=True, **kwargs)
out = dot(xb, yb, allow_tf32=True, **kwargs)
return out

M = lhs.type.shape[-2]
N = rhs.type.shape[-1]
B = lhs.type.shape[0] if lhs_rank == 3 else None
Expand Down
2 changes: 2 additions & 0 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class CUDAOptions:
ptx_version: int = None
enable_fp_fusion: bool = True
allow_fp8e4nv: bool = False
allow_tf32: bool = False
max_num_imprecise_acc_default: bool = None
extern_libs: dict = None
debug: bool = False
Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(self, target: tuple) -> None:
def parse_options(self, opts) -> Any:
args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
args["allow_fp8e4nv"] = self.capability >= 89
args["allow_tf32"] = self.capability >= 80
args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
return CUDAOptions(**args)

Expand Down

0 comments on commit 2edbebd

Please sign in to comment.