Skip to content

Commit

Permalink
Address review
Browse files Browse the repository at this point in the history
  • Loading branch information
lezcano committed Feb 28, 2024
1 parent 2edbebd commit f886c42
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 14 deletions.
20 changes: 8 additions & 12 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,27 +1348,23 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, options):
_0 = builder.get_fp16(0) if out_dtype.is_fp16() else builder.get_fp32(0)
ret_scalar_ty = out_dtype

if not allow_tf32 and lhs.dtype.is_fp32() and rhs.dtype.is_fp32() and getattr(builder.options, "allow_tf32"):
if not allow_tf32 and lhs.dtype.is_fp32() and rhs.dtype.is_fp32() and getattr(builder.options, "supports_tf32"):
# Implement 3xTF32 trick
# https://github.com/NVIDIA/cutlass/discussions/385

def fp32_to_tf32(x):
return tl.inline_asm_elementwise(
asm="""cvt.rna.tf32.f32 $0, $2;\n\tsub.f32 $1, $2, $0;""",
constraints="=r,=r,r",
args=[x],
dtype=[tl.float32, tl.float32],
is_pure=True,
pack=1,
_builder=builder
)
return tl.inline_asm_elementwise(asm="""cvt.rna.tf32.f32 $0, $2;\n\tsub.f32 $1, $2, $0;""",
constraints="=r,=r,r", args=[x], dtype=[tl.float32, tl.float32],
is_pure=True, pack=1, _builder=builder)

xb, xs = fp32_to_tf32(lhs)
yb, ys = fp32_to_tf32(rhs)
del kwargs["lhs"]
del kwargs["rhs"]
del kwargs["allow_tf32"]
kwargs["acc"] = dot(xb, ys, allow_tf32=True, **kwargs)
kwargs["acc"] = dot(xs, yb, allow_tf32=True, **kwargs)
out = dot(xb, yb, allow_tf32=True, **kwargs)
return out
return dot(xb, yb, allow_tf32=True, **kwargs)

M = lhs.type.shape[-2]
N = rhs.type.shape[-1]
Expand Down
4 changes: 2 additions & 2 deletions third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class CUDAOptions:
ptx_version: int = None
enable_fp_fusion: bool = True
allow_fp8e4nv: bool = False
allow_tf32: bool = False
supports_tf32: bool = False
max_num_imprecise_acc_default: bool = None
extern_libs: dict = None
debug: bool = False
Expand Down Expand Up @@ -95,7 +95,7 @@ def __init__(self, target: tuple) -> None:
def parse_options(self, opts) -> Any:
args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts}
args["allow_fp8e4nv"] = self.capability >= 89
args["allow_tf32"] = self.capability >= 80
args["supports_tf32"] = self.capability >= 80
args["max_num_imprecise_acc_default"] = 2**30 if self.capability == 90 else 0
return CUDAOptions(**args)

Expand Down

0 comments on commit f886c42

Please sign in to comment.