diff --git a/.github/workflows/torch-inductor/scripts/check_perf.py b/.github/workflows/torch-inductor/scripts/check_perf.py index 447076187cf7..a966b7be2588 100644 --- a/.github/workflows/torch-inductor/scripts/check_perf.py +++ b/.github/workflows/torch-inductor/scripts/check_perf.py @@ -3,8 +3,7 @@ from collections import namedtuple # Create a named tuple for the output of the benchmark -BenchmarkOutput = namedtuple( - 'BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup', 'latency']) +BenchmarkOutput = namedtuple('BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup', 'latency']) def parse_output(file_path: str) -> dict: @@ -19,13 +18,11 @@ def parse_output(file_path: str) -> dict: batch_size = row[2] speedup = float(row[3]) latency = float(row[4]) - entries[name] = BenchmarkOutput( - dev, name, batch_size, speedup, latency) + entries[name] = BenchmarkOutput(dev, name, batch_size, speedup, latency) return entries -def compare(baseline: dict, new: dict, threshold: float, - geomean_threshold: float) -> bool: +def compare(baseline: dict, new: dict, threshold: float, geomean_threshold: float) -> bool: baseline_geomean = 1.0 new_geomean = 1.0 for key in new: @@ -41,19 +38,16 @@ def compare(baseline: dict, new: dict, threshold: float, continue if new_latency < baseline_latency * (1 - threshold): - print( - f"New benchmark {key} is faster than baseline: {new_latency} vs {baseline_latency}") + print(f"New benchmark {key} is faster than baseline: {new_latency} vs {baseline_latency}") elif new_latency > baseline_latency * (1 + threshold): - print( - f"New benchmark {key} is slower than baseline: {new_latency} vs {baseline_latency}") + print(f"New benchmark {key} is slower than baseline: {new_latency} vs {baseline_latency}") else: - print( - f"New benchmark {key} is within threshold: {new_latency} vs {baseline_latency}") + print(f"New benchmark {key} is within threshold: {new_latency} vs {baseline_latency}") baseline_geomean *= baseline[key].speedup new_geomean *= new[key].speedup - baseline_geomean = baseline_geomean ** (1 / len(baseline)) - new_geomean = new_geomean ** (1 / len(new)) + baseline_geomean = baseline_geomean**(1 / len(baseline)) + new_geomean = new_geomean**(1 / len(new)) print(f"Baseline geomean: {baseline_geomean}") print(f"New geomean: {new_geomean}") assert new_geomean >= baseline_geomean * (1 - geomean_threshold), \ diff --git a/.isort.cfg b/.isort.cfg deleted file mode 100644 index 833801cca28e..000000000000 --- a/.isort.cfg +++ /dev/null @@ -1,4 +0,0 @@ -[settings] -known_local_folder=triton -line_length=88 -py_version=36 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d1f5d985aed..7c877cb15752 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,20 +30,13 @@ repos: ^docs/conf.py$ ) - - repo: https://github.com/PyCQA/isort - rev: 5.12.0 + - repo: https://github.com/jlebar/yapf + rev: bf301f5ef7777e137b97219842629ca78eb5ef2a hooks: - - id: isort - exclude: '^python/triton/runtime/.*' + - id: yapf + args: ["-p", "-i"] stages: [commit, push, manual] - - repo: https://github.com/pre-commit/mirrors-autopep8 - rev: v1.6.0 - hooks: - - id: autopep8 - exclude: '^python/triton/runtime/.*' - args: ["-i"] - stages: [commit, push, manual] - repo: https://github.com/pre-commit/mirrors-clang-format rev: v16.0.6 hooks: diff --git a/docs/conf.py b/docs/conf.py index 54ca524685b7..b40acaa4bc2a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,7 +22,6 @@ # -- General configuration ------------------------------------------------ - import os import shutil import sys @@ -121,12 +120,9 @@ def documenter(app, obj, parent): return old_documenter(app, obj, parent) sphinx.ext.autosummary.get_documenter = documenter - sphinx.util.inspect.unwrap_all = forward_jit_fn( - sphinx.util.inspect.unwrap_all) - sphinx.util.inspect.signature = forward_jit_fn( - sphinx.util.inspect.signature) - sphinx.util.inspect.object_description = forward_jit_fn( - sphinx.util.inspect.object_description) + sphinx.util.inspect.unwrap_all = forward_jit_fn(sphinx.util.inspect.unwrap_all) + sphinx.util.inspect.signature = forward_jit_fn(sphinx.util.inspect.signature) + sphinx.util.inspect.object_description = forward_jit_fn(sphinx.util.inspect.object_description) # Auto Doc @@ -139,7 +135,8 @@ def documenter(app, obj, parent): 'sphinx.ext.coverage', 'sphinx.ext.napoleon', 'sphinx_multiversion', - 'myst_parser'] + 'myst_parser', +] autosummary_generate = True # versioning config @@ -294,6 +291,6 @@ def documenter(app, obj, parent): # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ - (master_doc, 'Triton', 'Triton Documentation', author, - 'Triton', 'One line description of project.', 'Miscellaneous'), + (master_doc, 'Triton', 'Triton Documentation', author, 'Triton', 'One line description of project.', + 'Miscellaneous'), ] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000000..525f303efb6f --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[build-system] +requires = ["setuptools>=40.8.0", "wheel", "cmake>=3.18", "ninja>=1.11.1"] + +[tool.yapf] +based_on_style = "pep8" +column_limit = 120 +disable_split_list_with_comment = true +each_dict_entry_on_separate_line=false +split_before_named_assigns = false +split_complex_comprehension = true + +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +ignore = ["E501", "E701", "E731", "E741"] diff --git a/python/examples/copy_strided.py b/python/examples/copy_strided.py index 34cf12630205..79457d6a4feb 100644 --- a/python/examples/copy_strided.py +++ b/python/examples/copy_strided.py @@ -4,8 +4,8 @@ # triton kernel @triton.jit -def kernel(X, stride_xm, - Z, stride_zn, +def kernel(X, stride_xm, # + Z, stride_zn, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) diff --git a/python/examples/empty.py b/python/examples/empty.py index df313fb85869..bff6d1e9499e 100644 --- a/python/examples/empty.py +++ b/python/examples/empty.py @@ -10,4 +10,4 @@ def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): X = torch.randn(1, device="cuda") -pgm = kernel[(1,)](X, 1, 1, BLOCK=1024) +pgm = kernel[(1, )](X, 1, 1, BLOCK=1024) diff --git a/python/setup.py b/python/setup.py index acda03fc1f21..2758f14ce031 100644 --- a/python/setup.py +++ b/python/setup.py @@ -55,6 +55,7 @@ class Package(NamedTuple): lib_flag: str syspath_var_name: str + # pybind11 @@ -63,6 +64,7 @@ def get_pybind11_package_info(): url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.11.1.tar.gz" return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH") + # llvm @@ -121,6 +123,7 @@ def get_thirdparty_packages(triton_cache_path): thirdparty_cmake_args.append(f"-D{p.lib_flag}={package_dir}/lib") return thirdparty_cmake_args + # ---- package data --- @@ -153,6 +156,7 @@ def download_and_copy(src_path, variable, version, url_func): os.makedirs(os.path.split(dst_path)[0], exist_ok=True) shutil.copy(src_path, dst_path) + # ---- cmake extension ---- @@ -170,18 +174,21 @@ def get_cmake_dir(): class CMakeClean(clean): + def initialize_options(self): clean.initialize_options(self) self.build_temp = get_cmake_dir() class CMakeBuildPy(build_py): + def run(self) -> None: self.run_command('build_ext') return super().run() class CMakeExtension(Extension): + def __init__(self, name, path, sourcedir=""): Extension.__init__(self, name, sources=[]) self.sourcedir = os.path.abspath(sourcedir) @@ -204,7 +211,8 @@ def run(self): try: out = subprocess.check_output(["cmake", "--version"]) except OSError: - raise RuntimeError("CMake must be installed to build the following extensions: " + ", ".join(e.name for e in self.extensions)) + raise RuntimeError("CMake must be installed to build the following extensions: " + + ", ".join(e.name for e in self.extensions)) match = re.search(r"version\s*(?P\d+)\.(?P\d+)([\d.]+)?", out.decode()) cmake_major, cmake_minor = int(match.group("major")), int(match.group("minor")) @@ -231,8 +239,10 @@ def build_extension(self, ext): # python directories python_include_dir = sysconfig.get_path("platinclude") cmake_args = [ - "-G", "Ninja", # Ninja is much faster than make - "-DCMAKE_MAKE_PROGRAM=" + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path + "-G", + "Ninja", # Ninja is much faster than make + "-DCMAKE_MAKE_PROGRAM=" + + ninja_dir, # Pass explicit path to ninja otherwise cmake may cache a temporary path "-DCMAKE_EXPORT_COMPILE_COMMANDS=ON", "-DLLVM_ENABLE_WERROR=ON", "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir, @@ -266,12 +276,14 @@ def build_extension(self, ext): build_args += ['-j' + max_jobs] if check_env_flag("TRITON_BUILD_WITH_CLANG_LLD"): - cmake_args += ["-DCMAKE_C_COMPILER=clang", - "-DCMAKE_CXX_COMPILER=clang++", - "-DCMAKE_LINKER=lld", - "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld", - "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", - "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld"] + cmake_args += [ + "-DCMAKE_C_COMPILER=clang", + "-DCMAKE_CXX_COMPILER=clang++", + "-DCMAKE_LINKER=lld", + "-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld", + "-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld", + "-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld", + ] # Note that asan doesn't work with binaries that use the GPU, so this is # only useful for tools like triton-opt that don't run code on the GPU. @@ -303,19 +315,22 @@ def build_extension(self, ext): src_path="bin/ptxas", variable="TRITON_PTXAS_PATH", version="12.1.105", - url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2", + url_func=lambda arch, version: + f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2", ) download_and_copy( src_path="bin/cuobjdump", variable="TRITON_CUOBJDUMP_PATH", version="12.1.111", - url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", + url_func=lambda arch, version: + f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", ) download_and_copy( src_path="bin/nvdisasm", variable="TRITON_NVDISASM_PATH", version="12.1.105", - url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", + url_func=lambda arch, version: + f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", ) setup( @@ -339,9 +354,7 @@ def build_extension(self, ext): "triton/third_party", "triton/tools", ], - install_requires=[ - "filelock" - ], + install_requires=["filelock"], include_package_data=True, ext_modules=[CMakeExtension("triton", "triton/_C/")], cmdclass={"build_ext": CMakeBuild, "build_py": CMakeBuildPy, "clean": CMakeClean}, diff --git a/python/test/backend/test_device_backend.py b/python/test/backend/test_device_backend.py index cb86309caa5b..bc73d837657a 100644 --- a/python/test/backend/test_device_backend.py +++ b/python/test/backend/test_device_backend.py @@ -13,8 +13,7 @@ import triton import triton.language as tl -from triton.common.backend import (BaseBackend, compute_core_version_key, - register_backend) +from triton.common.backend import (BaseBackend, compute_core_version_key, register_backend) from triton.common.build import quiet from triton.compiler.make_launcher import make_so_cache_key from triton.runtime.cache import get_cache_manager @@ -81,6 +80,7 @@ def build_for_backend(name, src, srcdir): class ExtensionUtils: + def __new__(cls): if not hasattr(cls, 'instance'): cls.instance = super(ExtensionUtils, cls).__new__(cls) @@ -110,6 +110,7 @@ def __init__(self): class ExtensionDriver(DriverBase): + def __new__(cls): if not hasattr(cls, 'instance'): cls.instance = super(ExtensionDriver, cls).__new__(cls) @@ -256,13 +257,13 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): inp = torch.randn(10) out = torch.randn(10) - kernel[(10,)](inp, out, 10, XBLOCK=16) + kernel[(10, )](inp, out, 10, XBLOCK=16) spec = importlib.util.spec_from_file_location("__triton_launcher", ExtensionBackend.stub_so_path) mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) launch_counter = getattr(mod, "launch_counter") for _ in range(100): - kernel[(10,)](inp, out, 10, XBLOCK=16) + kernel[(10, )](inp, out, 10, XBLOCK=16) assert launch_counter() > 0 diff --git a/python/test/backend/third_party_backends/conftest.py b/python/test/backend/third_party_backends/conftest.py index 62ee6c68976e..d939bc001646 100644 --- a/python/test/backend/third_party_backends/conftest.py +++ b/python/test/backend/third_party_backends/conftest.py @@ -4,9 +4,7 @@ def pytest_addoption(parser): - parser.addoption( - "--backend", action="store", default="", help="Codegen backend" - ) + parser.addoption("--backend", action="store", default="", help="Codegen backend") @pytest.fixture diff --git a/python/test/backend/third_party_backends/test_xpu_backend.py b/python/test/backend/third_party_backends/test_xpu_backend.py index f00261f193de..e6850efdd867 100644 --- a/python/test/backend/third_party_backends/test_xpu_backend.py +++ b/python/test/backend/third_party_backends/test_xpu_backend.py @@ -24,10 +24,10 @@ def kernel(x_ptr, y_ptr, out_ptr): if has_ipex: for _ in range(1000): - x = torch.randn((65536,), device="xpu", dtype=torch.float32) - y = torch.randn((65536,), device="xpu", dtype=torch.float32) - z = torch.zeros((65536,), device="xpu", dtype=torch.float32) - kernel[(65536,)](x, y, z, num_warps=32) + x = torch.randn((65536, ), device="xpu", dtype=torch.float32) + y = torch.randn((65536, ), device="xpu", dtype=torch.float32) + z = torch.zeros((65536, ), device="xpu", dtype=torch.float32) + kernel[(65536, )](x, y, z, num_warps=32) assert torch.all(x + y == z) else: return diff --git a/python/test/regression/test_cast_matmul.py b/python/test/regression/test_cast_matmul.py index 8b63a23dbe70..1477bc5a41bd 100644 --- a/python/test/regression/test_cast_matmul.py +++ b/python/test/regression/test_cast_matmul.py @@ -15,16 +15,12 @@ out_dtypes = ["float16", "float32"] -@pytest.mark.parametrize( - "M, K, N, w_dtype, x_dtype, out_dtype", - [ - (M, K, N, w, x, o) - for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] - for w in input_dtypes - for x in input_dtypes - for o in out_dtypes - ] -) +@pytest.mark.parametrize("M, K, N, w_dtype, x_dtype, out_dtype", + [(M, K, N, w, x, o) # + for (M, K, N) in [(128, 128, 128), (1280, 768, 1024)] # + for w in input_dtypes + for x in input_dtypes # + for o in out_dtypes]) def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype): if x_dtype == w_dtype: pytest.skip("skip same dtype") @@ -44,15 +40,14 @@ def test_cast_matmul(M, K, N, w_dtype, x_dtype, out_dtype): grid = ((cdiv(M, BLOCK_M) * cdiv(N, BLOCK_N)), 1) @jit - def matmul_kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - dot_out_dtype: tl.constexpr, - allow_tf32: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr - ): + def matmul_kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + dot_out_dtype: tl.constexpr, # + allow_tf32: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, # + BLOCK_K: tl.constexpr, GROUP_M: tl.constexpr): # matrix multiplication pid = tl.program_id(0) grid_m = tl.cdiv(M, BLOCK_M) @@ -91,16 +86,15 @@ def matmul_kernel(A, B, C, M, N, K, mask = (rm < M)[:, None] & (rn < N)[None, :] tl.store(C, acc, mask=mask) - matmul_kernel[grid](a, b, out_triton, M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - out_triton.stride(0), out_triton.stride(1), - dot_out_dtype=triton_dtype, - allow_tf32=allow_tf32, - GROUP_M=8, - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_K=BLOCK_K, - ) + matmul_kernel[grid]( + a, b, out_triton, M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + out_triton.stride(0), out_triton.stride(1), dot_out_dtype=triton_dtype, # + allow_tf32=allow_tf32, # + GROUP_M=8, # + BLOCK_M=BLOCK_M, # + BLOCK_N=BLOCK_N, # + BLOCK_K=BLOCK_K) torch.testing.assert_close(out_torch, out_triton, atol=0.3, rtol=0.01) diff --git a/python/test/regression/test_functional_regressions.py b/python/test/regression/test_functional_regressions.py index b873db7a3fa0..e0eb5660186b 100644 --- a/python/test/regression/test_functional_regressions.py +++ b/python/test/regression/test_functional_regressions.py @@ -14,18 +14,14 @@ def chained_matmul_reference(a, b, c): return torch.einsum('MN,NK->MK', intermediate, c) @triton.jit - def chained_matmul_kernel( - A, # shape: (m, k) - B, # shape: (n, k) - C, # shape: (n, k) - out, # shape: (m, k) - m, n, k: tl.constexpr, - block_m: tl.constexpr, - block_n: tl.constexpr, - block_k: tl.constexpr): - - tl.static_assert(block_k == k, - f"expected block_k == k but got {block_k} != {k}") + def chained_matmul_kernel(A, # shape: (m, k) + B, # shape: (n, k) + C, # shape: (n, k) + out, # shape: (m, k) + m, n, k: tl.constexpr, # + block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr): + + tl.static_assert(block_k == k, f"expected block_k == k but got {block_k} != {k}") block_ix = tl.program_id(0) a_tile = (block_ix * block_m + tl.arange(0, block_m))[:, None] * block_k \ @@ -55,35 +51,33 @@ def chained_matmul_kernel( m, n, k = 32, 64, 128 block_m, block_n, block_k = 16, 32, k - grid = (triton.cdiv(m, block_m),) - a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, - device='cuda') - b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, - device='cuda') + grid = (triton.cdiv(m, block_m), ) + a = torch.randint(low=0, high=2, size=(m, k), dtype=torch.float16, device='cuda') + b = torch.randint(low=0, high=2, size=(n, k), dtype=torch.float16, device='cuda') c = torch.randint_like(b, low=0, high=2) triton_result = torch.zeros_like(a) torch_result = chained_matmul_reference(a, b, c) - chained_matmul_kernel[grid](a, b, c, triton_result, m, n, k, - block_m=block_m, block_n=block_n, - block_k=block_k) + chained_matmul_kernel[grid]( + a, b, c, triton_result, m, n, k, # + block_m=block_m, block_n=block_n, block_k=block_k) assert (torch_result == triton_result).all() def test_vecmat(): + @triton.jit def batched_vecmat( - # inputs - A, # shape: [dim_m, dim_k] - B, # shape: [dim_m, dim_n, dim_k] - # dimensions + # inputs + A, # shape: [dim_m, dim_k] + B, # shape: [dim_m, dim_n, dim_k] + # dimensions dim_m, dim_n, dim_k, - # outputs - output, - # block information - block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr - ): + # outputs + output, + # block information + block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr): m_index = tl.program_id(0) n_index = tl.program_id(1) # Output tile @@ -125,9 +119,10 @@ def batched_vecmat( grid = (M // block_m, N // block_n) - batched_vecmat[grid](A_tri, B_tri, M, N, K, C_tri, - block_m=block_m, block_n=block_n, block_k=block_k, - num_warps=4, num_stages=1) + batched_vecmat[grid]( + A_tri, B_tri, M, N, K, C_tri, # + block_m=block_m, block_n=block_n, block_k=block_k, # + num_warps=4, num_stages=1) A_expanded = A[:, np.newaxis, :] A_broadcasted = np.broadcast_to(A_expanded, (M, N, K)) @@ -137,18 +132,18 @@ def batched_vecmat( np.testing.assert_allclose(C_ref, C_tri.cpu().numpy(), rtol=0.01, atol=1e-3) -@pytest.mark.parametrize("type", ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"]) +@pytest.mark.parametrize("type", + ["pre_load", "post_load", "post_pre_mixed", "post_load_two_iters", "post_load_three_iters"]) def test_iv_dependent_matmul(type): + @triton.jit - def kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - type: tl.constexpr - ): + def kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + type: tl.constexpr): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) pid_m = pid // num_pid_n @@ -216,15 +211,16 @@ def kernel( b = torch.rand((K, N), device='cuda') torch_output = torch.mm(a, b) - triton_output = torch.empty_like( - torch_output, device=torch_output.device) + triton_output = torch.empty_like(torch_output, device=torch_output.device) def grid(META): - return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) num_stages = 4 if type == "post_load_three_iters" else 3 - kernel[grid](a, b, triton_output, M, N, K, a.stride(0), a.stride(1), - b.stride(0), b.stride(1), triton_output.stride(0), triton_output.stride(1), - BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, - type=type, num_stages=num_stages) + kernel[grid]( + a, b, triton_output, M, N, K, # + a.stride(0), a.stride(1), b.stride(0), b.stride(1), # + triton_output.stride(0), triton_output.stride(1), # + BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, # + num_stages=num_stages) torch.testing.assert_close(torch_output, triton_output, rtol=1e-2, atol=1e-2) diff --git a/python/test/regression/test_performance.py b/python/test/regression/test_performance.py index ccb146e6b1d1..bce935b07b9c 100644 --- a/python/test/regression/test_performance.py +++ b/python/test/regression/test_performance.py @@ -26,7 +26,6 @@ def print_perf(cur_ms, cur_util, ref_util): mem_clocks = {'v100': 877, 'a100': 1215} matmul_data = { - # NOTE: 'a100': { # square (512, 512, 512): {'float16': 0.108, 'float32': 0.097, 'int8': 0.05}, @@ -49,10 +48,9 @@ def print_perf(cur_ms, cur_util, ref_util): } -@pytest.mark.parametrize('M, N, K, dtype_str', - [(M, N, K, dtype_str) - for M, N, K in matmul_data[DEVICE_NAME].keys() - for dtype_str in ['float16']]) +@pytest.mark.parametrize('M, N, K, dtype_str', [(M, N, K, dtype_str) + for M, N, K in matmul_data[DEVICE_NAME].keys() + for dtype_str in ['float16']]) def test_matmul(M, N, K, dtype_str): stream = torch.cuda.Stream() torch.cuda.set_stream(stream) @@ -86,8 +84,7 @@ def test_matmul(M, N, K, dtype_str): @triton.jit -def _add(x_ptr, y_ptr, output_ptr, n_elements, - BLOCK_SIZE: tl.constexpr): +def _add(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -136,11 +133,11 @@ def test_elementwise(N, dtype_str): print_perf(ms, cur_gpu_util, ref_gpu_util) triton.testing.assert_close(cur_gpu_util, ref_gpu_util, atol=0.02, rtol=0.01) + ####################### # Flash-Attention ####################### - flash_attention_data = { "a100": { (4, 48, 4096, 64, True, True, 'forward', 'float16'): 0.542, @@ -221,8 +218,7 @@ def test_flash_attention(Z, H, N_CTX, D_HEAD, seq_par, causal, mode, dtype_str): @triton.jit -def _sum(x_ptr, y_ptr, output_ptr, n_elements, - BLOCK_SIZE: tl.constexpr): +def _sum(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE offsets = block_start + tl.arange(0, BLOCK_SIZE) @@ -260,8 +256,8 @@ def test_reductions(N, dtype_str): y = torch.randn_like(z) else: info = torch.iinfo(dtype) - x = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda') - y = torch.randint(info.min, info.max, (N,), dtype=dtype, device='cuda') + x = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda') + y = torch.randint(info.min, info.max, (N, ), dtype=dtype, device='cuda') grid = lambda args: (triton.cdiv(N, args['BLOCK_SIZE']), ) fn = lambda: _sum[grid](x, y, z, N, BLOCK_SIZE=1024) ms = triton.testing.do_bench_cudagraph(fn) diff --git a/python/test/tools/compare_files.py b/python/test/tools/compare_files.py index 1c8de084dcf9..d74e1da02a71 100644 --- a/python/test/tools/compare_files.py +++ b/python/test/tools/compare_files.py @@ -9,6 +9,7 @@ class ComparisonResult: + def __init__(self, name: str, numComparisons: int, diffs: List[str] = None, errors: List[str] = None): self.name = name self.numComparisons = numComparisons @@ -142,7 +143,8 @@ def doFilesMatch(path1: str, path2: str) -> bool: return True -def compareMatchingFiles(name: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]], args) -> ComparisonResult: +def compareMatchingFiles(name: str, nameToHashes1: Dict[str, List[str]], nameToHashes2: Dict[str, List[str]], + args) -> ComparisonResult: """ Compare files with the given name in all hashes in both paths Return the first mismatching files as a tuple (file1, file2), otherwise, return an empty tuple diff --git a/python/test/unit/hopper/test_flashattention.py b/python/test/unit/hopper/test_flashattention.py index 60006613b625..fc8db664c9f0 100644 --- a/python/test/unit/hopper/test_flashattention.py +++ b/python/test/unit/hopper/test_flashattention.py @@ -18,7 +18,6 @@ # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - """ Fused Attention =============== @@ -35,18 +34,15 @@ @triton.jit -def _fwd_kernel( - Q, K, V, sm_scale, - L, M, - Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, N_CTX, D0, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): +def _fwd_kernel(Q, K, V, sm_scale, # + L, M, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, D0, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr): start_m = tl.program_id(0) off_hz = tl.program_id(1) @@ -61,31 +57,38 @@ def _fwd_kernel( stride_qh_2d = stride_qh // stride_qm // stride_qk - q_tile_ptr = tl.make_block_ptr(base=Q, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_hz * stride_qh_2d + start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - k_tile_ptr = tl.make_block_ptr(base=K, - shape=(D0, BLOCK_DMODEL), - strides=(stride_kn, stride_kk), - offsets=(off_hz * stride_qh_2d, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0)) - v_tile_ptr = tl.make_block_ptr(base=V, - shape=(D0, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(off_hz * stride_qh_2d, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(1, 0)) - out_tile_ptr = tl.make_block_ptr(base=Out, - shape=(D0, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(D0, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(off_hz * stride_qh_2d, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(D0, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(off_hz * stride_qh_2d, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0), + ) + out_tile_ptr = tl.make_block_ptr( + base=Out, + shape=(D0, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) # load q: it will stay in SRAM throughout q = tl.load(q_tile_ptr) @@ -96,8 +99,7 @@ def _fwd_kernel( qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, tl.trans(k)) qk *= sm_scale - qk = tl.where(offs_m[:, None] >= ( - start_n + offs_n[None, :]), qk, float("-inf")) + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) # compute new m m_curr = tl.maximum(tl.max(qk, 1), m_prev) # correct old l @@ -133,11 +135,9 @@ def _fwd_kernel( @triton.jit -def _bwd_preprocess( - Out, DO, L, - NewDO, Delta, - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, -): +def _bwd_preprocess(Out, DO, L, # + NewDO, Delta, # + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_n = tl.arange(0, D_HEAD) # load @@ -153,19 +153,14 @@ def _bwd_preprocess( @triton.jit -def _bwd_kernel( - Q, K, V, sm_scale, Out, DO, - DQ, DK, DV, - L, M, - D, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - Z, H, N_CTX, D0, - num_block, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): +def _bwd_kernel(Q, K, V, sm_scale, Out, DO, # + DQ, DK, DV, # + L, M, # + D, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + Z, H, N_CTX, D0, # + num_block, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr): off_hz = tl.program_id(0) off_z = off_hz // H off_h = off_hz % H @@ -173,55 +168,62 @@ def _bwd_kernel( stride_qz_2d = stride_qz // stride_qm // stride_qk stride_qh_2d = stride_qh // stride_qm // stride_qk - q_tile_ptr = tl.make_block_ptr(base=Q, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - k_tile_ptr = tl.make_block_ptr(base=K, - shape=(D0, BLOCK_DMODEL), - strides=(stride_kn, stride_kk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - v_tile_ptr = tl.make_block_ptr(base=V, - shape=(D0, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - do_tile_ptr = tl.make_block_ptr(base=DO, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - dq_tile_ptr = tl.make_block_ptr(base=DQ, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - dk_tile_ptr = tl.make_block_ptr(base=DK, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) - dv_tile_ptr = tl.make_block_ptr(base=DV, - shape=(D0, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=( - off_z * stride_qz_2d + off_h * stride_qh_2d, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0)) + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(D0, BLOCK_DMODEL), + strides=(stride_kn, stride_kk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(D0, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + do_tile_ptr = tl.make_block_ptr( + base=DO, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dq_tile_ptr = tl.make_block_ptr( + base=DQ, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dk_tile_ptr = tl.make_block_ptr( + base=DK, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) + dv_tile_ptr = tl.make_block_ptr( + base=DV, + shape=(D0, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(off_z * stride_qz_2d + off_h * stride_qh_2d, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0), + ) # offset pointers for batch/head DQ += off_z * stride_qz + off_h * stride_qh for start_n in range(0, num_block): @@ -250,8 +252,7 @@ def _bwd_kernel( # recompute p = softmax(qk, dim=-1).T # NOTE: `do` is pre-divided by `l`; no normalization here qk = tl.dot(q, tl.trans(k)) - qk = tl.where(offs_m_curr[:, None] >= ( - offs_n[None, :]), qk, float("-inf")) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) m = tl.load(m_ptrs + offs_m_curr) p = tl.exp(qk * sm_scale - m[:, None]) # compute dv @@ -301,29 +302,21 @@ def forward(ctx, q, k, v, sm_scale): assert Lk in {16, 32, 64, 128} o = torch.empty_like(q) grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1) - L = torch.empty( - (q.shape[0] * q.shape[1], q.shape[2]), - device=q.device, - dtype=torch.float32) - m = torch.empty( - (q.shape[0] * q.shape[1], q.shape[2]), - device=q.device, - dtype=torch.float32) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 D0 = q.shape[0] * q.shape[1] * q.shape[2] _fwd_kernel[grid]( - q, k, v, sm_scale, - L, m, - o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], D0, - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, num_warps=num_warps, - num_stages=2, - ) + q, k, v, sm_scale, # + L, m, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], D0, # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=Lk, # + num_warps=num_warps, num_stages=2) ctx.save_for_backward(q, k, v, o, L, m) ctx.grid = grid @@ -343,25 +336,22 @@ def backward(ctx, do): delta = torch.empty_like(l) D0 = q.shape[0] * q.shape[1] * q.shape[2] _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( - o, do, l, - do_scaled, delta, - BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, - ) - _bwd_kernel[(ctx.grid[1],)]( - q, k, v, ctx.sm_scale, - o, do_scaled, - dq, dk, dv, - l, m, - delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - q.shape[0], q.shape[1], q.shape[2], D0, - ctx.grid[0], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - num_stages=1, - ) + o, do, l, # + do_scaled, delta, # + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL) + _bwd_kernel[(ctx.grid[1], )]( + q, k, v, ctx.sm_scale, # + o, do_scaled, # + dq, dk, dv, # + l, m, # + delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], D0, # + ctx.grid[0], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + num_warps=8, num_stages=1) return dq, dk, dv, None @@ -380,15 +370,9 @@ def backward(ctx, do): @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="requires arch 9+") def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): torch.manual_seed(20) - q = torch.empty( - (Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_( - mean=0.1, std=0.2).requires_grad_() - k = torch.empty( - (Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_( - mean=0.4, std=0.2).requires_grad_() - v = torch.empty( - (Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_( - mean=0.3, std=0.2).requires_grad_() + q = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2).requires_grad_() + k = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2).requires_grad_() + v = torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2).requires_grad_() sm_scale = 0.2 dout = torch.randn_like(q) # reference implementation @@ -427,22 +411,25 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 -configs = [triton.testing.Benchmark( - x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 14)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', - args={ - 'H': N_HEADS, - 'BATCH': BATCH, - 'D_HEAD': D_HEAD, - 'dtype': torch.float16, - 'mode': mode} -) for mode in ['fwd', 'bwd']] +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], + x_vals=[2**i for i in range(10, 14)], + line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), + styles=[('red', '-'), ('blue', '-')], + ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + }, + ) for mode in ['fwd', 'bwd'] +] @triton.testing.perf_report(configs) @@ -463,9 +450,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms if provider == "flash": - lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) - cu_seqlens = torch.zeros( - (BATCH + 1,), device=device, dtype=torch.int32) + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) cu_seqlens[1:] = lengths.cumsum(0) qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) diff --git a/python/test/unit/hopper/test_gemm.py b/python/test/unit/hopper/test_gemm.py index 5c57fd17c25e..cc6ebd0b0efb 100644 --- a/python/test/unit/hopper/test_gemm.py +++ b/python/test/unit/hopper/test_gemm.py @@ -32,19 +32,30 @@ @triton.jit -def matmul_no_scf_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr -): - a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) +def matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr # + ): + a_block_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(1, 0), + ) + b_block_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), + order=(0, 1), + ) a = tl.load(a_block_ptr) b = tl.load(b_block_ptr) @@ -54,8 +65,8 @@ def matmul_no_scf_kernel( c = c.to(tl.float16) if USE_TMA_EPILOGUE: - c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) tl.store(c_block_ptr, c) else: offs_m = tl.arange(0, BLOCK_M) @@ -64,33 +75,30 @@ def matmul_no_scf_kernel( tl.store(c_ptrs, c) -@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS', - itertools.chain( - *[ - [ - # numCTAs = 1, no TMA multicast: - [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - # static mask, cluster 4x1 - [256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - # dynamic mask, cluster 2x2 - [128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], - [128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - # small M, N - [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], - ] for USE_TMA_EPILOGUE in [True, False] - for ENABLE_WS in [False, True] - ])) +@pytest.mark.parametrize( + 'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,ENABLE_WS', + itertools.chain(*[[ + # numCTAs = 1, no TMA multicast: + [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + # static mask, cluster 4x1 + [256, 64, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [256, 64, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + # dynamic mask, cluster 2x2 + [128, 128, 16, 4, 4, False, True, "float16", USE_TMA_EPILOGUE, ENABLE_WS], + [128, 128, 16, 4, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + # small M, N + [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, ENABLE_WS], + ] for USE_TMA_EPILOGUE in [True, False] for ENABLE_WS in [False, True]])) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, ENABLE_WS): if (TRANS_A): @@ -107,46 +115,41 @@ def test_gemm_no_scf(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE else: c = torch.empty((M, N), device=a.device, dtype=torch.float32) - matmul_no_scf_kernel[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1), - BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, - num_warps=NUM_WARPS, - num_ctas=NUM_CTAS, - FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), - USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, - enable_warp_specialization=ENABLE_WS) + matmul_no_scf_kernel[(1, 1)]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # + num_warps=NUM_WARPS, # + num_ctas=NUM_CTAS, # + FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # + USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, # + enable_warp_specialization=ENABLE_WS) a_f32 = a.to(torch.float32) b_f32 = b.to(torch.float32) golden = torch.matmul(a_f32, b_f32) torch.set_printoptions(profile="full") - assert_close( - c, - golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) @triton.jit -def matmul_kernel( - a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_wm, stride_wn, - stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, - ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, - DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, - A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, - B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, - W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, - Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr -): +def matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_wm, stride_wn, # + stride_zm, stride_zn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, # + out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, # + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, # + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, # + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, # + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, # + W_ORDER_0: tl.constexpr, W_ORDER_1: tl.constexpr, # + Z_ORDER_0: tl.constexpr, Z_ORDER_1: tl.constexpr # + ): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_N) num_pid_m = tl.cdiv(M, BLOCK_M) @@ -159,13 +162,31 @@ def matmul_kernel( block_offset_m = pid_m * BLOCK_M block_offset_n = pid_n * BLOCK_N - a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1)) - b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1)) + a_tile_ptr = tl.make_block_ptr( + base=a_ptr, + shape=(M, K), + strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), + block_shape=(BLOCK_M, BLOCK_K), + order=(A_ORDER_0, A_ORDER_1), + ) + b_tile_ptr = tl.make_block_ptr( + base=b_ptr, + shape=(K, N), + strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), + block_shape=(BLOCK_K, BLOCK_N), + order=(B_ORDER_0, B_ORDER_1), + ) # for chain-dot, BLOCK_N must always be equal to N, and each program loads the whole W matrix - w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn), - offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_N), order=(W_ORDER_0, W_ORDER_1)) + w_tile_ptr = tl.make_block_ptr( + base=w_ptr, + shape=(N, N), + strides=(stride_wm, stride_wn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_N), + order=(W_ORDER_0, W_ORDER_1), + ) z = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) offs_m = block_offset_m + tl.arange(0, BLOCK_M) @@ -204,139 +225,146 @@ def matmul_kernel( if USE_TMA_STORE: z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn), - offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(Z_ORDER_0, Z_ORDER_1)) + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), + order=(Z_ORDER_0, Z_ORDER_1)) tl.store(z_block_ptr, z, boundary_check=(0, 1)) else: tl.store(z_ptrs, z, mask=mask) -@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS', - [ - # corner shapes - (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) - for shape_w_c in [ - [4096, 1, 1024, False, False, True], - [2048, 204, 1000, True, False, True], - [4096, 1, 1024, False, False, False], - [2048, 204, 1000, True, False, False], - ] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for enable_ws in [False, True] - ] + [ - # softmax epilogue - (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [64, 64, 16, 4, 1, 64, 64, 64], - [128, 128, 64, 4, 1, None, None, None], - [16, 16, 64, 4, 1, 16, 16, 64], - [64, 64, 32, 8, 1, 64, 64, 64], - [128, 128, 64, 4, 1, 128, 128, 128], - ] - for epilogue in ['softmax'] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for trans_a in [False,] - for trans_b in [True,] - for trans_output in [False,] - for num_stages in [3] - for enable_ws in [False, True] - ] + [ - # loop over epilogues besides of softmax - (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [64, 64, 16, 4, 1, 128, 128, 64], - *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4, 8] for num_ctas in [1, 2, 4]], - # for chain-dot - [128, 128, 64, 4, 1, None, None, None], - [64, 64, 16, 4, 1, None, None, None], - # small BLOCK_M and BLOCK_K - [16, 16, 64, 4, 1, 128, 128, 64], - *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4, 8] for num_ctas in [1, 2]], - # repeat - [64, 64, 32, 8, 1, 128, 256, 64], - [64, 64, 16, 8, 2, 128, 128, 64], - # irregular shape - [128, 128, 64, 4, 1, 500, 200, 128], - [128, 128, 64, 4, 2, 513, 193, 192], - ] - for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for trans_a in [False,] - for trans_b in [True,] - for trans_output in [False,] - for num_stages in [3] - for enable_ws in [False, True] - if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) - ] + [ - # loop over tile shapes and transpose combinations - (*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [64, 64, 32, 4, 1, 128, 256, 64], - [128, 128, 16, 4, 4, 512, 256, 64], - [128, 256, 32, 4, 8, 256, 256, 192], - [512, 256, 32, 4, 8, 1024, 256, 192], - # BLOCK_K >= 128 - [64, 128, 128, 4, 1, 512, 256, 256], - [128, 128, 128, 4, 1, 256, 256, 192], - [128, 128, 128, 4, 2, 256, 256, 192], - # small BLOCK_M and BLOCK_K - [16, 32, 32, 4, 1, 128, 256, 64], - [32, 32, 16, 4, 1, 256, 256, 192], - [16, 32, 64, 4, 4, 512, 256, 64], - ] - for out_dtype in ['float32',] - for use_tma_store in [False,] - for trans_a in [False, True] - for trans_b in [False, True] - for trans_output in [False, True] - for num_stages in [3] - for enable_ws in [False, True] - ] + [ - # loop over instr shapes & pipeline stages - (64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for n in [16, 32, 64, 128, 256] - for trans_output in [False,] - for out_dtype in ['float32',] - for use_tma_store in [False,] - for num_stages in [2, 4, 5, 7] - for enable_ws in [False, True] - ] + [ - # irregular shapes - (*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [128, 128, 64, 4, 1], - [256, 128, 64, 4, 2], - [128, 128, 128, 4, 2], - ] - for shape in [ - [512, 360, 1024], - [360, 4096, 512], - ] - for trans_output in [False,] - for out_dtype in ['float32',] - for use_tma_store in [False, True] - for num_stages in [3, 4] - for enable_ws in [False, True] - ]) -@pytest.mark.skipif(torch.cuda.get_device_capability() - [0] < 9, reason="Requires compute capability >= 9") -def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS): +@pytest.mark.parametrize( + 'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,TRANS_OUTPUT,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS', + [ + # corner shapes + (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) + for shape_w_c in [ + [4096, 1, 1024, False, False, True], + [2048, 204, 1000, True, False, True], + [4096, 1, 1024, False, False, False], + [2048, 204, 1000, True, False, False], + ] + for out_dtype in ['float16', 'float32'] # + for use_tma_store in [False, True] # + for enable_ws in [False, True] + ] + [ + # softmax epilogue + (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 16, 4, 1, 64, 64, 64], + [128, 128, 64, 4, 1, None, None, None], + [16, 16, 64, 4, 1, 16, 16, 64], + [64, 64, 32, 8, 1, 64, 64, 64], + [128, 128, 64, 4, 1, 128, 128, 128], + ] + for epilogue in ['softmax'] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for trans_a in [False] + for trans_b in [True] + for trans_output in [False] + for num_stages in [3] + for enable_ws in [False, True] + ] + [ + # loop over epilogues besides of softmax + (*shape_w_c, trans_a, trans_b, trans_output, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 16, 4, 1, 128, 128, 64], + *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] + for num_warps in [4, 8] + for num_ctas in [1, 2, 4]], + # for chain-dot + [128, 128, 64, 4, 1, None, None, None], + [64, 64, 16, 4, 1, None, None, None], + # small BLOCK_M and BLOCK_K + [16, 16, 64, 4, 1, 128, 128, 64], + *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] + for num_warps in [4, 8] + for num_ctas in [1, 2]], + # repeat + [64, 64, 32, 8, 1, 128, 256, 64], + [64, 64, 16, 8, 2, 128, 128, 64], + # irregular shape + [128, 128, 64, 4, 1, 500, 200, 128], + [128, 128, 64, 4, 2, 513, 193, 192], + ] + for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for trans_a in [False] + for trans_b in [True] + for trans_output in [False] + for num_stages in [3] + for enable_ws in [False, True] + if not (epilogue == 'chain-dot' and (shape_w_c[6] is not None or shape_w_c[1] != shape_w_c[6])) + ] + [ + # loop over tile shapes and transpose combinations + (*shape_w_c, trans_a, trans_b, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [64, 64, 32, 4, 1, 128, 256, 64], + [128, 128, 16, 4, 4, 512, 256, 64], + [128, 256, 32, 4, 8, 256, 256, 192], + [512, 256, 32, 4, 8, 1024, 256, 192], + # BLOCK_K >= 128 + [64, 128, 128, 4, 1, 512, 256, 256], + [128, 128, 128, 4, 1, 256, 256, 192], + [128, 128, 128, 4, 2, 256, 256, 192], + # small BLOCK_M and BLOCK_K + [16, 32, 32, 4, 1, 128, 256, 64], + [32, 32, 16, 4, 1, 256, 256, 192], + [16, 32, 64, 4, 4, 512, 256, 64], + ] + for out_dtype in ['float32'] + for use_tma_store in [False] + for trans_a in [False, True] + for trans_b in [False, True] + for trans_output in [False, True] + for num_stages in [3] + for enable_ws in [False, True] + ] + [ + # loop over instr shapes & pipeline stages + (64, n, 16, 4, 1, 512, 256, 256, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, + enable_ws) + for n in [16, 32, 64, 128, 256] + for trans_output in [False] + for out_dtype in ['float32'] + for use_tma_store in [False] + for num_stages in [2, 4, 5, 7] + for enable_ws in [False, True] + ] + [ + # irregular shapes + (*shape_w_c, *shape, False, True, trans_output, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [ + [128, 128, 64, 4, 1], + [256, 128, 64, 4, 2], + [128, 128, 128, 4, 2], + ] + for shape in [ + [512, 360, 1024], + [360, 4096, 512], + ] + for trans_output in [False] + for out_dtype in ['float32'] + for use_tma_store in [False, True] + for num_stages in [3, 4] + for enable_ws in [False, True] + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_gemm(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, TRANS_OUTPUT, epilogue, + out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS): if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ - '16-32-64-4-4-512-256-64-True-False', - '16-32-64-4-4-512-256-64-True-True', - '16-32-64-4-4-512-256-64-False-False', - '16-32-64-4-4-512-256-64-False-True', + '16-32-64-4-4-512-256-64-True-False', + '16-32-64-4-4-512-256-64-True-True', + '16-32-64-4-4-512-256-64-False-False', + '16-32-64-4-4-512-256-64-False-True', ]: pytest.skip('shapePerCTA[1] < 16 not supported') if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ - '16-32-64-4-1-256-256-256-False', - '16-32-64-4-2-256-256-256-False', - '16-32-64-4-2-256-256-256-True', - '16-32-64-8-2-256-256-256-False', - '16-32-64-8-2-256-256-256-True', + '16-32-64-4-1-256-256-256-False', + '16-32-64-4-2-256-256-256-False', + '16-32-64-4-2-256-256-256-True', + '16-32-64-8-2-256-256-256-False', + '16-32-64-8-2-256-256-256-True', ]: pytest.skip('Known legacy issue, ldmatrix can only support x4') enable_tma = os.environ.get('ENABLE_TMA', 'not found').lower() @@ -413,38 +441,38 @@ def process_epilogue(d, bias, w, epilogue): else: ref = d return ref + golden = process_epilogue(dot, bias, w, epilogue) def grid(META): - return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']),) - pgm = matmul_kernel[grid](a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_wm=w.stride(0), stride_wn=w.stride(1), - stride_zm=z.stride(0), stride_zn=z.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, - out_dtype=out_dtype, - USE_TMA_STORE=USE_TMA_STORE, - ADD_MATRIX=epilogue == 'add-matrix', - ADD_ROWS=epilogue == 'add-rows', - ADD_COLS=epilogue == 'add-cols', - DO_SOFTMAX=epilogue == 'softmax', - CHAIN_DOT=epilogue == 'chain-dot', - A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], - B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], - W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], - Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], - num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, - enable_warp_specialization=ENABLE_WS) + return (triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N']), ) + + pgm = matmul_kernel[grid]( + a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_wm=w.stride(0), stride_wn=w.stride(1), # + stride_zm=z.stride(0), stride_zn=z.stride(1), # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # + out_dtype=out_dtype, # + USE_TMA_STORE=USE_TMA_STORE, # + ADD_MATRIX=epilogue == 'add-matrix', # + ADD_ROWS=epilogue == 'add-rows', # + ADD_COLS=epilogue == 'add-cols', # + DO_SOFTMAX=epilogue == 'softmax', # + CHAIN_DOT=epilogue == 'chain-dot', # + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # + W_ORDER_0=w_order[0], W_ORDER_1=w_order[1], # + Z_ORDER_0=z_order[0], Z_ORDER_1=z_order[1], # + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, # + enable_warp_specialization=ENABLE_WS) torch.set_printoptions(profile="full") golden = torch.nn.functional.normalize(golden) z = torch.nn.functional.normalize(z) - assert_close(z, golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) + assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) disable_mmav3 = os.environ.get('DISABLE_MMA_V3', 'not found').lower() if disable_mmav3 not in ["on", "true", "1"] and BLOCK_M >= 64 and NUM_CTAS == 1 and BLOCK_N <= 256: diff --git a/python/test/unit/hopper/test_gemm_fusion.py b/python/test/unit/hopper/test_gemm_fusion.py index 1fd53d5c4579..4b439efa80a6 100644 --- a/python/test/unit/hopper/test_gemm_fusion.py +++ b/python/test/unit/hopper/test_gemm_fusion.py @@ -27,16 +27,20 @@ @triton.jit -def gemm_fusion_kernel(A, B, C, E, - M, N, K, - stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, +def gemm_fusion_kernel(A, B, C, E, # + M, N, K, # + stride_am, stride_ak, stride_bn, stride_bk, stride_cn, stride_ck, stride_em, stride_ek, # BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): pid = tl.program_id(0) - a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) - c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) - e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + a_tile_ptr = tl.make_block_ptr(base=A, shape=(M, K), strides=(stride_am, stride_ak), offsets=(pid * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=B, shape=(N, K), strides=(stride_bn, stride_bk), offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) + c_tile_ptr = tl.make_block_ptr(base=C, shape=(N, K), strides=(stride_cn, stride_ck), offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_K), order=(1, 0)) + e_tile_ptr = tl.make_block_ptr(base=E, shape=(M, K), strides=(stride_em, stride_ek), offsets=(pid * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) acc_e = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32) a = tl.load(a_tile_ptr) @@ -57,66 +61,70 @@ def gemm_fusion_kernel(A, B, C, E, def test_gemm_fusion(): M, N, K = 4096, 4096, 64 BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64 - A = torch.empty( - (M, K), dtype=torch.float16, device='cuda').normal_( - mean=0.1, std=0.2) - B = torch.empty( - (N, K), dtype=torch.float16, device='cuda').normal_( - mean=0.1, std=0.2) - C = torch.empty( - (N, K), dtype=torch.float16, device='cuda').normal_( - mean=0.1, std=0.2) + A = torch.empty((M, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + B = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) + C = torch.empty((N, K), dtype=torch.float16, device='cuda').normal_(mean=0.1, std=0.2) E = torch.empty((M, K), dtype=torch.float16, device='cuda') ref_out = torch.matmul(torch.matmul(A, B.T), C) num_warps = 4 grid = (triton.cdiv(M, BLOCK_M), 1) - gemm_fusion_kernel[grid](A, B, C, E, M, N, K, - A.stride(0), A.stride(1), B.stride(0), B.stride( - 1), C.stride(0), C.stride(1), E.stride(0), E.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, num_warps=num_warps) + gemm_fusion_kernel[grid]( + A, B, C, E, M, N, K, # + A.stride(0), A.stride(1), # + B.stride(0), B.stride(1), # + C.stride(0), C.stride(1), # + E.stride(0), E.stride(1), # + BLOCK_M, BLOCK_N, BLOCK_K, # + num_warps=num_warps) torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0) @triton.jit -def batched_gemm_fusion( - Q, K, V, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - Z, NH, N_CTX, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): +def batched_gemm_fusion(Q, K, V, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, NH, N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr): start_m = tl.program_id(0) off_hz = tl.program_id(1) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - q_tile_ptr = tl.make_block_ptr(base=Q, - shape=(Z, NH, N_CTX, BLOCK_DMODEL), - strides=(stride_qz, stride_qh, stride_qm, stride_qk), - offsets=(off_hz // NH, off_hz % NH, start_m, 0), - block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), - order=(3, 2, 1, 0)) - k_tile_ptr = tl.make_block_ptr(base=K, - shape=(Z, NH, N_CTX, BLOCK_DMODEL), - strides=(stride_kz, stride_kh, stride_kn, stride_kk), - offsets=(off_hz // NH, off_hz % NH, 0, 0), - block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), - order=(3, 2, 1, 0)) - v_tile_ptr = tl.make_block_ptr(base=V, - shape=(Z, NH, N_CTX, BLOCK_DMODEL), - strides=(stride_vz, stride_vh, stride_vk, stride_vn), - offsets=(off_hz // NH, off_hz % NH, 0, 0), - block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), - order=(3, 2, 1, 0)) - o_tile_ptr = tl.make_block_ptr(base=Out, - shape=(Z, NH, N_CTX, BLOCK_DMODEL), - strides=(stride_oz, stride_oh, stride_om, stride_on), - offsets=(off_hz // NH, off_hz % NH, start_m, 0), - block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), - order=(3, 2, 1, 0)) + q_tile_ptr = tl.make_block_ptr( + base=Q, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_qz, stride_qh, stride_qm, stride_qk), + offsets=(off_hz // NH, off_hz % NH, start_m, 0), + block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + k_tile_ptr = tl.make_block_ptr( + base=K, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_kz, stride_kh, stride_kn, stride_kk), + offsets=(off_hz // NH, off_hz % NH, 0, 0), + block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + v_tile_ptr = tl.make_block_ptr( + base=V, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_vz, stride_vh, stride_vk, stride_vn), + offsets=(off_hz // NH, off_hz % NH, 0, 0), + block_shape=(1, 1, BLOCK_N, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) + o_tile_ptr = tl.make_block_ptr( + base=Out, + shape=(Z, NH, N_CTX, BLOCK_DMODEL), + strides=(stride_oz, stride_oh, stride_om, stride_on), + offsets=(off_hz // NH, off_hz % NH, start_m, 0), + block_shape=(1, 1, BLOCK_M, BLOCK_DMODEL), + order=(3, 2, 1, 0), + ) q = tl.load(q_tile_ptr, boundary_check=(0, 1, 2, 3)) q = tl.view(q, (BLOCK_M, BLOCK_DMODEL)) @@ -155,12 +163,13 @@ def test_batched_gemm_fusion(): ref_out = torch.matmul(torch.matmul(A, BT), C) num_warps = 4 grid = (triton.cdiv(N_CTX, BLOCK_M), B * NH) - batched_gemm_fusion[grid](A, B, C, E, - A.stride(0), A.stride(1), A.stride(2), A.stride(3), - B.stride(0), B.stride(1), B.stride(2), B.stride(3), - C.stride(0), C.stride(1), C.stride(2), C.stride(3), - E.stride(0), E.stride(1), E.stride(2), E.stride(3), - Z, NH, N_CTX, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps) + batched_gemm_fusion[grid]( + A, B, C, E, # + A.stride(0), A.stride(1), A.stride(2), A.stride(3), # + B.stride(0), B.stride(1), B.stride(2), B.stride(3), # + C.stride(0), C.stride(1), C.stride(2), C.stride(3), # + E.stride(0), E.stride(1), E.stride(2), E.stride(3), # + Z, NH, N_CTX, # + BLOCK_M, BLOCK_DMODEL, BLOCK_N, num_warps=num_warps) torch.testing.assert_close(ref_out, E, atol=1e-2, rtol=0) diff --git a/python/test/unit/hopper/test_mixed_io.py b/python/test/unit/hopper/test_mixed_io.py index cecabbaa732f..68ee474a495d 100644 --- a/python/test/unit/hopper/test_mixed_io.py +++ b/python/test/unit/hopper/test_mixed_io.py @@ -24,10 +24,8 @@ def add_kernel( offsets = block_start + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements - x_block_ptr = tl.make_block_ptr( - base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), - block_shape=(BLOCK_SIZE, ), order=(0, ) - ) + x_block_ptr = tl.make_block_ptr(base=x_ptr, shape=(n_elements, ), strides=(1, ), offsets=(pid * BLOCK_SIZE, ), + block_shape=(BLOCK_SIZE, ), order=(0, )) x = tl.load(x_block_ptr, boundary_check=(0, ), padding_option='zero') y = tl.load(y_ptr + offsets, mask=mask) @@ -36,9 +34,7 @@ def add_kernel( @pytest.mark.parametrize('SIZE,BLOCK_SIZE,dtype_str', - [(98432, 1024, dtype_str) - for dtype_str in ['float16', 'float32'] - ]) + [(98432, 1024, dtype_str) for dtype_str in ['float16', 'float32']]) def test_add(SIZE, BLOCK_SIZE, dtype_str): dtype = dtype_mapping[dtype_str] output = torch.empty(SIZE, device='cuda', dtype=dtype) @@ -46,7 +42,8 @@ def test_add(SIZE, BLOCK_SIZE, dtype_str): y = torch.randn(SIZE, device='cuda', dtype=dtype) def grid(meta): - return (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) + return (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + add_kernel[grid](x, y, output, SIZE, BLOCK_SIZE=BLOCK_SIZE) output_torch = x + y @@ -64,25 +61,20 @@ def load_reduce_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, ): - x_ptr = tl.make_block_ptr( - base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0) - ) + x_ptr = tl.make_block_ptr(base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) x = tl.load(x_ptr) y = tl.max(x, axis=1) tl.store(y_ptr + tl.arange(0, BLOCK_M), y) -@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', - [(128, 64, dtype_str) - for dtype_str in ['float16'] - ]) +@pytest.mark.parametrize('BLOCK_M,BLOCK_N,dtype_str', [(128, 64, dtype_str) for dtype_str in ['float16']]) def test_load_reduce(BLOCK_M, BLOCK_N, dtype_str): dtype = dtype_mapping[dtype_str] x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype) y = torch.empty((BLOCK_M, ), device='cuda', dtype=dtype) - load_reduce_kernel[(1,)](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N) + load_reduce_kernel[(1, )](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N) golden = x.max(dim=1)[0] torch.set_printoptions(profile='full') diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py b/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py index 868c052d69a2..ea1776998ef6 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_fused-attention.py @@ -18,7 +18,6 @@ # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - """ Fused Attention =============== @@ -40,18 +39,17 @@ key=['Q', 'K', 'V'], ) @triton.jit -def _fwd_kernel( - Q, K, V, sm_scale, - L, M, - Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, N_CTX, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): +def _fwd_kernel(Q, K, V, sm_scale, # + L, M, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr # + ): start_m = tl.program_id(0) off_hz = tl.program_id(1) # initialize offsets @@ -116,11 +114,10 @@ def _fwd_kernel( @triton.jit -def _bwd_preprocess( - Out, DO, L, - NewDO, Delta, - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, -): +def _bwd_preprocess(Out, DO, L, # + NewDO, Delta, # + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr # + ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_n = tl.arange(0, D_HEAD) # load @@ -136,19 +133,18 @@ def _bwd_preprocess( @triton.jit -def _bwd_kernel( - Q, K, V, sm_scale, Out, DO, - DQ, DK, DV, - L, M, - D, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - Z, H, N_CTX, - num_block, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, -): +def _bwd_kernel(Q, K, V, sm_scale, Out, DO, # + DQ, DK, DV, # + L, M, # + D, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + Z, H, N_CTX, # + num_block, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + ): off_hz = tl.program_id(0) off_z = off_hz // H off_h = off_hz % H @@ -240,16 +236,16 @@ def forward(ctx, q, k, v, sm_scale): assert num_warps == 4 _fwd_kernel[grid]( - q, k, v, sm_scale, - L, m, - o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=Lk, + q, k, v, sm_scale, # + L, m, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, # + BLOCK_DMODEL=Lk # ) ctx.save_for_backward(q, k, v, o, L, m) @@ -269,24 +265,23 @@ def backward(ctx, do): do_scaled = torch.empty_like(do) delta = torch.empty_like(l) _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )]( - o, do, l, - do_scaled, delta, - BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL, - ) - _bwd_kernel[(ctx.grid[1],)]( - q, k, v, ctx.sm_scale, - o, do_scaled, - dq, dk, dv, - l, m, - delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - q.shape[0], q.shape[1], q.shape[2], - ctx.grid[0], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8, - num_stages=1, + o, do, l, # + do_scaled, delta, # + BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL) + _bwd_kernel[(ctx.grid[1], )]( + q, k, v, ctx.sm_scale, # + o, do_scaled, # + dq, dk, dv, # + l, m, # + delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + ctx.grid[0], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, # + BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + num_warps=8, num_stages=1 # ) return dq, dk, dv, None @@ -339,19 +334,19 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16): BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 -configs = [triton.testing.Benchmark( - x_names=['N_CTX'], - # x_vals=[2**i for i in range(10, 14)], - x_vals=[2**i for i in range(10, 11)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', - args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} - # ) for mode in ['fwd', 'bwd']] -) for mode in ['fwd']] +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], + # x_vals=[2**i for i in range(10, 14)], + x_vals=[2**i + for i in range(10, 11)], line_arg='provider', line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}', + args={'H': N_HEADS, 'BATCH': BATCH, 'D_HEAD': D_HEAD, 'dtype': torch.float16, 'mode': mode} + # ) for mode in ['fwd', 'bwd']] + ) + for mode in ['fwd'] +] @triton.testing.perf_report(configs) @@ -374,9 +369,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, provider, dtype=torch.f ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms if provider == "flash": - lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) - cu_seqlens = torch.zeros( - (BATCH + 1,), device=device, dtype=torch.int32) + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) cu_seqlens[1:] = lengths.cumsum(0) qkv = torch.randn((BATCH * N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, cu_seqlens, 0., N_CTX, causal=True) diff --git a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py index 32c04c33bc31..340709a6a4a2 100644 --- a/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py +++ b/python/test/unit/hopper/test_persistent_warp_specialized_gemm.py @@ -29,14 +29,14 @@ @triton.jit -def static_persistent_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NUM_SM: tl.constexpr, +def static_persistent_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SM: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -68,14 +68,14 @@ def static_persistent_matmul_kernel( @triton.jit -def static_persistent_tma_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NUM_SM: tl.constexpr, +def static_persistent_tma_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SM: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -88,8 +88,10 @@ def static_persistent_tma_matmul_kernel( block_offset_m = pre_pid_m * BLOCK_M block_offset_n = pre_pid_n * BLOCK_N - a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) for tile_id in range(start_tile, num_tiles, NUM_SM): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles @@ -114,21 +116,23 @@ def static_persistent_tma_matmul_kernel( pre_pid_n = pid_n -@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', - [(*shape, use_tma) - for shape in [ - [4096, 4096, 64, 64, 64, 16, 4, 1, False, True], - [4096, 4096, 64, 64, 64, 32, 4, 1, False, True], - [4096, 4096, 64, 256, 64, 16, 4, 1, False, True], - [4096, 4096, 64, 128, 128, 16, 4, 1, False, True], - # TODO: fix issue for 8-warp persistent kernel - # [4096, 4096, 64, 128, 128, 16, 8, 1, False, True], - # [4096, 4096, 64, 128, 256, 16, 8, 1, False, True], - ] - for use_tma in [False, True] - ]) +@pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', [( + *shape, use_tma +) for shape in [ + [4096, 4096, 64, 64, 64, 16, 4, 1, False, True], + [4096, 4096, 64, 64, 64, 32, 4, 1, False, True + ], + [4096, 4096, 64, 256, 64, 16, 4, 1, False, True + ], + [4096, 4096, 64, 128, 128, 16, 4, 1, False, True + ], + # TODO: fix issue for 8-warp persistent kernel + # [4096, 4096, 64, 128, 128, 16, 8, 1, False, True], + # [4096, 4096, 64, 128, 256, 16, 8, 1, False, True], +] for use_tma in [False, True]]) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") -def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA): +def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, + TRANS_A, TRANS_B, USE_TMA): if (TRANS_A): a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T else: @@ -141,25 +145,33 @@ def test_user_defined_persistent_non_warp_specialized_gemm(M, N, K, BLOCK_M, BLO c = torch.empty((M, N), device=a.device, dtype=torch.float32) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) + grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) if USE_TMA: - static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) + static_persistent_tma_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), + stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, + num_ctas=NUM_CTAS) else: - static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, num_ctas=NUM_CTAS) + static_persistent_matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, M=M, N=N, K=K, stride_am=a.stride(0), + stride_ak=a.stride(1), stride_bk=b.stride(0), stride_bn=b.stride(1), + stride_cm=c.stride(0), stride_cn=c.stride(1), BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, NUM_SM=num_SMs, num_warps=NUM_WARPS, + num_ctas=NUM_CTAS) th_c = torch.matmul(a, b) torch.testing.assert_close(th_c, c, atol=1e-2, rtol=0, check_dtype=False) @triton.jit -def warp_specialized_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +def warp_specialized_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # ): tid = tl.program_id(axis=0) n_tiles = tl.cdiv(N, BLOCK_N) @@ -193,13 +205,13 @@ def warp_specialized_matmul_kernel( @triton.jit -def tma_warp_specialized_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +def tma_warp_specialized_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # ): tid = tl.program_id(axis=0) n_tiles = tl.cdiv(N, BLOCK_N) @@ -232,8 +244,7 @@ def tma_warp_specialized_matmul_kernel( @pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', - [(*shape, use_tma) - for shape in [ + [(*shape, use_tma) for shape in [ [2048, 2048, 64, 64, 64, 16, 1, False, True], [4096, 4096, 64, 64, 64, 16, 1, False, True], [128, 4096, 64, 64, 64, 16, 1, False, True], @@ -257,9 +268,7 @@ def tma_warp_specialized_matmul_kernel( [4096, 4096, 128, 256, 128, 64, 4, False, True], [4096, 4096, 256, 128, 256, 64, 4, False, True], [4096, 4096, 256, 256, 256, 64, 4, False, True], - ] - for use_tma in [False, True] - ]) + ] for use_tma in [False, True]]) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA): if (TRANS_A): @@ -274,29 +283,29 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K c = torch.empty((M, N), device=a.device, dtype=torch.float32) - grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),) + grid = lambda META: (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), ) if USE_TMA: tma_warp_specialized_matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, - num_warps=4, - num_ctas=NUM_CTAS, + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_M, BLOCK_N, BLOCK_K, # + num_warps=4, # + num_ctas=NUM_CTAS, # enable_warp_specialization=True) else: warp_specialized_matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, - num_warps=4, - num_ctas=NUM_CTAS, + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_M, BLOCK_N, BLOCK_K, # + num_warps=4, # + num_ctas=NUM_CTAS, # enable_warp_specialization=True) th_c = torch.matmul(a, b) @@ -304,14 +313,14 @@ def test_non_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K @triton.jit -def static_persistent_warp_specialized_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NUM_SM: tl.constexpr, +def static_persistent_warp_specialized_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SM: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -343,14 +352,14 @@ def static_persistent_warp_specialized_matmul_kernel( @triton.jit -def static_persistent_tma_warp_specialized_matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - NUM_SM: tl.constexpr, +def static_persistent_tma_warp_specialized_matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + NUM_SM: tl.constexpr # ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) @@ -363,8 +372,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel( block_offset_m = pre_pid_m * BLOCK_M block_offset_n = pre_pid_n * BLOCK_N - a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) for tile_id in range(start_tile, num_tiles, NUM_SM): pid_m = tile_id // n_tiles pid_n = tile_id % n_tiles @@ -390,8 +401,7 @@ def static_persistent_tma_warp_specialized_matmul_kernel( @pytest.mark.parametrize('M,N,K,BLOCK_M,BLOCK_N,BLOCK_K,NUM_CTAS,TRANS_A,TRANS_B,USE_TMA', - [(*shape, use_tma) - for shape in [ + [(*shape, use_tma) for shape in [ [2048, 2048, 64, 64, 64, 16, 1, False, True], [4096, 4096, 64, 64, 64, 16, 1, False, True], [128, 4096, 64, 64, 64, 16, 1, False, True], @@ -415,11 +425,10 @@ def static_persistent_tma_warp_specialized_matmul_kernel( [4096, 4096, 128, 256, 128, 64, 4, False, True], [4096, 4096, 256, 128, 256, 64, 4, False, True], [4096, 4096, 256, 256, 256, 64, 4, False, True], - ] - for use_tma in [False, True] - ]) + ] for use_tma in [False, True]]) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") -def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, USE_TMA): +def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, NUM_CTAS, TRANS_A, TRANS_B, + USE_TMA): if (TRANS_A): a = .1 * torch.randn((K, M), device='cuda', dtype=torch.float16).T else: @@ -432,27 +441,22 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N c = torch.empty((M, N), device=a.device, dtype=torch.float32) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count - grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) + grid = lambda META: (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) if USE_TMA: static_persistent_tma_warp_specialized_matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, - num_warps=4, num_ctas=NUM_CTAS, + a, b, c, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), c.stride(0), c.stride(1), BLOCK_M, + BLOCK_N, BLOCK_K, num_SMs, num_warps=4, num_ctas=NUM_CTAS, # enable_warp_specialization=True) else: static_persistent_warp_specialized_matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, - num_warps=4, num_ctas=NUM_CTAS, + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + BLOCK_M, BLOCK_N, BLOCK_K, num_SMs, # + num_warps=4, num_ctas=NUM_CTAS, # enable_warp_specialization=True) th_c = torch.matmul(a, b) @@ -460,16 +464,15 @@ def test_user_defined_persistent_warp_specialized_gemm(M, N, K, BLOCK_M, BLOCK_N @triton.jit -def static_persistent_matmul_no_scf_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, - NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr, -): +def static_persistent_matmul_no_scf_kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + FLOAT16_OUTPUT: tl.constexpr, USE_TMA_EPILOGUE: tl.constexpr, # + NUM_SM: tl.constexpr, USE_TMA_LOAD: tl.constexpr # + ): start_tile = tl.program_id(axis=0) m_tiles = tl.cdiv(M, BLOCK_M) n_tiles = tl.cdiv(N, BLOCK_N) @@ -487,7 +490,8 @@ def static_persistent_matmul_no_scf_kernel( offsets=(0, block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) if USE_TMA_EPILOGUE: c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), - offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0)) for tile_id in range(start_tile, num_tiles, NUM_SM): pid_m = tile_id // n_tiles @@ -524,29 +528,27 @@ def static_persistent_matmul_no_scf_kernel( pre_pid_n = pid_n -@pytest.mark.parametrize('M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD', - itertools.chain( - *[ - [ - # numCTAs = 1, no TMA multicast: - [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - # small M, N - [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], - ] for USE_TMA_EPILOGUE in [True, False] - for USE_TMA_LOAD in [True, False] - ])) +@pytest.mark.parametrize( + 'M,N,K,NUM_CTAS,NUM_WARPS,TRANS_A,TRANS_B,OUTPUT_TYPE,USE_TMA_EPILOGUE,USE_TMA_LOAD', + itertools.chain(*[[ + # numCTAs = 1, no TMA multicast: + [64, 16, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 32, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 32, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [64, 64, 64, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [128, 128, 16, 1, 4, False, True, "float16", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [128, 128, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + # small M, N + [16, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [16, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [32, 16, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + [32, 32, 16, 1, 4, False, True, "float32", USE_TMA_EPILOGUE, USE_TMA_LOAD], + ] for USE_TMA_EPILOGUE in [True, False] for USE_TMA_LOAD in [True, False]])) @pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") -def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, USE_TMA_EPILOGUE, USE_TMA_LOAD): +def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_TYPE, + USE_TMA_EPILOGUE, USE_TMA_LOAD): if (TRANS_A): a = torch.randn((K, M), device='cuda', dtype=torch.float16).T else: @@ -564,46 +566,42 @@ def test_static_persistent_matmul_no_scf_kernel(M, N, K, NUM_CTAS, NUM_WARPS, TR num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count # TODO: set `enable_warp_specialization=False` will lead to compilation error. - static_persistent_matmul_no_scf_kernel[(num_SMs,)](a_ptr=a, b_ptr=b, c_ptr=c, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1), - BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, - num_warps=NUM_WARPS, - num_ctas=NUM_CTAS, - FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), - USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, - USE_TMA_LOAD=USE_TMA_LOAD, - enable_warp_specialization=True) + static_persistent_matmul_no_scf_kernel[(num_SMs, )]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=M if M < 128 else M // 2, BLOCK_N=N if N < 128 else N // 2, BLOCK_K=K, NUM_SM=num_SMs, # + num_warps=NUM_WARPS, # + num_ctas=NUM_CTAS, # + FLOAT16_OUTPUT=(OUTPUT_TYPE == "float16"), # + USE_TMA_EPILOGUE=USE_TMA_EPILOGUE, # + USE_TMA_LOAD=USE_TMA_LOAD, # + enable_warp_specialization=True) a_f32 = a.to(torch.float32) b_f32 = b.to(torch.float32) golden = torch.matmul(a_f32, b_f32) torch.set_printoptions(profile="full") - assert_close( - c, - golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) @triton.jit -def full_static_persistent_matmul_kernel( - a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_wm, stride_wn, - stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, - ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, - DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, - A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, - B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, - NUM_SM: tl.constexpr -): +def full_static_persistent_matmul_kernel(a_ptr, b_ptr, w_ptr, bias_ptr, z_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_wm, stride_wn, # + stride_zm, stride_zn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, # + out_dtype: tl.constexpr, USE_TMA_STORE: tl.constexpr, # + ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, # + DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, # + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, # + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr, # + NUM_SM: tl.constexpr # + ): start_pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_N) num_pid_m = tl.cdiv(M, BLOCK_M) @@ -618,15 +616,18 @@ def full_static_persistent_matmul_kernel( pre_block_offset_m = pre_pid_m * BLOCK_M pre_block_offset_n = pre_pid_n * BLOCK_N a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), order=(A_ORDER_0, A_ORDER_1)) + offsets=(pre_block_offset_m, 0), block_shape=(BLOCK_M, BLOCK_K), + order=(A_ORDER_0, A_ORDER_1)) b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), order=(B_ORDER_0, B_ORDER_1)) + offsets=(0, pre_block_offset_n), block_shape=(BLOCK_K, BLOCK_N), + order=(B_ORDER_0, B_ORDER_1)) w_tile_ptr = tl.make_block_ptr(base=w_ptr, shape=(N, N), strides=(stride_wm, stride_wn), offsets=(0, pre_block_offset_n), block_shape=(BLOCK_N, BLOCK_N), order=(0, 1)) if USE_TMA_STORE: z_block_ptr = tl.make_block_ptr(base=z_ptr, shape=(M, N), strides=(stride_zm, stride_zn), - offsets=(pre_block_offset_m, pre_block_offset_n), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + offsets=(pre_block_offset_m, pre_block_offset_n), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) for tile_id in range(start_pid, num_tiles, NUM_SM): group_id = tile_id // num_pid_in_group @@ -694,136 +695,120 @@ def full_static_persistent_matmul_kernel( pre_pid_n = pid_n -@pytest.mark.parametrize('BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS', - [ - # corner shapes - (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) - for shape_w_c in [ - [4096, 1, 1024, False, False], - [2048, 204, 1000, True, False], - [16, 524288, 32, False, True], - ] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for enable_ws in [True] - ] + [ - # softmax epilogue - (*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) - # softmax works for one CTA - for shape_w_c in [ - [64, 64, 16, 4, 1, 64, 64, 64], - [128, 128, 64, 4, 1, None, None, None], - [16, 16, 64, 4, 1, 16, 16, 64], - # TODO: enable when num_warps != 4 is supported. - # [64, 64, 32, 8, 1, 64, 64, 64], - [128, 128, 64, 4, 1, 128, 128, 128], - ] - for epilogue in ['softmax'] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for trans_a in [False,] - for trans_b in [True,] - for num_stages in [3] - for enable_ws in [True] - ] + [ - # loop over tile shapes and transpose combinations - (*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [64, 64, 32, 4, 1, 128, 256, 64], - [128, 128, 16, 4, 4, 512, 256, 64], - [128, 256, 32, 4, 8, 256, 256, 192], - [512, 256, 32, 4, 8, 1024, 256, 192], - # BLOCK_K >= 128 - [64, 128, 128, 4, 1, 512, 256, 256], - [128, 128, 128, 4, 1, 256, 256, 192], - [128, 128, 128, 4, 2, 256, 256, 192], - # small BLOCK_M and BLOCK_K - [16, 32, 32, 4, 1, 128, 256, 64], - [32, 32, 16, 4, 1, 256, 256, 192], - [16, 32, 64, 4, 4, 512, 256, 64], - ] - for out_dtype in ['float32',] - for use_tma_store in [False,] - for trans_a in [False, True] - for trans_b in [False, True] - for num_stages in [3] - for enable_ws in [True] - ] + [ - # loop over epilogues besides of softmax - (*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [64, 64, 16, 4, 1, 128, 128, 64], - *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]], - # for chain-dot - [128, 128, 64, 4, 1, None, None, None], - [64, 64, 16, 4, 1, None, None, None], - # small BLOCK_M and BLOCK_K - [16, 16, 64, 4, 1, 128, 128, 64], - *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]], - # # TODO: enable when num_warps != 4 is supported. - # # repeat - # # [64, 64, 32, 8, 1, 128, 256, 64], - # # [64, 64, 16, 8, 2, 128, 128, 64], - # irregular shape - [128, 128, 64, 4, 1, 500, 200, 128], - [128, 128, 64, 4, 1, 513, 193, 192], - ] - for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] - for out_dtype in ['float16', 'float32'] - for use_tma_store in [False, True] - for trans_a in [False,] - for trans_b in [True,] - for num_stages in [3] - for enable_ws in [True] - if not (epilogue == 'chain-dot' and (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1])) - ] + [ - # loop over instr shapes & pipeline stages - (64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for n in [16, 32, 64, 128, 256] - for out_dtype in ['float32'] - for use_tma_store in [False,] - for num_stages in [2, 4, 5, 7] - for enable_ws in [True] - ] + [ - # irregular shapes - (*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws) - for shape_w_c in [ - [128, 128, 64, 4, 1], - [256, 128, 64, 4, 2], - [128, 128, 128, 4, 2] - ] - for shape in [ - [512, 360, 1024], - [360, 4096, 512], - ] - for out_dtype in ['float32'] - for use_tma_store in [False, True] - for num_stages in [3, 4] - for enable_ws in [True] - ] - ) -@pytest.mark.skipif(torch.cuda.get_device_capability() - [0] < 9, reason="Requires compute capability >= 9") -def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS): - if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS])) in [ - '128-128-128-4-1-256-256-192-none-float32-True-3-True', - ]: +@pytest.mark.parametrize( + 'BLOCK_M,BLOCK_N,BLOCK_K,NUM_WARPS,NUM_CTAS,M,N,K,TRANS_A,TRANS_B,epilogue,out_dtype,USE_TMA_STORE,NUM_STAGES,ENABLE_WS', + [ + # corner shapes + (128, 128, 64, 4, 1, *shape_w_c, 'none', out_dtype, use_tma_store, 3, enable_ws) for shape_w_c in [ + [4096, 1, 1024, False, False], + [2048, 204, 1000, True, False], + [16, 524288, 32, False, True], + ] for out_dtype in ['float16', 'float32'] for use_tma_store in [False, True] for enable_ws in [True] + ] + [ + # softmax epilogue + (*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) + # softmax works for one CTA + for shape_w_c in [ + [64, 64, 16, 4, 1, 64, 64, 64], + [128, 128, 64, 4, 1, None, None, None], + [16, 16, 64, 4, 1, 16, 16, 64], + # TODO: enable when num_warps != 4 is supported. + # [64, 64, 32, 8, 1, 64, 64, 64], + [128, 128, 64, 4, 1, 128, 128, 128], + ] + for epilogue in ['softmax'] + for out_dtype in ['float16', 'float32'] + for use_tma_store in [False, True] + for trans_a in [False] + for trans_b in [True] + for num_stages in [3] + for enable_ws in [True] + ] + [ + # loop over tile shapes and transpose combinations + (*shape_w_c, trans_a, trans_b, 'none', out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [ + [64, 64, 32, 4, 1, 128, 256, 64], + [128, 128, 16, 4, 4, 512, 256, 64], + [128, 256, 32, 4, 8, 256, 256, 192], + [512, 256, 32, 4, 8, 1024, 256, 192], + # BLOCK_K >= 128 + [64, 128, 128, 4, 1, 512, 256, 256], + [128, 128, 128, 4, 1, 256, 256, 192], + [128, 128, 128, 4, 2, 256, 256, 192], + # small BLOCK_M and BLOCK_K + [16, 32, 32, 4, 1, 128, 256, 64], + [32, 32, 16, 4, 1, 256, 256, 192], + [16, 32, 64, 4, 4, 512, 256, 64], + ] for out_dtype in ['float32'] for use_tma_store in [False] for trans_a in [False, True] for trans_b in + [False, True] for num_stages in [3] for enable_ws in [True] + ] + [ + # loop over epilogues besides of softmax + (*shape_w_c, trans_a, trans_b, epilogue, out_dtype, use_tma_store, num_stages, enable_ws) for shape_w_c in [ + [64, 64, 16, 4, 1, 128, 128, 64], + *[[256, 64, 16, num_warps, num_ctas, 256, 256, 64] for num_warps in [4] for num_ctas in [1, 2, 4]], + # for chain-dot + [128, 128, 64, 4, 1, None, None, None], + [64, 64, 16, 4, 1, None, None, None], + # small BLOCK_M and BLOCK_K + [16, 16, 64, 4, 1, 128, 128, 64], + *[[16, 32, 64, num_warps, num_ctas, 256, 256, 256] for num_warps in [4] for num_ctas in [1, 2]], + # # TODO: enable when num_warps != 4 is supported. + # # repeat + # # [64, 64, 32, 8, 1, 128, 256, 64], + # # [64, 64, 16, 8, 2, 128, 128, 64], + # irregular shape + [128, 128, 64, 4, 1, 500, 200, 128], + [128, 128, 64, 4, 1, 513, 193, 192], + ] for epilogue in ['none', 'add-matrix', 'add-rows', 'add-cols', 'chain-dot'] for out_dtype in + ['float16', 'float32'] for use_tma_store in [False, True] for trans_a in [False] for trans_b in [True] for + num_stages in [3] for enable_ws in [True] if not (epilogue == 'chain-dot' and + (shape_w_c[5] is not None or shape_w_c[0] != shape_w_c[1])) + ] + [ + # loop over instr shapes & pipeline stages + (64, n, 16, 4, 1, 512, 256, 256, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for n in [16, 32, 64, 128, 256] + for out_dtype in ['float32'] + for use_tma_store in [False] + for num_stages in [2, 4, 5, 7] + for enable_ws in [True] + ] + [ + # irregular shapes + (*shape_w_c, *shape, False, True, 'none', out_dtype, use_tma_store, num_stages, enable_ws) + for shape_w_c in [[128, 128, 64, 4, 1], [256, 128, 64, 4, 2], [128, 128, 128, 4, 2]] + for shape in [ + [512, 360, 1024], + [360, 4096, 512], + ] + for out_dtype in ['float32'] + for use_tma_store in [False, True] + for num_stages in [3, 4] + for enable_ws in [True] + ]) +@pytest.mark.skipif(torch.cuda.get_device_capability()[0] < 9, reason="Requires compute capability >= 9") +def test_full_static_persistent_matmul_kernel(BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B, + epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, ENABLE_WS): + if '-'.join( + map(str, [ + BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, epilogue, out_dtype, USE_TMA_STORE, NUM_STAGES, + ENABLE_WS + ])) in [ + '128-128-128-4-1-256-256-192-none-float32-True-3-True', + ]: pytest.skip('out of resource: shared memory, Required: 263168') if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_A, TRANS_B])) in [ - '16-32-64-4-4-512-256-64-True-False', - '16-32-64-4-4-512-256-64-True-True', - '16-32-64-4-4-512-256-64-False-False', - '16-32-64-4-4-512-256-64-False-True', + '16-32-64-4-4-512-256-64-True-False', + '16-32-64-4-4-512-256-64-True-True', + '16-32-64-4-4-512-256-64-False-False', + '16-32-64-4-4-512-256-64-False-True', ]: pytest.skip('shapePerCTA[1] < 16 not supported') if '-'.join(map(str, [BLOCK_M, BLOCK_N, BLOCK_K, NUM_WARPS, NUM_CTAS, M, N, K, TRANS_B])) in [ - '16-32-64-4-1-256-256-256-False', - '16-32-64-4-2-256-256-256-False', - '16-32-64-4-2-256-256-256-True', - '16-32-64-8-2-256-256-256-False', - '16-32-64-8-2-256-256-256-True', + '16-32-64-4-1-256-256-256-False', + '16-32-64-4-2-256-256-256-False', + '16-32-64-4-2-256-256-256-True', + '16-32-64-8-2-256-256-256-False', + '16-32-64-8-2-256-256-256-True', ]: pytest.skip('Known legacy issue, ldmatrix can only support x4') @@ -893,37 +878,36 @@ def process_epilogue(d, bias, w, epilogue): else: ref = d return ref + golden = process_epilogue(dot, bias, w, epilogue) num_SMs = torch.cuda.get_device_properties('cuda').multi_processor_count def grid(META): - return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])),) + return (min(META['NUM_SM'], triton.cdiv(M, META['BLOCK_M']) * triton.cdiv(N, META['BLOCK_N'])), ) + full_static_persistent_matmul_kernel[grid]( - a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_wm=w.stride(0), stride_wn=w.stride(1), - stride_zm=z.stride(0), stride_zn=z.stride(1), - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, - out_dtype=out_dtype, - USE_TMA_STORE=USE_TMA_STORE, - ADD_MATRIX=epilogue == 'add-matrix', - ADD_ROWS=epilogue == 'add-rows', - ADD_COLS=epilogue == 'add-cols', - DO_SOFTMAX=epilogue == 'softmax', - CHAIN_DOT=epilogue == 'chain-dot', - A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], - B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], - num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, - enable_warp_specialization=ENABLE_WS, + a_ptr=a, b_ptr=b, w_ptr=w, bias_ptr=bias, z_ptr=z, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_wm=w.stride(0), stride_wn=w.stride(1), # + stride_zm=z.stride(0), stride_zn=z.stride(1), # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_K=BLOCK_K, GROUP_SIZE_M=8, # + out_dtype=out_dtype, # + USE_TMA_STORE=USE_TMA_STORE, # + ADD_MATRIX=epilogue == 'add-matrix', # + ADD_ROWS=epilogue == 'add-rows', # + ADD_COLS=epilogue == 'add-cols', # + DO_SOFTMAX=epilogue == 'softmax', # + CHAIN_DOT=epilogue == 'chain-dot', # + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1], # + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, num_stages=NUM_STAGES, # + enable_warp_specialization=ENABLE_WS, # NUM_SM=num_SMs) torch.set_printoptions(profile="full") golden = torch.nn.functional.normalize(golden) z = torch.nn.functional.normalize(z) - assert_close(z, golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) + assert_close(z, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/python/test/unit/hopper/test_tma_store_gemm.py b/python/test/unit/hopper/test_tma_store_gemm.py index 6d912d89caed..b2fc3e8745dc 100644 --- a/python/test/unit/hopper/test_tma_store_gemm.py +++ b/python/test/unit/hopper/test_tma_store_gemm.py @@ -19,7 +19,6 @@ # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - import pytest import torch from torch.testing import assert_close @@ -29,21 +28,21 @@ @triton.jit -def matmul_tma_load_store( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - OUTPUT_F16: tl.constexpr +def matmul_tma_load_store( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + OUTPUT_F16: tl.constexpr # ): - a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) - c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(0, 1)) + c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)) a = tl.load(a_block_ptr) b = tl.load(b_block_ptr) @@ -78,15 +77,15 @@ def test_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F if OUTPUT_F16: c = torch.empty((M, N), device=a.device, dtype=torch.float16) - matmul_tma_load_store[(1, 1)](a_ptr=a, b_ptr=b, c_ptr=c, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1), - BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, - num_warps=NUM_WARPS, - num_ctas=NUM_CTAS, - OUTPUT_F16=OUTPUT_F16) + matmul_tma_load_store[(1, 1)]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=M, BLOCK_N=N, BLOCK_K=K, # + num_warps=NUM_WARPS, num_ctas=NUM_CTAS, # + OUTPUT_F16=OUTPUT_F16) golden = torch.matmul(a, b) torch.set_printoptions(profile="full") assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/python/test/unit/hopper/ttgir_tests/test_tma.py b/python/test/unit/hopper/ttgir_tests/test_tma.py index d48d2aa42986..0ee725b4bf9c 100644 --- a/python/test/unit/hopper/ttgir_tests/test_tma.py +++ b/python/test/unit/hopper/ttgir_tests/test_tma.py @@ -54,17 +54,13 @@ def test_tma_wgmma_64_64_16_f16(TTGIR, TRANS_A, TRANS_B): ttgir_path = os.path.dirname(__file__) + "/" + TTGIR kernel = triton.compile(ttgir_path) - kernel[(1, 1, 1)](a.data_ptr(), b.data_ptr(), c.data_ptr(), - SIZE_M, SIZE_N, SIZE_K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0)) + kernel[(1, 1, 1)]( # + a.data_ptr(), b.data_ptr(), c.data_ptr(), # + SIZE_M, SIZE_N, SIZE_K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0)) golden = torch.matmul(a, b) torch.set_printoptions(profile="full", sci_mode=False) - assert_close( - c, - golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) + assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) diff --git a/python/test/unit/language/assert_helper.py b/python/test/unit/language/assert_helper.py index 6b6742e4a99c..f97697b6c307 100644 --- a/python/test/unit/language/assert_helper.py +++ b/python/test/unit/language/assert_helper.py @@ -48,17 +48,17 @@ def test_assert(func: str): x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda') y = torch.zeros(shape, dtype=x.dtype, device="cuda") if func == "device_assert": - kernel_device_assert[(1,)](x, y, BLOCK=shape[0]) + kernel_device_assert[(1, )](x, y, BLOCK=shape[0]) if func == "device_assert_passes": # Assert passes; no error. - kernel_assert_passes[(1,)](x, y, BLOCK=shape[0]) + kernel_assert_passes[(1, )](x, y, BLOCK=shape[0]) elif func == "no_debug": # TRITON_DEBUG=1 can override the debug flag - kernel_device_assert_no_debug[(1,)](x, y, BLOCK=shape[0]) + kernel_device_assert_no_debug[(1, )](x, y, BLOCK=shape[0]) elif func == "assert": - kernel_assert[(1,)](x, y, BLOCK=shape[0]) + kernel_assert[(1, )](x, y, BLOCK=shape[0]) elif func == "static_assert": - kernel_static_assert[(1,)](x, y, BLOCK=shape[0]) + kernel_static_assert[(1, )](x, y, BLOCK=shape[0]) elif func == "double_assert": # Launching a different kernel after the first one asserted used to # segfault. What seems to have happened is: @@ -70,8 +70,8 @@ def test_assert(func: str): # - Now the GPU is in an error state. We need to detect this inside # the kernel-launch/loading code and bail out properly. If we don't, # we segfault. - kernel_device_assert[(1,)](x, y, BLOCK=shape[0]) - kernel_assert_passes[(1,)](x, y, BLOCK=shape[0]) + kernel_device_assert[(1, )](x, y, BLOCK=shape[0]) + kernel_assert_passes[(1, )](x, y, BLOCK=shape[0]) assert_close(y, x) @@ -131,11 +131,11 @@ def test_assert_nested(caller: str, callee: str): x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda') y = torch.zeros(shape, dtype=x.dtype, device="cuda") if caller == "none": - kernel_device_assert_nested[(1,)](x, y, BLOCK=shape[0], jit_debug=callee) + kernel_device_assert_nested[(1, )](x, y, BLOCK=shape[0], jit_debug=callee) elif caller == "true": - kernel_device_assert_nested_true[(1,)](x, y, BLOCK=shape[0], jit_debug=callee) + kernel_device_assert_nested_true[(1, )](x, y, BLOCK=shape[0], jit_debug=callee) elif caller == "false": - kernel_device_assert_nested_false[(1,)](x, y, BLOCK=shape[0], jit_debug=callee) + kernel_device_assert_nested_false[(1, )](x, y, BLOCK=shape[0], jit_debug=callee) assert_close(y, x) diff --git a/python/test/unit/language/conftest.py b/python/test/unit/language/conftest.py index f9e96688b921..7a02d322b49f 100644 --- a/python/test/unit/language/conftest.py +++ b/python/test/unit/language/conftest.py @@ -4,9 +4,7 @@ def pytest_addoption(parser): - parser.addoption( - "--device", action="store", default='cuda' - ) + parser.addoption("--device", action="store", default='cuda') @pytest.fixture diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index fca027355c39..6776f09c1ccd 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -36,14 +36,14 @@ def kernel_device_print_large( @triton.jit def kernel_print_multiple_args(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) - y = tl.full((BLOCK,), 1, tl.int32) + y = tl.full((BLOCK, ), 1, tl.int32) print("", x, y) @triton.jit def kernel_device_print_multiple_args(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) - y = tl.full((BLOCK,), 1, tl.int32) + y = tl.full((BLOCK, ), 1, tl.int32) tl.device_print("", x, y) tl.store(Y + tl.arange(0, BLOCK), y) @@ -72,21 +72,21 @@ def test_print(func: str, data_type: str): x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type)) y = torch.zeros(shape, dtype=x.dtype, device="cuda") if func == "device_print": - kernel_device_print[(1,)](x, y, BLOCK=shape[0]) + kernel_device_print[(1, )](x, y, BLOCK=shape[0]) elif func == "print": - kernel_print[(1,)](x, y, BLOCK=shape[0]) + kernel_print[(1, )](x, y, BLOCK=shape[0]) elif func == "device_print_large": kernel_device_print_large[(1, 2)](BLOCK_M=64, BLOCK_N=128) elif func == "print_multiple_args": - kernel_print_multiple_args[(1,)](x, y, BLOCK=shape[0]) + kernel_print_multiple_args[(1, )](x, y, BLOCK=shape[0]) elif func == "device_print_multiple_args": - kernel_device_print_multiple_args[(1,)](x, y, BLOCK=shape[0]) + kernel_device_print_multiple_args[(1, )](x, y, BLOCK=shape[0]) elif func == "static_print": - kernel_static_print[(1,)](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4()) + kernel_static_print[(1, )](x, y, BLOCK=shape[0], PLACEHOLDER=uuid.uuid4()) elif func == "no_arg_print": - kernel_no_arg_print[(1,)](num_warps=4) + kernel_no_arg_print[(1, )](num_warps=4) elif func == "print_no_arg": - kernel_print_no_arg[(1,)](num_warps=4) + kernel_print_no_arg[(1, )](num_warps=4) else: assert f"Unknown kernel: {func}" diff --git a/python/test/unit/language/test_annotations.py b/python/test/unit/language/test_annotations.py index 0e18c950c47d..26bb40664904 100644 --- a/python/test/unit/language/test_annotations.py +++ b/python/test/unit/language/test_annotations.py @@ -1,4 +1,3 @@ - from __future__ import annotations import torch @@ -14,8 +13,8 @@ def _kernel(X: torch.Tensor, N: int, BLOCK_SIZE: tl.constexpr): pass x = torch.empty(1, device=device) - _kernel[(1,)](x, x.shape[0], 32) + _kernel[(1, )](x, x.shape[0], 32) try: - _kernel[(1,)](x.shape[0], x.shape[0], 32) + _kernel[(1, )](x.shape[0], x.shape[0], 32) except AttributeError: pass diff --git a/python/test/unit/language/test_block_pointer.py b/python/test/unit/language/test_block_pointer.py index 3cc4bdced339..f4ee3414d0c2 100644 --- a/python/test/unit/language/test_block_pointer.py +++ b/python/test/unit/language/test_block_pointer.py @@ -17,10 +17,12 @@ def block_copy_kernel(a_ptr, b_ptr, N, BLOCK_SIZE: tl.constexpr, padding_option: tl.store(b_block_ptr, a, boundary_check=(0, )) -@pytest.mark.parametrize("dtype_str, n, padding_option", - [(dtype_str, n, padding) for dtype_str in ("bool", "int16", "float16") - for n in (64, 128, 256, 512, 1024) - for padding in ("zero", "nan")]) +@pytest.mark.parametrize("dtype_str, n, padding_option", [ # + (dtype_str, n, padding) + for dtype_str in ("bool", "int16", "float16") + for n in (64, 128, 256, 512, 1024) + for padding in ("zero", "nan") # +]) def test_block_copy(dtype_str, n, padding_option): capability = torch.cuda.get_device_capability() if capability[0] >= 9: @@ -35,31 +37,31 @@ def test_block_copy(dtype_str, n, padding_option): a = torch.randn((n, ), device="cuda", dtype=dtype) b = torch.zeros((n, ), device="cuda", dtype=dtype) - grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]),) + grid = lambda meta: (triton.cdiv(n, meta["BLOCK_SIZE"]), ) block_copy_kernel[grid](a_ptr=a, b_ptr=b, N=n, BLOCK_SIZE=64, padding_option=padding_option) - assert torch.all(a[0: n // 2] == b[0: n // 2]) + assert torch.all(a[0:n // 2] == b[0:n // 2]) if padding_option == "zero": - assert torch.all(b[n // 2: n] == 0) + assert torch.all(b[n // 2:n] == 0) else: - assert torch.all(torch.isnan(b[n // 2: n])) + assert torch.all(torch.isnan(b[n // 2:n])) @triton.jit -def matmul_no_scf_with_advance_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr +def matmul_no_scf_with_advance_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr # ): offs_m = tl.arange(0, BLOCK_M) offs_n = tl.arange(0, BLOCK_N) - a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) - b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, 0), block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) + a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0), + block_shape=(BLOCK_M, BLOCK_K), order=(1, 0)) + b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0), + block_shape=(BLOCK_K, BLOCK_N), order=(1, 0)) # Below two lines are just for testing negative offsets for the `advance` API, which could be removed a_block_ptr = tl.advance(a_block_ptr, (BLOCK_M, -BLOCK_K)) a_block_ptr = tl.advance(a_block_ptr, (-BLOCK_M, BLOCK_K)) @@ -71,14 +73,12 @@ def matmul_no_scf_with_advance_kernel( tl.store(c_ptrs, c) -@pytest.mark.parametrize("shape, num_warps", [ - (shape, num_warps) - for shape in [ +@pytest.mark.parametrize("shape, num_warps", [ # + (shape, num_warps) for shape in [ [64, 64, 16], [64, 64, 32], [64, 64, 64], - ] - for num_warps in [4, 8] + ] for num_warps in [4, 8] ]) def test_block_ptr_matmul_no_scf(shape, num_warps): capability = torch.cuda.get_device_capability() @@ -91,12 +91,13 @@ def test_block_ptr_matmul_no_scf(shape, num_warps): c = torch.empty((m, n), device="cuda", dtype=torch.float32) grid = lambda META: (1, ) - matmul_no_scf_with_advance_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, - M=m, N=n, K=k, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1), - BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, - num_warps=num_warps) + matmul_no_scf_with_advance_kernel[grid]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=m, N=n, K=k, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1), # + BLOCK_M=m, BLOCK_N=n, BLOCK_K=k, # + num_warps=num_warps) golden = torch.matmul(a, b) torch.testing.assert_close(c, golden, check_dtype=False) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 3e1aa51dd49e..1c92ef6896b8 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -61,8 +61,7 @@ def numpy_random(shape, dtype_str, rs: Optional[RandomState] = None, low=None, h elif dtype_str in float_dtypes: return rs.normal(0, 1, shape).astype(dtype_str) elif dtype_str == 'bfloat16': - return (rs.normal(0, 1, shape).astype('float32').view('uint32') - & np.uint32(0xffff0000)).view('float32') + return (rs.normal(0, 1, shape).astype('float32').view('uint32') & np.uint32(0xffff0000)).view('float32') elif dtype_str in ['bool', 'int1', 'bool_']: return rs.normal(0, 1, shape) > 0.0 else: @@ -135,6 +134,7 @@ def check_type_supported(dtype, device): class MmaLayout: + def __init__(self, version, warps_per_cta, ctas_per_cga, cta_split_num, cta_order, instr_shape): self.version = version self.warps_per_cta = str(warps_per_cta) @@ -148,6 +148,7 @@ def __str__(self): class BlockedLayout: + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): self.sz_per_thread = str(size_per_thread) self.threads_per_warp = str(threads_per_warp) @@ -162,6 +163,7 @@ def __str__(self): class SharedLayout: + def __init__(self, vec, per_phase, max_phase, order, ctas_per_cga, cta_split_num, cta_order): self.vec = str(vec) self.per_phase = str(per_phase) @@ -182,6 +184,7 @@ def test_empty_kernel(dtype_x, device): @triton.jit def kernel(X, SIZE: tl.constexpr): pass + check_type_supported(dtype_x, device) x = to_triton(numpy_random(SIZE, dtype_str=dtype_x), device=device, dst_type=dtype_x) kernel[(1, )](x, SIZE=SIZE, num_warps=4) @@ -246,7 +249,8 @@ def _binary_op_dtype_override(a: str, b: str) -> Optional[np.dtype]: return overrides.get(key) -def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, y_low=None, y_high=None): +def _test_binary(dtype_x, dtype_y, expr, numpy_expr=None, mode_x='real', mode_y='real', device='cuda', num_ctas=1, + y_low=None, y_high=None): check_type_supported(dtype_x, device) # early return if dtype_x is not supported check_type_supported(dtype_y, device) SIZE = 128 @@ -278,8 +282,7 @@ def kernel(Z, X, Y, SIZE: tl.constexpr): x_tri = to_triton(x, device=device, dst_type=dtype_x) y_tri = to_triton(y, device=device, dst_type=dtype_y) z_tri = to_triton(np.empty(SIZE, dtype=z_ref.dtype), device=device) - kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, - num_warps=4, num_ctas=num_ctas) + kernel[(1, )](z_tri, x_tri, y_tri, SIZE=SIZE, num_warps=4, num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), err_msg=expr, rtol=0.01) @@ -308,12 +311,13 @@ def _mod_operation_ill_conditioned(dtype_x, dtype_y) -> bool: ('uint64', 'float64'), ] + # --------------- # test binary ops # --------------- -@pytest.mark.parametrize("dtype_x, dtype_y, op", [ +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['+', '-', '*', '/', '%'] for dtype_x in dtypes_with_bfloat16 @@ -325,7 +329,8 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): if op == '%' and dtype_x in int_dtypes + uint_dtypes and dtype_y in int_dtypes + uint_dtypes: # LLVM has 'numpy.fmod', not 'numpy.remainder', semantics on integer remainders. numpy_expr = 'np.fmod(x, y)' - elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', 'bfloat16'): + elif op in ('/', '%') and dtype_x in ('int16', 'float16', 'bfloat16') and dtype_y in ('int16', 'float16', + 'bfloat16'): # Triton promotes 16-bit floating-point / and % to 32-bit because there # are no native div or FRem operations on float16. Since we have to # convert anyway, we may as well take the accuracy bump. @@ -338,27 +343,14 @@ def test_bin_op(dtype_x, dtype_y, op, num_ctas, device): numpy_expr = None if op == '%' and _mod_operation_ill_conditioned(dtype_x, dtype_y): with pytest.raises(AssertionError, match='Not equal to tolerance'): - _test_binary( - dtype_x, - dtype_y, - expr, - numpy_expr, - device=device, - num_ctas=num_ctas) - elif (op in ('%', '/') - and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or - (dtype_x in uint_dtypes and dtype_y in int_dtypes))): + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) + elif (op in ('%', '/') and ((dtype_x in int_dtypes and dtype_y in uint_dtypes) or + (dtype_x in uint_dtypes and dtype_y in int_dtypes))): with pytest.raises(triton.CompilationError) as exc_info: _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) assert re.match('Cannot use .* because they have different signedness', str(exc_info.value.__cause__)) else: - _test_binary( - dtype_x, - dtype_y, - expr, - numpy_expr, - device=device, - num_ctas=num_ctas) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) @pytest.mark.parametrize("dtype, order", [(dtype, order) for dtype in dtypes_with_bfloat16 for order in [0, 1]]) @@ -380,14 +372,15 @@ def kernel(x, y, ORDER: tl.constexpr, SIZE: tl.constexpr): x_tri = to_triton(x, dst_type=dtype, device=device) y_tri = to_triton(y, dst_type=dtype, device=device) y = x - kernel[1,](x_tri, y_tri, order, SIZE) + kernel[ + 1, + ](x_tri, y_tri, order, SIZE) np.testing.assert_allclose(y, to_numpy(y_tri)) -@pytest.mark.parametrize("dtype_x, dtype_y", - [(dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes] - + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes] - ) +@pytest.mark.parametrize("dtype_x, dtype_y", [ # + (dtype_x, dtype_y) for dtype_x in int_dtypes for dtype_y in int_dtypes +] + [(dtype_x, dtype_y) for dtype_x in uint_dtypes for dtype_y in uint_dtypes]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_floordiv(dtype_x, dtype_y, num_ctas, device): # Triton has IEEE, not numpy/torch, semantics for %, and those carry @@ -395,13 +388,7 @@ def test_floordiv(dtype_x, dtype_y, num_ctas, device): # reference result for //. expr = 'x // y' numpy_expr = '((x - np.fmod(x, y)) / y)' - _test_binary( - dtype_x, - dtype_y, - expr, - numpy_expr, - device=device, - num_ctas=num_ctas) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) def test_unsigned_name_mangling(device='cuda'): @@ -430,10 +417,7 @@ def kernel(O1, O2, X, Y, SIZE: tl.constexpr): # triton result x_tri = to_triton(x, device=device, dst_type=dtype_x) y_tri = to_triton(y, device=device, dst_type=dtype_y) - actual = tuple( - to_triton(np.empty_like(e), device=device) - for e in expect - ) + actual = tuple(to_triton(np.empty_like(e), device=device) for e in expect) kernel[(1, )](actual[0], actual[1], x_tri, y_tri, SIZE=SIZE, num_warps=4) # Bitwise op, so expect exact equality @@ -443,7 +427,7 @@ def kernel(O1, O2, X, Y, SIZE: tl.constexpr): # test bitwise ops # --------------- -@pytest.mark.parametrize("dtype_x, dtype_y, op", [ +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['&', '|', '^'] for dtype_x in dtypes + dtypes_with_bfloat16 @@ -464,16 +448,10 @@ def test_bitwise_op(dtype_x, dtype_y, op, num_ctas, device): # The CompilationError must have been caused by a C++ exception with this text. assert re.match('invalid operands of type', str(exc_info.value.__cause__)) else: - _test_binary( - dtype_x, - dtype_y, - expr, - numpy_expr, - device=device, - num_ctas=num_ctas) + _test_binary(dtype_x, dtype_y, expr, numpy_expr, device=device, num_ctas=num_ctas) -@pytest.mark.parametrize("dtype_x, dtype_y, op", [ +@pytest.mark.parametrize("dtype_x, dtype_y, op", [ # (dtype_x, dtype_y, op) for op in ['<<', '>>'] for dtype_x in int_dtypes + uint_dtypes @@ -497,22 +475,14 @@ def test_shift_op(dtype_x, dtype_y, op, num_ctas, device): ops = ['==', '!=', '>', '<', '>=', '<='] -@pytest.mark.parametrize("dtype_x, dtype_y, op, mode_x, mode_y", - # real - [ - (dtype_x, dtype_y, op, 'real', 'real') - for op in ops - for dtype_x in dtypes - for dtype_y in dtypes - ] - # NaNs - + [('float32', 'float32', op, mode_x, mode_y) - for op in ops - for mode_x, mode_y in [('nan', 'real'), - ('real', 'nan'), - ('nan', 'nan')] - - ]) +@pytest.mark.parametrize( + "dtype_x, dtype_y, op, mode_x, mode_y", + # real + [(dtype_x, dtype_y, op, 'real', 'real') for op in ops for dtype_x in dtypes for dtype_y in dtypes] + # NaNs + + [('float32', 'float32', op, mode_x, mode_y) + for op in ops + for mode_x, mode_y in [('nan', 'real'), ('real', 'nan'), ('nan', 'nan')]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): expr = f'x {op} y' @@ -530,6 +500,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device): # --------------- @pytest.mark.parametrize("dtype", dtypes_with_bfloat16) def test_broadcast(dtype, device): + @triton.jit def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr): offset1 = tl.arange(0, M) @@ -550,9 +521,10 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con y_tri = to_triton(y, device=device, dst_type=dtype) y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype) - broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) + broadcast_kernel[(1, )](x_tri, y_tri, y_broadcasted_tri, M=M, N=N) assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() + # ---------- # test slice # ---------- @@ -580,7 +552,7 @@ def slice_kernel(XBLOCK: tl.constexpr): t = scalar[None, None] tl.static_assert(t.shape == [1, 1]) - slice_kernel[(1,)](XBLOCK=32) + slice_kernel[(1, )](XBLOCK=32) # ------------------ @@ -596,13 +568,14 @@ def _kernel(dst): dst[10:] with pytest.raises(triton.CompilationError, match='unsupported tensor index'): - _kernel[(1,)](dst=dst) + _kernel[(1, )](dst=dst) # ---------------- # test expand_dims # ---------------- def test_expand_dims(device): + @triton.jit def expand_dims_kernel(dummy, N: tl.constexpr): offset1 = tl.arange(0, N) @@ -641,10 +614,11 @@ def expand_dims_kernel(dummy, N: tl.constexpr): N = 32 dummy_tensor = torch.empty((), device=device) - expand_dims_kernel[(1,)](dummy_tensor, N) + expand_dims_kernel[(1, )](dummy_tensor, N) def test_expand_dims_error_cases(device): + @triton.jit def dim_out_of_range1(dummy, N: tl.constexpr): offset1 = tl.arange(0, N) @@ -682,19 +656,19 @@ def duplicate_dim2(dummy, N: tl.constexpr): dummy_tensor = torch.empty((), device=device) with pytest.raises(triton.CompilationError, match="invalid axis -3"): - dim_out_of_range1[(1,)](dummy_tensor, N) + dim_out_of_range1[(1, )](dummy_tensor, N) with pytest.raises(triton.CompilationError, match="invalid axis 2"): - dim_out_of_range2[(1,)](dummy_tensor, N) + dim_out_of_range2[(1, )](dummy_tensor, N) with pytest.raises(triton.CompilationError, match="invalid axis 1"): - dim_out_of_range3[(1,)](dummy_tensor, N) + dim_out_of_range3[(1, )](dummy_tensor, N) with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): - duplicate_dim1[(1,)](dummy_tensor, N) + duplicate_dim1[(1, )](dummy_tensor, N) with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): - duplicate_dim2[(1,)](dummy_tensor, N) + duplicate_dim2[(1, )](dummy_tensor, N) # ---------------------------- @@ -708,7 +682,7 @@ def _kernel(dst): pid = tl.program_id(20) with pytest.raises(triton.CompilationError, match=r"program_id axis must be 0, 1, or 2 but got 20"): - _kernel[(1,)](dst) + _kernel[(1, )](dst) # --------------- @@ -724,10 +698,8 @@ def test_where(dtype, num_ctas, device): check_type_supported(dtype, device) @triton.jit - def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, - BLOCK_SIZE: tl.constexpr, - TEST_POINTERS: tl.constexpr, - TEST_SCALAR_POINTERS: tl.constexpr): + def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, + TEST_POINTERS: tl.constexpr, TEST_SCALAR_POINTERS: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = offsets < n_elements decide = tl.load(cond_ptr + offsets, mask=mask) @@ -756,17 +728,20 @@ def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements, y_tri = to_triton(y, device=device, dst_type=dtype) z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device=device, dst_type=dtype) - grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),) - where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) + grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']), ) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=False, num_ctas=num_ctas) assert (z == to_numpy(z_tri)).all() if select_ptrs: - where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True) + where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, + TEST_SCALAR_POINTERS=True) z = np.where(cond[0], x, y) assert (z == to_numpy(z_tri)).all() @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_where_broadcast(num_ctas, device): + @triton.jit def where_kernel(cond_ptr, a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): xoffsets = tl.arange(0, BLOCK_SIZE)[:, None] @@ -795,44 +770,45 @@ def where_scalar_condition(a_ptr, out_ptr, BLOCK_SIZE: tl.constexpr): cond_tri = to_triton(mask, device=device) x_tri = to_triton(x, device=device, dst_type=dtype) z_tri = to_triton(np.empty((SIZE, SIZE), dtype=z.dtype), device=device, dst_type=dtype) - where_kernel[(1,)](cond_tri, x_tri, z_tri, SIZE) + where_kernel[(1, )](cond_tri, x_tri, z_tri, SIZE) assert (z == to_numpy(z_tri)).all() - where_scalar_condition[(1,)](x_tri, z_tri, SIZE, num_ctas=num_ctas) + where_scalar_condition[(1, )](x_tri, z_tri, SIZE, num_ctas=num_ctas) z = np.where(0, x, 0) assert (z == to_numpy(z_tri)).all() + # --------------- # test unary ops # --------------- -@pytest.mark.parametrize("dtype_x, expr", [ - (dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16 -] + [ - (dtype_x, ' ~x') for dtype_x in int_dtypes -]) +@pytest.mark.parametrize("dtype_x, expr", + [(dtype_x, ' -x') for dtype_x in dtypes_with_bfloat16] + [(dtype_x, ' ~x') + for dtype_x in int_dtypes]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_unary_op(dtype_x, expr, num_ctas, device): _test_unary(dtype_x, expr, device=device, num_ctas=num_ctas) + # ---------------- # test math ops # ---------------- -@pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) for dtype_x in ["float32", "float64"] for expr in ['exp', 'log', 'cos', 'sin'] for x in ['x', '3.0']]) +@pytest.mark.parametrize("dtype_x, expr, x", [(dtype_x, expr, x) + for dtype_x in ["float32", "float64"] + for expr in ['exp', 'log', 'cos', 'sin'] + for x in ['x', '3.0']]) def test_math_op(dtype_x, expr, device, x): _test_unary(dtype_x, f'tl.{expr}({x})', f'np.{expr}({x}) ', device=device) + # ---------------- # test abs # ---------------- -@pytest.mark.parametrize("dtype_x", [ - (dtype_x) - for dtype_x in dtypes_with_bfloat16 -]) +@pytest.mark.parametrize("dtype_x", [(dtype_x) for dtype_x in dtypes_with_bfloat16]) def test_abs(dtype_x, device): _test_unary(dtype_x, 'tl.abs(x)', 'np.abs(x) ', device=device) @@ -856,7 +832,7 @@ def abs_kernel(X, Z, SIZE: tl.constexpr): f8 = triton.reinterpret(f8_tensor, in_dtype) n_elements = f8_tensor.numel() out_f8 = torch.empty_like(f8_tensor) - abs_kernel[(1,)](f8, triton.reinterpret(out_f8, in_dtype), n_elements) + abs_kernel[(1, )](f8, triton.reinterpret(out_f8, in_dtype), n_elements) f32_tensor = convert_float_to_float32(f8_tensor, in_dtype) expect = f32_tensor.abs() @@ -881,13 +857,9 @@ def make_ptr_str(name, shape): # TODO: handle `%4 = triton_gpu.convert_layout %3 : (tensor<32xi32, #blocked0>) -> tensor<32xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>>`` -@pytest.mark.parametrize("expr, dtype_str", [ - (f'x[{s}]', d) - for s in ['None, :', ':, None', - 'None, :, :', - ':, :, None'] - for d in ['int32', 'uint32', 'uint16'] -]) +@pytest.mark.parametrize("expr, dtype_str", [(f'x[{s}]', d) + for s in ['None, :', ':, None', 'None, :, :', ':, :, None'] + for d in ['int32', 'uint32', 'uint16']]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_index1d(expr, dtype_str, num_ctas, device): rank_x = expr.count(':') @@ -931,8 +903,7 @@ def generate_kernel(shape_x, shape_z): def catch_compilation_error(kernel): try: - kernel[(1, )](z_tri, x_tri, num_warps=1, - SIZE=shape_x[0], num_ctas=num_ctas) + kernel[(1, )](z_tri, x_tri, num_warps=1, SIZE=shape_x[0], num_ctas=num_ctas) except triton.CompilationError as e: np.testing.assert_(True) except BaseException: @@ -955,6 +926,7 @@ def tuples_fn(a, b): def test_tuples(device): + @triton.jit def with_fn(X, Y, A, B, C): x = tl.load(X) @@ -1063,7 +1035,7 @@ def kernel(X, Y, Z): z = torch.ones((16, 16), device=device, dtype=torch.float32) else: z = torch.tensor([0.0], device=device, dtype=torch.float32) - kernel[(1,)](x, y, z, num_warps=1) + kernel[(1, )](x, y, z, num_warps=1) if mode == "simple": assert torch.equal(z, x + y) elif mode == "call_graph" or mode == "dynamic" or mode == "multi_values": @@ -1076,18 +1048,31 @@ def kernel(X, Y, Z): # --------------- # test atomics # --------------- -@pytest.mark.parametrize("op, dtype_x_str, mode, sem", itertools.chain.from_iterable([ - [ +@pytest.mark.parametrize( + "op, dtype_x_str, mode, sem", + itertools.chain.from_iterable([[ ('add', 'float16', mode, sem), - ('add', 'uint32', mode, sem), ('add', 'int32', mode, sem), ('add', 'float32', mode, sem), - ('add', 'uint64', mode, sem), ('add', 'int64', mode, sem), ('add', 'float64', mode, sem), - ('max', 'uint32', mode, sem), ('max', 'int32', mode, sem), ('max', 'float32', mode, sem), - ('max', 'uint64', mode, sem), ('max', 'int64', mode, sem), ('max', 'float64', mode, sem), - ('min', 'uint32', mode, sem), ('min', 'int32', mode, sem), ('min', 'float32', mode, sem), - ('min', 'uint64', mode, sem), ('min', 'int64', mode, sem), ('min', 'float64', mode, sem), + ('add', 'uint32', mode, sem), + ('add', 'int32', mode, sem), + ('add', 'float32', mode, sem), + ('add', 'uint64', mode, sem), + ('add', 'int64', mode, sem), + ('add', 'float64', mode, sem), + ('max', 'uint32', mode, sem), + ('max', 'int32', mode, sem), + ('max', 'float32', mode, sem), + ('max', 'uint64', mode, sem), + ('max', 'int64', mode, sem), + ('max', 'float64', mode, sem), + ('min', 'uint32', mode, sem), + ('min', 'int32', mode, sem), + ('min', 'float32', mode, sem), + ('min', 'uint64', mode, sem), + ('min', 'int64', mode, sem), + ('min', 'float64', mode, sem), ] - for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] - for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) + for mode in ['all_neg', 'all_pos', 'min_neg', 'max_pos'] + for sem in [None, 'acquire', 'release', 'acq_rel', 'relaxed']])) def test_atomic_rmw(op, dtype_x_str, mode, sem, device): check_cuda_only(device) @@ -1146,18 +1131,22 @@ def kernel(X, Z): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_atomic_rmw_predicate(num_ctas, device): + @triton.jit def kernel(X): val = tl.program_id(0) if val < 64: tl.atomic_max(X, val) - x = torch.zeros((1,), device=device, dtype=torch.int32) - kernel[(4096,)](x, num_ctas=num_ctas) + + x = torch.zeros((1, ), device=device, dtype=torch.int32) + kernel[(4096, )](x, num_ctas=num_ctas) assert x.item() == 63 -@pytest.mark.parametrize("shape, axis, num_ctas", - [(shape, axis, num_ctas) for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] for axis in [0, 1] for num_ctas in num_ctas_list]) +@pytest.mark.parametrize("shape, axis, num_ctas", [(shape, axis, num_ctas) + for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] + for axis in [0, 1] + for num_ctas in num_ctas_list]) def test_tensor_atomic_rmw(shape, axis, num_ctas, device): shape0, shape1 = shape # triton kernel @@ -1172,6 +1161,7 @@ def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr) tl.atomic_add(Z + off0, z) else: tl.atomic_add(Z + off1, z) + rs = RandomState(17) x = numpy_random((shape0, shape1), dtype_str="float32", rs=rs) # reference result @@ -1180,7 +1170,7 @@ def kernel(Z, X, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr) x_tri = to_triton(x, device=device) z_shape = (shape0, ) if axis == 1 else (shape1, ) z_tri = to_triton(np.zeros(z_shape, dtype="float32"), device=device) - kernel[(1,)](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) + kernel[(1, )](z_tri, x_tri, axis, shape0, shape1, num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) @@ -1196,8 +1186,9 @@ def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): val = offs.to(tl.float32) x = X + offs tl.atomic_min(x, val) + x = torch.ones((8, 8), device=device, dtype=torch.float32) - kernel[(2,)](x, shape[0], shape[1], num_ctas=num_ctas) + kernel[(2, )](x, shape[0], shape[1], num_ctas=num_ctas) assert torch.min(x).item() == 0.0 @@ -1209,8 +1200,8 @@ def test_atomic_cas(sem, num_ctas, device): def change_value(Lock): tl.atomic_cas(Lock, 0, 1) - Lock = torch.zeros((1,), device=device, dtype=torch.int32) - change_value[(1,)](Lock) + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + change_value[(1, )](Lock) assert (Lock[0] == 1) @@ -1226,10 +1217,10 @@ def serialized_add(data, Lock, SEM: tl.constexpr): # release lock tl.atomic_xchg(Lock, 0) - Lock = torch.zeros((1,), device=device, dtype=torch.int32) - data = torch.zeros((128,), device=device, dtype=torch.float32) - ref = torch.full((128,), 64.0) - h = serialized_add[(64,)](data, Lock, SEM=sem, num_ctas=num_ctas) + Lock = torch.zeros((1, ), device=device, dtype=torch.int32) + data = torch.zeros((128, ), device=device, dtype=torch.float32) + ref = torch.full((128, ), 64.0) + h = serialized_add[(64, )](data, Lock, SEM=sem, num_ctas=num_ctas) sem_str = "acq_rel" if sem is None else sem np.testing.assert_allclose(to_numpy(data), to_numpy(ref)) if is_hip(): @@ -1240,6 +1231,7 @@ def serialized_add(data, Lock, SEM: tl.constexpr): @pytest.mark.parametrize("sem", [None, 'acquire', 'release', 'acq_rel', 'relaxed']) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_tensor_atomic_cas(sem, num_ctas, device): + @triton.jit def change_value(X, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) @@ -1252,7 +1244,7 @@ def change_value(X, BLOCK_SIZE: tl.constexpr): X = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1], device=device, dtype=torch.int64) Y = torch.tensor([2, 1, 2, 1, 2, 1, 2, 1], device=device, dtype=torch.int64) - change_value[(2,)](X, 4) + change_value[(2, )](X, 4) assert (torch.equal(X, Y)) @@ -1261,31 +1253,24 @@ def change_value(X, BLOCK_SIZE: tl.constexpr): # --------------- -@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", [ - (dtype_x, dtype_z, False, 1024) - for dtype_x in dtypes - for dtype_z in dtypes -] + [ - ('float32', 'bfloat16', False, 1024), - ('bfloat16', 'float32', False, 1024), - ('float32', 'int32', True, 1024), - ('float32', 'int1', False, 1024), - ('int8', 'bfloat16', False, 1024), -] + [ - (f'uint{x}', f'int{x}', True, 1024) for x in [8, 16, 32, 64] -] + [ - (f'int{x}', f'uint{x}', True, 1024) for x in [8, 16, 32, 64] -] + (([ - (dtype_x, dtype_z, False, size) - for dtype_x in torch_float8_dtypes - for dtype_z in ["float16", "float32", "bfloat16"] - for size in [1024, 32] -] + [ - (dtype_x, dtype_z, False, size) - for dtype_z in torch_float8_dtypes - for dtype_x in ["float16", "float32", "bfloat16"] - for size in [1024, 32] -]) if torch.__version__ >= "2.1" else [])) +@pytest.mark.parametrize("dtype_x, dtype_z, bitcast, size", + [(dtype_x, dtype_z, False, 1024) for dtype_x in dtypes for dtype_z in dtypes] + [ + ('float32', 'bfloat16', False, 1024), + ('bfloat16', 'float32', False, 1024), + ('float32', 'int32', True, 1024), + ('float32', 'int1', False, 1024), + ('int8', 'bfloat16', False, 1024), + ] + [(f'uint{x}', f'int{x}', True, 1024) + for x in [8, 16, 32, 64]] + [(f'int{x}', f'uint{x}', True, 1024) + for x in [8, 16, 32, 64]] + + (([(dtype_x, dtype_z, False, size) + for dtype_x in torch_float8_dtypes + for dtype_z in ["float16", "float32", "bfloat16"] + for size in [1024, 32]] + # + [(dtype_x, dtype_z, False, size) + for dtype_z in torch_float8_dtypes + for dtype_x in ["float16", "float32", "bfloat16"] + for size in [1024, 32]]) if torch.__version__ >= "2.1" else [])) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_cast(dtype_x, dtype_z, bitcast, size, num_ctas, device): # bfloat16 on cc < 80 will not be tested @@ -1324,14 +1309,15 @@ def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr): dtype_z_np = dtype_z if dtype_z != 'int1' else 'bool_' # triton result if dtype_z.startswith('bfloat'): - z_tri = torch.empty((size,), dtype=getattr(torch, dtype_z), device=device) + z_tri = torch.empty((size, ), dtype=getattr(torch, dtype_z), device=device) elif dtype_z.startswith('float8'): - z_tri = torch.empty((size,), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) + z_tri = torch.empty((size, ), dtype=torch.half, device=device).to(dtype=getattr(torch, dtype_z)) else: z_tri = to_triton(np.empty((size, ), dtype=getattr(np, dtype_z_np)), device=device) kernel[(1, )](x_tri, z_tri, BITCAST=bitcast, SIZE=size, num_warps=1, num_ctas=num_ctas) # torch result - if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith('float8') or dtype_x.startswith('float8'): + if dtype_z.startswith('bfloat') or dtype_x.startswith('bfloat') or dtype_z.startswith( + 'float8') or dtype_x.startswith('float8'): assert bitcast is False z_ref = x_tri.to(z_tri.dtype) torch.testing.assert_close(z_ref, z_tri, rtol=0, atol=0) @@ -1343,7 +1329,8 @@ def kernel(X, Z, BITCAST: tl.constexpr, SIZE: tl.constexpr): np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=0, atol=0) -@pytest.mark.parametrize("dtype_str, num_warps", [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) +@pytest.mark.parametrize("dtype_str, num_warps", + [(dtype_str, num_warps) for dtype_str in int_dtypes + float_dtypes for num_warps in [4, 8]]) def test_cat(dtype_str, num_warps, device): check_type_supported(dtype_str, device) @@ -1358,7 +1345,7 @@ def kernel(X, Y, Z, N: tl.constexpr): x = torch.arange(0, 128, device=device).to(getattr(torch, dtype_str)) y = torch.arange(-128, 0, device=device).to(getattr(torch, dtype_str)) z_ref = torch.cat([x, y], dim=0).sum() - z = torch.zeros((256,), dtype=getattr(torch, dtype_str), device=device) + z = torch.zeros((256, ), dtype=getattr(torch, dtype_str), device=device) kernel[(1, )](x, y, z, N=128, num_warps=num_warps) assert z.sum() == z_ref # check if there's no duplicate value in z @@ -1369,8 +1356,8 @@ def kernel(X, Y, Z, N: tl.constexpr): @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_store_constant(dtype_str, num_ctas, device): check_type_supported(dtype_str, device) - """Tests that boolean True is stored as 1""" + @triton.jit def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) @@ -1383,12 +1370,13 @@ def kernel(output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): block_size = 128 ref = torch.ones([block_size], dtype=getattr(torch, dtype_str), device=device) output = torch.zeros([block_size], dtype=getattr(torch, dtype_str), device=device) - kernel[(1,)](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) + kernel[(1, )](output, block_size, BLOCK_SIZE=block_size, num_ctas=num_ctas) assert torch.all(output == ref) def test_load_store_same_ptr(device): + @triton.jit() def kernel(in_out_ptr): pid = tl.program_id(axis=0) @@ -1397,11 +1385,11 @@ def kernel(in_out_ptr): tl.store(in_out_ptr + pid, out) for _ in range(1000): - x = torch.ones((65536,), device=device, dtype=torch.float32) + x = torch.ones((65536, ), device=device, dtype=torch.float32) if is_hip(): - kernel[(65536,)](x, num_warps=16) # threads per Warp for ROCM is 64 + kernel[(65536, )](x, num_warps=16) # threads per Warp for ROCM is 64 else: - kernel[(65536,)](x, num_warps=32) + kernel[(65536, )](x, num_warps=32) assert torch.all(x == 2) @@ -1416,13 +1404,15 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None): exp = ((fp >> dtype.fp_mantissa_width) & ((1 << exp_width) - 1)).int() frac = (fp & ((1 << dtype.fp_mantissa_width) - 1)).int() - output = torch.where(exp == 0, - # subnormal - ((-1.0) ** sign) * (2.0 ** (1 - exp_bias)) * (frac / (2.0 ** dtype.fp_mantissa_width)), - # normal - ((-1.0) ** sign) * (2.0 ** (exp - exp_bias)) * (1.0 + frac / (2.0 ** dtype.fp_mantissa_width))).float() + output = torch.where( + exp == 0, + # subnormal + ((-1.0)**sign) * (2.0**(1 - exp_bias)) * (frac / (2.0**dtype.fp_mantissa_width)), + # normal + ((-1.0)**sign) * (2.0**(exp - exp_bias)) * (1.0 + frac / (2.0**dtype.fp_mantissa_width))).float() - extended_exp = ((1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width + extended_exp = ( + (1 << (tl.float32.primitive_bitwidth - tl.float32.fp_mantissa_width - 1)) - 1) << tl.float32.fp_mantissa_width # special cases, exp is 0b11..1 if dtype in [tl.float8e4nv, tl.float8e4b15]: # float8e4m3nv does not have infinities @@ -1430,8 +1420,9 @@ def convert_float_to_float32(fp: torch.tensor, dtype=None): output[fp == 0b11111111] = torch.nan else: output = torch.where(exp == (1 << exp_width) - 1, - ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))).view(torch.float32), - output) + ((sign << (tl.float32.primitive_bitwidth - 1)) | extended_exp | + (frac << (tl.float32.fp_mantissa_width - dtype.fp_mantissa_width))) # + .view(torch.float32), output) return output @@ -1440,7 +1431,7 @@ def test_convert_float16_to_float32(in_dtype, device): """Tests that check convert_float_to_float32 function""" check_type_supported(in_dtype, device) - f16_input = torch.tensor(range(-int(2 ** (16 - 1)), int(2 ** (16 - 1))), dtype=torch.int16).view(in_dtype) + f16_input = torch.tensor(range(-int(2**(16 - 1)), int(2**(16 - 1))), dtype=torch.int16).view(in_dtype) f32_output = convert_float_to_float32(f16_input) nan = f16_input.isnan() @@ -1469,6 +1460,7 @@ def serialize_fp8(np_data, in_dtype): else: return np_data + # inverse of `serialize_fp8` @@ -1514,13 +1506,13 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): tri_fp8 = torch.from_numpy(serialize_fp8(ref_fp8, in_dtype)).cuda() # check that non-subnormal fp8 are correctly converted to fp16 tri_fp16 = torch.empty(256, dtype=out_dtype, device="cuda") - copy_kernel[(1,)](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024) + copy_kernel[(1, )](triton.reinterpret(tri_fp8, in_dtype), tri_fp16, tri_fp16.shape[0], BLOCK_SIZE=1024) ref_fp8 = torch.from_numpy(ref_fp8).cuda() ref_fp16 = convert_float_to_float32(ref_fp8, in_dtype) assert torch.all(tri_fp16[~is_subnormal] == ref_fp16[~is_subnormal]) # check that values are properly converted back to float8 ref_fp8 = torch.empty_like(tri_fp16, dtype=torch.int8) - copy_kernel[(1,)](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024) + copy_kernel[(1, )](tri_fp16, triton.reinterpret(ref_fp8, in_dtype), tri_fp16.shape[0], BLOCK_SIZE=1024) if in_dtype == tl.float8e4b15: assert torch.all(tri_fp8[:127] == ref_fp8[:127]) assert torch.all(tri_fp8[128:255] == ref_fp8[128:255]) @@ -1529,6 +1521,7 @@ def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr): else: assert torch.all(tri_fp8[~is_subnormal] == ref_fp8[~is_subnormal]) + # --------------- # test reduce # --------------- @@ -1542,16 +1535,15 @@ def get_reduced_dtype(dtype_str, op): return dtype_str -@pytest.mark.parametrize("op, dtype_str, shape", - [(op, dtype, shape) - for op in ['min', 'max', - 'min-with-indices', - 'max-with-indices', - 'argmin-tie-break-left', - 'argmax-tie-break-left', - 'sum'] - for dtype in dtypes_with_bfloat16 - for shape in [32, 64, 128, 512]]) +@pytest.mark.parametrize("op, dtype_str, shape", [(op, dtype, shape) for op in [ + 'min', + 'max', + 'min-with-indices', + 'max-with-indices', + 'argmin-tie-break-left', + 'argmax-tie-break-left', + 'sum', +] for dtype in dtypes_with_bfloat16 for shape in [32, 64, 128, 512]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_reduce1d(op, dtype_str, shape, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested @@ -1577,14 +1569,18 @@ def kernel(X, Z, BLOCK: tl.constexpr): # input rs = RandomState(17) # limit the range of integers so that the sum does not overflow - x = numpy_random((shape,), dtype_str=dtype_str, rs=rs) - numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, - 'max-with-indices': np.max, - 'min-with-indices': np.min, - 'argmin-tie-break-fast': np.argmin, - 'argmin-tie-break-left': np.argmin, - 'argmax-tie-break-fast': np.argmax, - 'argmax-tie-break-left': np.argmax}[op] + x = numpy_random((shape, ), dtype_str=dtype_str, rs=rs) + numpy_op = { + 'sum': np.sum, + 'max': np.max, + 'min': np.min, + 'max-with-indices': np.max, + 'min-with-indices': np.min, + 'argmin-tie-break-fast': np.argmin, + 'argmin-tie-break-left': np.argmin, + 'argmax-tie-break-fast': np.argmax, + 'argmax-tie-break-left': np.argmax, + }[op] if 'tie-break-left' in op: x[3:10] = numpy_op(x) x_tri = to_triton(x, device=device) @@ -1600,9 +1596,8 @@ def kernel(X, Z, BLOCK: tl.constexpr): else: z_ref = numpy_op(x).astype(getattr(np, z_dtype_str)) # triton result - z_tri = to_triton(numpy_random((1,), dtype_str=z_dtype_str, rs=rs), - device=device, dst_type=z_tri_dtype_str) - kernel[(1,)](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) + z_tri = to_triton(numpy_random((1, ), dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) + kernel[(1, )](x_tri, z_tri, BLOCK=shape, num_ctas=num_ctas) z_tri = to_numpy(z_tri) # compare if op == 'sum': @@ -1617,12 +1612,10 @@ def kernel(X, Z, BLOCK: tl.constexpr): # TODO: [Qingyi] Fix argmin / argmax -reduce_configs1 = [ - (op, dtype, (1, 1024), axis) for dtype in dtypes_with_bfloat16 - for op in ['min', 'max', 'sum', 'argmin', 'argmax'] - for axis in [1] -] - +reduce_configs1 = [(op, dtype, (1, 1024), axis) + for dtype in dtypes_with_bfloat16 + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for axis in [1]] # shape (128, 256) and (32, 1024) are not enabled on sm86 because the required shared memory # exceeds the limit of 99KB @@ -1632,24 +1625,16 @@ def kernel(X, Z, BLOCK: tl.constexpr): if torch.cuda.is_available() and 'V100' in torch.cuda.get_device_name(0): reduce2d_shapes += [(128, 256) and (32, 1024)] - -reduce_configs2 = [ - (op, 'float32', shape, axis) - for op in ['min', 'max', 'sum', 'argmin', 'argmax'] - for shape in reduce2d_shapes - for axis in [0, 1] -] + [ - (op, 'float32', [16, 32], None) - for op in ['min', 'max', 'sum'] -] +reduce_configs2 = [(op, 'float32', shape, axis) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce2d_shapes + for axis in [0, 1]] + [(op, 'float32', [16, 32], None) for op in ['min', 'max', 'sum']] reduce3d_shapes = [(2, 32, 16), (32, 2, 16), (32, 16, 2)] -reduce_configs3 = [ - (op, 'float32', shape, axis) - for op in ['min', 'max', 'sum', 'argmin', 'argmax'] - for shape in reduce3d_shapes - for axis in [0, 1, 2] -] +reduce_configs3 = [(op, 'float32', shape, axis) + for op in ['min', 'max', 'sum', 'argmin', 'argmax'] + for shape in reduce3d_shapes + for axis in [0, 1, 2]] @pytest.mark.parametrize("op, dtype_str, shape, axis", reduce_configs1 + reduce_configs2 + reduce_configs3) @@ -1662,12 +1647,14 @@ def test_reduce(op, dtype_str, shape, axis, num_ctas, device): # triton kernel @triton.jit - def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, AXIS: tl.constexpr): + def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, IS_3D: tl.constexpr, + AXIS: tl.constexpr): range_m = tl.arange(0, BLOCK_M) range_n = tl.arange(0, BLOCK_N) range_k = tl.arange(0, BLOCK_K) if IS_3D: - x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + range_k[None, None, :]) + x = tl.load(X + range_m[:, None, None] * BLOCK_N * BLOCK_K + range_n[None, :, None] * BLOCK_K + + range_k[None, None, :]) else: x = tl.load(X + range_m[:, None] * BLOCK_N + range_n[None, :]) z = GENERATE_TEST_HERE @@ -1694,8 +1681,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const # limit the range of integers so that the sum does not overflow x = numpy_random(shape, dtype_str=dtype_str, rs=rs) x_tri = to_triton(x, device=device) - numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, - 'argmin': np.argmin, 'argmax': np.argmax}[op] + numpy_op = {'sum': np.sum, 'max': np.max, 'min': np.min, 'argmin': np.argmin, 'argmax': np.argmax}[op] z_dtype_str = get_reduced_dtype(dtype_str, op) z_tri_dtype_str = z_dtype_str # numpy result @@ -1709,13 +1695,12 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const z_ref = numpy_op(x, axis=axis).astype(getattr(np, z_dtype_str)) # triton result ret_numel = 1 if axis is None else shape[1 - axis] - z_shape = (1,) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis) - z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), - device=device, dst_type=z_tri_dtype_str) + z_shape = (1, ) if axis is None else tuple(shape_i for i, shape_i in enumerate(shape) if i != axis) + z_tri = to_triton(numpy_random(z_shape, dtype_str=z_dtype_str, rs=rs), device=device, dst_type=z_tri_dtype_str) BLOCK_K = 1 if len(shape) == 2 else shape[2] IS_3D = bool(len(shape) == 3) - kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], - BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, num_ctas=num_ctas) + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], BLOCK_K=BLOCK_K, IS_3D=IS_3D, AXIS=axis, + num_ctas=num_ctas) z_tri = to_numpy(z_tri) # compare if op == 'sum': @@ -1735,14 +1720,12 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.const scan2d_shapes = [(8, 32), (16, 32), (32, 16), (2, 1024), (1024, 2), (32, 32), (1, 1024)] -scan_configs = [ - (op, type, shape, axis, num_warps) - for num_warps in [4, 16] - for type in ['int32', 'float32'] - for axis in [1, 0] - for shape in scan2d_shapes - for op in ['cumsum', 'cumprod', 'get_first_element'] -] +scan_configs = [(op, type, shape, axis, num_warps) + for num_warps in [4, 16] + for type in ['int32', 'float32'] + for axis in [1, 0] + for shape in scan2d_shapes + for op in ['cumsum', 'cumprod', 'get_first_element']] @triton.jit @@ -1788,7 +1771,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp z_ref[:, 1:] = x[:, 0:1] # triton result z_tri = to_triton(z, device=device) - kernel[(1,)](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) + kernel[(1, )](x_tri, z_tri, BLOCK_M=shape[0], BLOCK_N=shape[1], AXIS=axis, num_warps=num_warps) z_tri = to_numpy(z_tri) # compare if dtype_str == 'float32': @@ -1806,7 +1789,6 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp BlockedLayout([4, 1], [4, 8], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([2, 2], [4, 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([2, 2], [8, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 4], [4, 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 1], [4, 8], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]), @@ -1820,6 +1802,7 @@ def kernel(X, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.constexp @pytest.mark.parametrize("N", [512, 1024, 2048]) @pytest.mark.parametrize("num_pid_n", [2, 4]) def test_locality(op, BLOCK_N, N, num_pid_n): + @triton.jit def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): start_m = tl.program_id(0) @@ -1833,6 +1816,7 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): x = tl.load(Xs) local = ACCUMULATE_PATCH tl.store(Y + off_m * num_pid_n + pid_n, local) + initialize_patch = { 'sum': 'tl.zeros([BLOCK_M], dtype=tl.float32)', 'max': 'tl.full([BLOCK_M], float("-inf"), dtype=tl.float32)', @@ -1854,7 +1838,8 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): x = torch.randn((BLOCK_M, N), dtype=torch.float32, device="cuda") y = torch.randn((BLOCK_M, num_pid_n), dtype=torch.float32, device="cuda") h = kernel[(1, num_pid_n, 1)](x, y, N, BLOCK_M, BLOCK_N) - assert h.asm['ttgir'].count('"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" + assert h.asm['ttgir'].count( + '"tt.reduce"') == 2, "tt.reduce should be called twice, otherwise the optimization didn't work" y_ref = numpy_op(x.cpu().numpy(), axis=1, keepdims=True) y_tri = numpy_op(y.cpu().numpy(), axis=1, keepdims=True) np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3) @@ -1921,9 +1906,12 @@ def test_scan_layouts(M, N, src_layout, axis, device): BlockedLayout([1, 4], [8, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [8, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), BlockedLayout([4, 4], [2, 16], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), - MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]), - MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], instr_shape=[16, 16, 16]), + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(2, 0), warps_per_cta=[2, 2], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]), + MmaLayout(version=(3, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[1, 0], + instr_shape=[16, 16, 16]), ] @@ -1941,13 +1929,10 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, ty = {"int32": "i32", "float32": "f32", "float16": "f16"}[dtype_str] arith_op = { - "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, + "max": {"int32": "arith.maxsi", "float32": "arith.maximumf", "float16": "arith.maximumf"}, # "sum": {"int32": "arith.addi", "float32": "arith.addf", "float16": "arith.addf"} }[reduce_op][dtype_str] - numpy_op = { - "max": np.max, - "sum": np.sum - }[reduce_op] + numpy_op = {"max": np.max, "sum": np.sum}[reduce_op] rdims_1d = f"{N}" if axis == 0 else f"{M}" rdims_2d = f"1x{N}" if axis == 0 else f"{M}x1" store_range = "%7" if axis == 0 else "%1" @@ -2024,7 +2009,8 @@ def test_reduce_layouts(M, N, src_layout, axis, reduce2d, dtype_str, reduce_op, layouts = [ BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]) + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) ] @@ -2073,7 +2059,8 @@ def test_store_op(M, src_layout, device): layouts = [ BlockedLayout([1, 4], [1, 32], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), BlockedLayout([1, 4], [1, 32], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], instr_shape=[16, 8]) + MmaLayout(version=(2, 0), warps_per_cta=[4, 1], ctas_per_cga=[1, 1], cta_split_num=[1, 1], cta_order=[0, 1], + instr_shape=[16, 8]) ] @@ -2196,7 +2183,7 @@ def test_chain_reduce(M, N, src_layout, op, device, first_axis): rs = RandomState(17) x = rs.randint(0, 4, (M, N)).astype('int32') - z = np.zeros((1,)).astype('int32') + z = np.zeros((1, )).astype('int32') x_tri = torch.tensor(x, device=device) z_tri = torch.tensor(z, device=device) @@ -2228,7 +2215,7 @@ def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): out_mean = torch.empty((), device=device) out_var = torch.empty((), device=device) - var_mean_kernel[(1,)](x, out_mean, out_var, BLOCK=SIZE) + var_mean_kernel[(1, )](x, out_mean, out_var, BLOCK=SIZE) expect_var, expect_mean = torch.var_mean(x, dim=0, correction=0) torch.testing.assert_close(out_mean, expect_mean) @@ -2240,12 +2227,11 @@ def var_mean_kernel(X, out_mean, out_var, BLOCK: tl.constexpr): # --------------- -@pytest.mark.parametrize("dtype_str, shape, perm", - [(dtype, shape, perm) - # TODO: bfloat16 - for dtype in ['float8e4b15', 'float16', 'float32'] - for shape in [(64, 64), (128, 128)] - for perm in [(1, 0)]]) +@pytest.mark.parametrize("dtype_str, shape, perm", [(dtype, shape, perm) + # TODO: bfloat16 + for dtype in ['float8e4b15', 'float16', 'float32'] + for shape in [(64, 64), (128, 128)] + for perm in [(1, 0)]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_permute(dtype_str, shape, perm, num_ctas, device): check_type_supported(dtype_str, device) # bfloat16 on cc < 80 will not be tested @@ -2254,28 +2240,24 @@ def test_permute(dtype_str, shape, perm, num_ctas, device): # triton kernel @triton.jit - def kernel(X, stride_xm, stride_xn, - Z, stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): + def kernel(X, stride_xm, stride_xn, Z, stride_zm, stride_zn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * stride_xn Zs = Z + off_m[:, None] * stride_zm + off_n[None, :] * stride_zn tl.store(Zs, tl.load(Xs)) + # input x = numpy_random(shape, dtype_str=dtype_str) # triton result z_tri = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) z_tri_contiguous = to_triton(np.empty_like(x), device=device, dst_type=dtype_str) x_tri = to_triton(x, device=device, dst_type=dtype_str) - pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), - z_tri, z_tri.stride(1), z_tri.stride(0), - BLOCK_M=shape[0], BLOCK_N=shape[1], - num_ctas=num_ctas) - pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), x_tri.stride(0), - z_tri_contiguous, z_tri_contiguous.stride(0), z_tri_contiguous.stride(1), - BLOCK_M=shape[0], BLOCK_N=shape[1], - num_ctas=num_ctas) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), z_tri, z_tri.stride(1), z_tri.stride(0), + BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) + pgm_contiguous = kernel[(1, 1)](x_tri, x_tri.stride(1), + x_tri.stride(0), z_tri_contiguous, z_tri_contiguous.stride(0), + z_tri_contiguous.stride(1), BLOCK_M=shape[0], BLOCK_N=shape[1], num_ctas=num_ctas) # numpy result if dtype_str == 'float8e4b15': ty = tl.float8e4b15 @@ -2305,37 +2287,25 @@ def kernel(X, stride_xm, stride_xn, # --------------- -@pytest.mark.parametrize("M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype", - [(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype) - for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] - for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] - for allow_tf32 in [True, False] - for in_dtype, out_dtype in [('float16', 'float16'), - ('float16', 'float32'), - ('float32', 'float32')] - if not (allow_tf32 and (in_dtype in ['float16']))] - - + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype) - for shape_nw in [[128, 256, 32, 8], - [128, 16, 32, 4], - [32, 128, 64, 4], - [128, 128, 64, 4], - [64, 128, 128, 4], - [32, 128, 64, 2], - [64, 64, 32, 4], - [32, 32, 128, 16], - [128, 128, 64, 2], - [64, 128, 128, 2]] - for allow_tf32 in [True] - for col_a in [True, False] - for col_b in [True, False] - for in_dtype, out_dtype in [('int8', 'int8'), - ('float16', 'float16'), - ('float16', 'float32'), - ('float32', 'float32')]] + - - [(64, 64, 64, 4, col_a, col_b, 'none', False, 'float32', 'float32') - for col_a in [True, False] for col_b in [True, False]]) +@pytest.mark.parametrize( + "M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype", + [(*shape, 4, False, False, epilogue, allow_tf32, in_dtype, out_dtype) + for shape in [(64, 64, 64), (32, 32, 32), (16, 16, 16)] + for epilogue in ['none', 'trans', 'add-matrix', 'add-rows', 'add-cols', 'softmax', 'chain-dot'] + for allow_tf32 in [True, False] + for in_dtype, out_dtype in [('float16', 'float16'), ('float16', 'float32'), ('float32', 'float32')] + if not (allow_tf32 and (in_dtype in ['float16']))] + + [(*shape_nw, col_a, col_b, 'none', allow_tf32, in_dtype, out_dtype) + for shape_nw in [[128, 256, 32, 8], [128, 16, 32, 4], [32, 128, 64, 4], [128, 128, 64, 4], [64, 128, 128, 4], + [32, 128, 64, 2], [64, 64, 32, 4], [32, 32, 128, 16], [128, 128, 64, 2], [64, 128, 128, 2]] + for allow_tf32 in [True] + for col_a in [True, False] + for col_b in [True, False] + for in_dtype, out_dtype in [('int8', 'int8'), ('float16', 'float16'), ('float16', 'float32'), ('float32', + 'float32')]] + + [(64, 64, 64, 4, col_a, col_b, 'none', False, 'float32', 'float32') + for col_a in [True, False] + for col_b in [True, False]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, out_dtype, num_ctas, device): check_cuda_only(device) @@ -2384,16 +2354,10 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, in_dtype, o # triton kernel @triton.jit - def kernel(X, stride_xm, stride_xk, - Y, stride_yk, stride_yn, - W, stride_wn, stride_wl, - Z, stride_zm, stride_zn, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - ADD_MATRIX: tl.constexpr, ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, - ALLOW_TF32: tl.constexpr, - DO_SOFTMAX: tl.constexpr, CHAIN_DOT: tl.constexpr, - COL_A: tl.constexpr, COL_B: tl.constexpr, - out_dtype: tl.constexpr = tl.float32): + def kernel(X, stride_xm, stride_xk, Y, stride_yk, stride_yn, W, stride_wn, stride_wl, Z, stride_zm, stride_zn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ADD_MATRIX: tl.constexpr, + ADD_ROWS: tl.constexpr, ADD_COLS: tl.constexpr, ALLOW_TF32: tl.constexpr, DO_SOFTMAX: tl.constexpr, + CHAIN_DOT: tl.constexpr, COL_A: tl.constexpr, COL_B: tl.constexpr, out_dtype: tl.constexpr = tl.float32): off_m = tl.arange(0, BLOCK_M) off_n = tl.arange(0, BLOCK_N) off_l = tl.arange(0, BLOCK_N) @@ -2423,6 +2387,7 @@ def kernel(X, stride_xm, stride_xk, w = tl.load(Ws) z = tl.dot(z.to(w.dtype), w, allow_tf32=ALLOW_TF32, out_dtype=out_dtype) tl.store(Zs, z) + # input rs = RandomState(17) if col_a: @@ -2464,20 +2429,12 @@ def kernel(X, stride_xm, stride_xk, else: out_dtype = tl.float32 - pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), - y_tri, y_tri.stride(0), y_tri.stride(1), - w_tri, w_tri.stride(0), w_tri.stride(1), - z_tri, z_tri.stride(0), z_tri.stride(1), - COL_A=col_a, COL_B=col_b, - BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, - ADD_MATRIX=epilogue == 'add-matrix', - ADD_ROWS=epilogue == 'add-rows', - ADD_COLS=epilogue == 'add-cols', - DO_SOFTMAX=epilogue == 'softmax', - CHAIN_DOT=epilogue == 'chain-dot', - ALLOW_TF32=allow_tf32, - num_warps=num_warps, num_ctas=num_ctas, - out_dtype=out_dtype) + pgm = kernel[(1, 1)](x_tri, x_tri.stride(0), x_tri.stride(1), y_tri, y_tri.stride(0), y_tri.stride(1), w_tri, + w_tri.stride(0), w_tri.stride(1), z_tri, z_tri.stride(0), z_tri.stride(1), COL_A=col_a, + COL_B=col_b, BLOCK_M=M, BLOCK_K=K, BLOCK_N=N, ADD_MATRIX=epilogue == 'add-matrix', + ADD_ROWS=epilogue == 'add-rows', ADD_COLS=epilogue == 'add-cols', + DO_SOFTMAX=epilogue == 'softmax', CHAIN_DOT=epilogue == 'chain-dot', ALLOW_TF32=allow_tf32, + num_warps=num_warps, num_ctas=num_ctas, out_dtype=out_dtype) if epilogue == 'softmax' and (in_dtype != 'float32' or allow_tf32): if is_hip(): @@ -2496,8 +2453,7 @@ def kernel(X, stride_xm, stride_xk, assert "bar.sync" not in red_code # torch result if in_dtype == 'int8': - z_ref = np.matmul(x.astype(np.float32), - y.astype(np.float32())).astype(np.int32) + z_ref = np.matmul(x.astype(np.float32), y.astype(np.float32())).astype(np.int32) else: z_ref = np.matmul(x, y) @@ -2557,9 +2513,8 @@ def test_dot_mulbroadcastred(in_dtype, device): pytest.skip("Requires sm >= 80 to run") @triton.jit - def kernel(Z, X, Y, - M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, - BM: tl.constexpr, BN: tl.constexpr, BK: tl.constexpr): + def kernel(Z, X, Y, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, BM: tl.constexpr, BN: tl.constexpr, + BK: tl.constexpr): pidn = tl.program_id(1) pidm = tl.program_id(0) offm = tl.arange(0, BM)[:, None] @@ -2575,6 +2530,7 @@ def kernel(Z, X, Y, t = tl.sum(x * y, axis=1) acc = t + acc tl.store(Z + ((pidm * BM * N) + (pidn * BN) + (offm * N) + offn), acc) + M, N, K = 256, 192, 160 BM, BN, BK = 128, 32, 32 rs = RandomState(17) @@ -2605,7 +2561,7 @@ def kernel(Z, X, Y, @pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16']) -@pytest.mark.parametrize("shape", [(), (1,), (128,)]) +@pytest.mark.parametrize("shape", [(), (1, ), (128, )]) def test_full(dtype_str, shape, device): if dtype_str in uint_dtypes and not hasattr(torch, dtype_str): # PyTorch only has unsigned 8, but not 16, 32, or 64 @@ -2633,30 +2589,28 @@ def kernel_dynamic(out, val, dtype: tl.constexpr): 'SHAPE': str(list(shape)), }) out_static = torch.zeros((128), dtype=dtype, device=device) - kernel_static_patched[(1,)](out_static) + kernel_static_patched[(1, )](out_static) assert torch.all(out_static == 2) kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))}) out_dynamic = torch.zeros((128), dtype=dtype, device=device) - kernel_dynamic_patched[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str)) + kernel_dynamic_patched[(1, )](out_dynamic, 2, getattr(triton.language, dtype_str)) assert torch.all(out_dynamic == 2) -@pytest.mark.parametrize("literal, dtype_str", - [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), - ('float("inf")', "f32"), ('float("-inf")', "f32"), - ('float("nan")', "f32"), ('float("-nan")', "f32"), - (0., "f32"), - (5, "i32"), (2**40, "i64"),]) +@pytest.mark.parametrize("literal, dtype_str", [(1e+50, "f64"), (1e+10, "f32"), (1.0, "f32"), ('float("inf")', "f32"), + ('float("-inf")', "f32"), ('float("nan")', "f32"), + ('float("-nan")', "f32"), (0., "f32"), (5, "i32"), (2**40, "i64")]) def test_constexpr(literal, dtype_str, device): + @triton.jit def kernel(out_ptr): val = GENERATE_TEST_HERE tl.store(out_ptr.to(tl.pointer_type(val.dtype)), val) kernel_patched = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{literal}"}) - out = torch.zeros((1,), dtype=torch.float32, device=device) - h = kernel_patched[(1,)](out) + out = torch.zeros((1, ), dtype=torch.float32, device=device) + h = kernel_patched[(1, )](out) assert re.search(r"arith.constant .* : " + dtype_str, h.asm["ttir"]) is not None @@ -2675,14 +2629,16 @@ def _kernel(out, ALLOW_TF32: tl.constexpr): c = tl.dot(a, b, allow_tf32=ALLOW_TF32) out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :] tl.store(out_ptr, c) + kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"}) a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device=device) out_ref = torch.matmul(a, b) out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device=device) - kernel[(1,)](out, ALLOW_TF32=allow_tf32) + kernel[(1, )](out, ALLOW_TF32=allow_tf32) assert torch.all(out == out_ref) + # --------------- # test arange # --------------- @@ -2695,21 +2651,25 @@ def test_arange(start, num_ctas, device): z_tri = torch.empty(BLOCK, dtype=torch.int32, device=device) @triton.jit - def _kernel(z, BLOCK: tl.constexpr, - START: tl.constexpr, END: tl.constexpr): + def _kernel(z, BLOCK: tl.constexpr, START: tl.constexpr, END: tl.constexpr): off = tl.arange(0, BLOCK) val = tl.arange(START, END) tl.store(z + off, val) - _kernel[(1,)](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) + + _kernel[(1, )](z_tri, START=start, END=start + BLOCK, BLOCK=BLOCK, num_ctas=num_ctas) z_ref = torch.arange(start, BLOCK + start, dtype=torch.int32, device=device) np.testing.assert_allclose(to_numpy(z_tri), to_numpy(z_ref)) + # --------------- # test load # --------------- -@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) for dtype_str in torch_dtypes for size in [128, 512] for size_diff in [0, 1, 2, 3, 4]]) +@pytest.mark.parametrize("dtype_str, size, size_diff", [(dtype_str, size, size_diff) + for dtype_str in torch_dtypes + for size in [128, 512] + for size_diff in [0, 1, 2, 3, 4]]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_masked_load(dtype_str, size, size_diff, num_ctas, device): dtype = getattr(torch, dtype_str) @@ -2718,12 +2678,12 @@ def test_masked_load(dtype_str, size, size_diff, num_ctas, device): input_size = size - size_diff output_size = size if dtype_str == 'bool': - input = torch.randint(0, 2, (input_size,), dtype=dtype, device=device) + input = torch.randint(0, 2, (input_size, ), dtype=dtype, device=device) elif dtype_str in int_dtypes or dtype_str in uint_dtypes: - input = torch.randint(0, 127, (input_size,), dtype=dtype, device=device) + input = torch.randint(0, 127, (input_size, ), dtype=dtype, device=device) else: input = torch.rand(input_size, dtype=dtype, device=device) - output = torch.zeros((output_size,), dtype=dtype, device=device) + output = torch.zeros((output_size, ), dtype=dtype, device=device) @triton.jit def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): @@ -2736,12 +2696,13 @@ def _kernel(in_ptr, out_ptr, in_size: tl.constexpr, out_size: tl.constexpr): mask_str = "mask=in_offsets < in_size, other=1" if size_diff > 0 else "None" kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.load(in_ptr + in_offsets, {mask_str})"}) - kernel[(1,)](input, output, input_size, output_size, num_ctas=num_ctas) + kernel[(1, )](input, output, input_size, output_size, num_ctas=num_ctas) - reference_out = torch.cat((input, torch.ones((size_diff,), dtype=dtype, device=device))) + reference_out = torch.cat((input, torch.ones((size_diff, ), dtype=dtype, device=device))) # print((output - reference_out).nonzero()) torch.testing.assert_close(output, reference_out) + # Testing masked loads with an intermate copy to shared memory run. @@ -2762,9 +2723,7 @@ def test_masked_load_shared_memory(dtype, device): out = torch.zeros((M, N), dtype=dtype, device=device) @triton.jit - def _kernel(in1_ptr, in2_ptr, output_ptr, - in_stride, in2_stride, out_stride, - in_numel, in2_numel, out_numel, + def _kernel(in1_ptr, in2_ptr, output_ptr, in_stride, in2_stride, out_stride, in_numel, in2_numel, out_numel, M: tl.constexpr, N: tl.constexpr, K: tl.constexpr): M_offsets = tl.arange(0, M) @@ -2785,14 +2744,8 @@ def _kernel(in1_ptr, in2_ptr, output_ptr, output_offsets = M_offsets[:, None] * out_stride + N_offsets[None, :] tl.store(output_ptr + output_offsets, o, mask=output_offsets < M * N) - pgm = _kernel[(1,)](in1, in2, out, - in1.stride()[0], - in2.stride()[0], - out.stride()[0], - in1.numel(), - in2.numel(), - out.numel(), - M=M, N=N, K=K) + pgm = _kernel[(1, )](in1, in2, out, in1.stride()[0], in2.stride()[0], out.stride()[0], in1.numel(), in2.numel(), + out.numel(), M=M, N=N, K=K) reference_out = torch.matmul(in1, in2) torch.testing.assert_close(out, reference_out, atol=1e-2, rtol=0) @@ -2809,7 +2762,7 @@ def _kernel(dst, src, CACHE: tl.constexpr): x = tl.load(src + offsets, cache_modifier=CACHE) tl.store(dst + offsets, x) - pgm = _kernel[(1,)](dst, src, CACHE=cache) + pgm = _kernel[(1, )](dst, src, CACHE=cache) if is_hip(): return @@ -2837,8 +2790,8 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) - pgm = _kernel[(1,)]( - dst, src, N=N, BLOCK_SIZE=block_size) + + pgm = _kernel[(1, )](dst, src, N=N, BLOCK_SIZE=block_size) if is_hip(): return @@ -2865,7 +2818,8 @@ def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024) x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) - pgm = _kernel[(1,)](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) + + pgm = _kernel[(1, )](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints) if is_hip(): return @@ -2875,6 +2829,7 @@ def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr): else: assert "ld.global.v4.b32" not in ptx + # --------------- # test store # --------------- @@ -2893,7 +2848,7 @@ def _kernel(dst, src, CACHE: tl.constexpr): if is_hip(): return - pgm = _kernel[(1,)](dst, src, CACHE=cache) + pgm = _kernel[(1, )](dst, src, CACHE=cache) ptx = pgm.asm['ptx'] if cache == '': assert 'st.global.wb' not in ptx @@ -2921,6 +2876,7 @@ def _kernel(dst, src, CACHE: tl.constexpr): assert 'st.global.cs' not in ptx assert 'st.global.wt' in ptx + # --------------- # test if # --------------- @@ -2954,52 +2910,56 @@ def _kernel(ret0, ret1, value=3): tl.store(ret0, _impl()) tl.store(ret1, _impl(value)) - _kernel[(1,)](ret0, ret1, value) + _kernel[(1, )](ret0, ret1, value) assert ret0.item() == 10 assert ret1.item() == value - _kernel[(1,)](ret0, ret1) + _kernel[(1, )](ret0, ret1) assert ret0.item() == 10 assert ret1.item() == 3 + # --------------- # test noop # ---------------- def test_noop(device): + @triton.jit def kernel(x): pass - x = to_triton(numpy_random((1,), dtype_str='int32'), device=device) + + x = to_triton(numpy_random((1, ), dtype_str='int32'), device=device) kernel[(1, )](x) @pytest.mark.parametrize("device", ['cuda', 'cpu', 'cpu_pinned']) def test_pointer_arguments(device): + @triton.jit def kernel(x): pass + pin_memory = 'pinned' in device x = torch.empty(1024, device=device.split('_')[0], pin_memory=pin_memory) if device == "cpu": with pytest.raises(ValueError): - kernel[(1,)](x) + kernel[(1, )](x) else: kernel[(1, )](x) -@pytest.mark.parametrize("value, value_type", [ - (-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), - (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), - (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64') -]) +@pytest.mark.parametrize("value, value_type", [(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'), + (2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'), + (-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')]) def test_value_specialization(value: int, value_type: str, device) -> None: spec_type = None def cache_hook(*args, **kwargs): nonlocal spec_type spec_type = kwargs["compile"]["signature"][0] + JITFunction.cache_hook = cache_hook @triton.jit @@ -3012,15 +2972,13 @@ def kernel(VALUE, X): JITFunction.cache_hook = None assert spec_type == value_type + # -------------------- # value specialization # -------------------- -@pytest.mark.parametrize( - "value, overflow", - [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)] -) +@pytest.mark.parametrize("value, overflow", [(2**64 - 1, False), (2**64, True), (-2**63, False), (-2**63 - 1, True)]) def test_value_specialization_overflow(value: int, overflow: bool, device) -> None: @triton.jit @@ -3040,12 +2998,14 @@ def kernel(VALUE, X): # test constexpr # ---------------- + @pytest.mark.parametrize("op", ['+', '-', '*', '/', '%', '<', '>', '<<', '>>', '&', '^', '|']) @pytest.mark.parametrize("is_lhs_constexpr", [False, True]) @pytest.mark.parametrize("is_rhs_constexpr", [True, False]) def test_bin_op_constexpr(op, is_lhs_constexpr, is_rhs_constexpr, device): if is_hip(): - if (is_rhs_constexpr, is_lhs_constexpr, op) in [(False, False, "<<"), (False, False, ">>"), (False, True, "<<")]: + if (is_rhs_constexpr, is_lhs_constexpr, op) in [(False, False, "<<"), (False, False, ">>"), + (False, True, "<<")]: pytest.skip(f"test_bin_op_constexpr[{is_lhs_constexpr}-{is_rhs_constexpr}-{op}] is not supported in HIP") @triton.jit @@ -3058,19 +3018,19 @@ def kernel(Z, X, Y): if op in ['<<', '>>', '&', '^', '|']: # int op x_str = "3" if is_lhs_constexpr else "x" y_str = "4" if is_rhs_constexpr else "y" - x = numpy_random((1,), dtype_str="int32") - y = numpy_random((1,), dtype_str="int32") + x = numpy_random((1, ), dtype_str="int32") + y = numpy_random((1, ), dtype_str="int32") else: x_str = "3.14" if is_lhs_constexpr else "x" y_str = "4.13" if is_rhs_constexpr else "y" - x = numpy_random((1,), dtype_str="float32") - y = numpy_random((1,), dtype_str="float32") + x = numpy_random((1, ), dtype_str="float32") + y = numpy_random((1, ), dtype_str="float32") kernel = patch_kernel(kernel, {'GENERATE_TEST_HERE': f"{x_str} {op} {y_str}"}) z = np.array(eval(f"{x_str} {op} {y_str}")) x_tri = to_triton(x, device=device) y_tri = to_triton(y, device=device) - z_tri = to_triton(np.empty((1,), dtype=z.dtype), device=device) - kernel[(1,)](z_tri, x_tri, y_tri) + z_tri = to_triton(np.empty((1, ), dtype=z.dtype), device=device) + kernel[(1, )](z_tri, x_tri, y_tri) np.testing.assert_allclose(z, to_numpy(z_tri)) @@ -3082,7 +3042,7 @@ def kernel(X): tl.store(X + off, off) x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) - kernel[(1,)](x_tri) + kernel[(1, )](x_tri) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256)) @@ -3095,7 +3055,7 @@ def kernel(X, s): tl.store(X + off, val) x_tri = to_triton(np.empty((256, ), dtype=np.int32), device=device) - kernel[(1,)](x_tri, 32) + kernel[(1, )](x_tri, 32) np.testing.assert_equal(to_numpy(x_tri), np.arange(0, 256) % 8) @@ -3105,12 +3065,15 @@ def static_assert_func(): def test_constexpr_propagation(): + @triton.jit def _kernel(COND: tl.constexpr): NEW_COND = COND if NEW_COND: static_assert_func() - _kernel[(1,)](False) + + _kernel[(1, )](False) + # ------------- # test call @@ -3151,11 +3114,11 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): vecmul_kernel(ptr, n_elements, num2, type) size = 1024 - rand_val = numpy_random((size,), dtype_str="float32") + rand_val = numpy_random((size, ), dtype_str="float32") rand_val_tri = to_triton(rand_val, device=device) err_msg = "" try: - kernel[(size // 128,)](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) + kernel[(size // 128, )](rand_val_tri, size, 3, 5, type, num_ctas=num_ctas) except Exception as e: err_msg = str(e) @@ -3165,13 +3128,18 @@ def kernel(ptr, n_elements, num1, num2, type: tl.constexpr): ans = rand_val * 1 * 2 * 1 * 2 * 3 * 4 np.testing.assert_equal(to_numpy(rand_val_tri), ans) + # ------------- # test if # ------------- -@pytest.mark.parametrize("if_type", ["if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", "if_and_static"]) +@pytest.mark.parametrize("if_type", [ + "if", "if_and_dynamic", "if_exp_static", "if_exp_dynamic", "if_exp_dynamic_constexpr", "if_exp_dynamic_void", + "if_and_static" +]) def test_if(if_type, device): + @triton.jit def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr, StaticVaue: tl.constexpr): pid = tl.program_id(0) @@ -3207,7 +3175,7 @@ def kernel(Cond, XTrue, XFalse, Ret, IfType: tl.constexpr, BoolVar: tl.constexpr x_false = torch.tensor([1.51], dtype=torch.float32, device=device) ret = torch.zeros(1, dtype=torch.float32, device=device) - kernel[(1,)](cond, x_true, x_false, ret, if_type, True, 1) + kernel[(1, )](cond, x_true, x_false, ret, if_type, True, 1) assert torch.equal(ret, x_true) @@ -3219,23 +3187,22 @@ def _kernel(dst): pass with pytest.raises(AssertionError, match='must be a power of 2'): - _kernel[(1,)](dst=dst, num_warps=3) - _kernel[(1,)](dst=dst, num_warps=1) - _kernel[(1,)](dst=dst, num_warps=2) - _kernel[(1,)](dst=dst, num_warps=4) + _kernel[(1, )](dst=dst, num_warps=3) + _kernel[(1, )](dst=dst, num_warps=1) + _kernel[(1, )](dst=dst, num_warps=2) + _kernel[(1, )](dst=dst, num_warps=4) + # ------------- # test extern # ------------- -@pytest.mark.parametrize("dtype_str, expr, lib_path", - [('int32', 'math.ffs', ''), - ('float32', 'math.log2', ''), - ('float32', 'math.scalbn', ''), - ('float32', 'math.pow', tl.math.libdevice_path()), - ('float64', 'math.pow_dtype', tl.math.libdevice_path()), - ('float64', 'math.norm4d', '')]) +@pytest.mark.parametrize("dtype_str, expr, lib_path", [('int32', 'math.ffs', ''), ('float32', 'math.log2', ''), + ('float32', 'math.scalbn', ''), + ('float32', 'math.pow', tl.math.libdevice_path()), + ('float64', 'math.pow_dtype', tl.math.libdevice_path()), + ('float64', 'math.norm4d', '')]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_math_tensor(dtype_str, expr, lib_path, num_ctas, device): @@ -3283,8 +3250,8 @@ def kernel(X, Y, BLOCK: tl.constexpr): x_tri = to_triton(x, device=device) # triton result - y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device) - kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}, num_ctas=num_ctas) + y_tri = to_triton(numpy_random((shape[0], ), dtype_str=dtype_str, rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}, num_ctas=num_ctas) # compare if expr == 'math.ffs': np.testing.assert_equal(y_ref, to_numpy(y_tri)) @@ -3292,10 +3259,8 @@ def kernel(X, Y, BLOCK: tl.constexpr): np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) -@pytest.mark.parametrize("dtype_str, expr, lib_path", - [('float32', 'math.pow', ''), - ('float64', 'math.pow_dtype', ''), - ('float64', 'math.pow', tl.math.libdevice_path())]) +@pytest.mark.parametrize("dtype_str, expr, lib_path", [('float32', 'math.pow', ''), ('float64', 'math.pow_dtype', ''), + ('float64', 'math.pow', tl.math.libdevice_path())]) @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_math_scalar(dtype_str, expr, lib_path, num_ctas, device): @@ -3308,7 +3273,7 @@ def kernel(X, Y, BLOCK: tl.constexpr): shape = (128, ) rs = RandomState(17) # limit the range of integers so that the sum does not overflow - x = numpy_random((1,), dtype_str=dtype_str, rs=rs) + x = numpy_random((1, ), dtype_str=dtype_str, rs=rs) y_ref = np.zeros(shape, dtype=x.dtype) # numpy does not allow negative factors in power, so we use abs() @@ -3323,8 +3288,8 @@ def kernel(X, Y, BLOCK: tl.constexpr): # triton result x_tri = to_triton(x, device=device)[0].item() - y_tri = to_triton(numpy_random((shape[0],), dtype_str=dtype_str, rs=rs), device=device) - kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}, num_ctas=num_ctas) + y_tri = to_triton(numpy_random((shape[0], ), dtype_str=dtype_str, rs=rs), device=device) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], extern_libs={'libdevice': lib_path}, num_ctas=num_ctas) # compare np.testing.assert_allclose(y_ref, to_numpy(y_tri), rtol=0.01) @@ -3333,6 +3298,7 @@ def kernel(X, Y, BLOCK: tl.constexpr): # test inline asm # ----------------------- + @pytest.mark.parametrize("num_ctas", num_ctas_list) def test_inline_asm(num_ctas, device): check_cuda_only(device) @@ -3345,7 +3311,8 @@ def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) y = tl.load(Y + tl.arange(0, BLOCK)) s = tl.full([BLOCK], n, tl.int32) - z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, is_pure=True, pack=1) + z = tl.inline_asm_elementwise("shf.l.wrap.b32 $0, $1, $2, $3;", "=r,r, r, r", [x, y, s], dtype=tl.int32, + is_pure=True, pack=1) tl.store(Z + tl.arange(0, BLOCK), z) shape = (128, ) @@ -3356,7 +3323,7 @@ def kernel(X, Y, Z, n: tl.constexpr, BLOCK: tl.constexpr): y_tri = to_triton(y, device=device) n = 17 z_tri = to_triton(numpy_random(shape, dtype_str='uint32', rs=rs), device=device) - kernel[(1,)](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) + kernel[(1, )](x_tri, y_tri, z_tri, n, BLOCK=shape[0], num_ctas=num_ctas) y_ref = (y << n) | (x >> (32 - n)) # compare np.testing.assert_equal(y_ref, to_numpy(z_tri)) @@ -3373,9 +3340,11 @@ def test_inline_asm_packed(num_ctas, device): def kernel(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) # shift 4x8bits values together. - y = tl.inline_asm_elementwise("and.b32 $0, $1, 0x1F1F1F1F; \ - shl.b32 $0, $0, 3;", - "=r,r", [x,], dtype=tl.int8, is_pure=True, pack=4) + y = tl.inline_asm_elementwise( + "and.b32 $0, $1, 0x1F1F1F1F; \ + shl.b32 $0, $0, 3;", "=r,r", [ + x, + ], dtype=tl.int8, is_pure=True, pack=4) tl.store(Y + tl.arange(0, BLOCK), y) shape = (512, ) @@ -3383,19 +3352,19 @@ def kernel(X, Y, BLOCK: tl.constexpr): x = numpy_random(shape, dtype_str='uint8', rs=rs) x_tri = to_triton(x, device=device) y_tri = to_triton(numpy_random(shape, dtype_str='uint8', rs=rs), device=device) - kernel[(1,)](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) + kernel[(1, )](x_tri, y_tri, BLOCK=shape[0], num_ctas=num_ctas) y_ref = x << 3 # compare np.testing.assert_equal(y_ref, to_numpy(y_tri)) + # ----------------------- # test control flow # ----------------------- @pytest.mark.parametrize("lo, hi, iv", [(2**35, 2**35 + 20, 1), (2**35, 2**35 + 20, 2), (2**35, 2**35 + 20, 3), - (15, -16, -1), (15, -16, -2), (15, -16, -3), - (-18, -22, -1), (22, 18, -1)]) + (15, -16, -1), (15, -16, -2), (15, -16, -3), (-18, -22, -1), (22, 18, -1)]) def test_for_iv(lo, hi, iv, device): @triton.jit @@ -3408,8 +3377,8 @@ def kernel(Out, lo, hi, iv: tl.constexpr): lo = 2**35 hi = 2**35 + 20 - out = to_triton(np.zeros((1,), dtype=np.int64), device=device) - kernel[(1,)](out, lo, hi, iv) + out = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + kernel[(1, )](out, lo, hi, iv) assert out[0] == sum(range(lo, hi, iv)) @@ -3423,17 +3392,17 @@ def kernel(Cond, TrueVal, FalseVal, Out): val = tl.load(FalseVal) tl.store(Out, val) - out = to_triton(np.zeros((1,), dtype=np.int32), device=device) - true_val = to_triton(np.full((1,), 1, dtype=np.int32), device=device) - false_val = to_triton(np.full((1,), 2, dtype=np.int32), device=device) - cond = to_triton(np.zeros((1,), dtype=np.int32), device=device) + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + true_val = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + false_val = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + cond = to_triton(np.zeros((1, ), dtype=np.int32), device=device) # True cond[0] = True - kernel[(1,)](cond, true_val, false_val, out) + kernel[(1, )](cond, true_val, false_val, out) assert to_numpy(out)[0] == true_val[0] # False cond[0] = False - kernel[(1,)](cond, true_val, false_val, out) + kernel[(1, )](cond, true_val, false_val, out) assert to_numpy(out)[0] == false_val[0] @@ -3452,15 +3421,15 @@ def kernel(ExitEarly, Out, cond: tl.constexpr, mode: tl.constexpr): return tl.store(Out, 1) - out = to_triton(np.zeros((1,), dtype=np.int32), device=device) - exit_early = to_triton(np.zeros((1,), dtype=np.int32), device=device) + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + exit_early = to_triton(np.zeros((1, ), dtype=np.int32), device=device) # exit early path taken exit_early[0] = 1 - kernel[(1,)](exit_early, out, True, mode) + kernel[(1, )](exit_early, out, True, mode) assert to_numpy(out)[0] == 0 # exit early path not taken exit_early[0] = 0 - kernel[(1,)](exit_early, out, False, mode) + kernel[(1, )](exit_early, out, False, mode) assert to_numpy(out)[0] == 1 @@ -3496,10 +3465,11 @@ def add_fn_static_cond(x, cond: tl.constexpr): # TODO(Keren): if_exp -@pytest.mark.parametrize("call_type", ["attribute", "attribute_jit", - "jit", "jit_if", "jit_expr", - "jit_static_cond", "jit_noinline", "jit_extern"]) +@pytest.mark.parametrize( + "call_type", + ["attribute", "attribute_jit", "jit", "jit_if", "jit_expr", "jit_static_cond", "jit_noinline", "jit_extern"]) def test_if_call(call_type, device): + @triton.jit def kernel(Out, call_type: tl.constexpr): pid = tl.program_id(0) @@ -3558,8 +3528,8 @@ def kernel(Out, call_type: tl.constexpr): tl.store(Out, o) - out = to_triton(np.zeros((1,), dtype=np.int32), device=device) - kernel[(1,)](out, call_type) + out = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + kernel[(1, )](out, call_type) assert to_numpy(out)[0] == 1 @@ -3583,14 +3553,14 @@ def kernel(Cond1, Cond2, Cond3, Val1, Val2, Val3, Out): val = tl.load(Val3) tl.store(Out, val) - out = to_triton(np.full((1,), -1, dtype=np.int32), device=device) - cond1 = to_triton(np.full((1,), _cond1, dtype=np.int32), device=device) - cond2 = to_triton(np.full((1,), _cond2, dtype=np.int32), device=device) - cond3 = to_triton(np.full((1,), _cond3, dtype=np.int32), device=device) - val1 = to_triton(np.full((1,), 1, dtype=np.int32), device=device) - val2 = to_triton(np.full((1,), 2, dtype=np.int32), device=device) - val3 = to_triton(np.full((1,), 3, dtype=np.int32), device=device) - kernel[(1,)](cond1, cond2, cond3, val1, val2, val3, out) + out = to_triton(np.full((1, ), -1, dtype=np.int32), device=device) + cond1 = to_triton(np.full((1, ), _cond1, dtype=np.int32), device=device) + cond2 = to_triton(np.full((1, ), _cond2, dtype=np.int32), device=device) + cond3 = to_triton(np.full((1, ), _cond3, dtype=np.int32), device=device) + val1 = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + val2 = to_triton(np.full((1, ), 2, dtype=np.int32), device=device) + val3 = to_triton(np.full((1, ), 3, dtype=np.int32), device=device) + kernel[(1, )](cond1, cond2, cond3, val1, val2, val3, out) targets = { (True, True, True): val1[0], (True, True, False): val1[0], @@ -3619,19 +3589,20 @@ def kernel(InitI, Bound, CutOff, OutI, OutInitI, OutJ): tl.store(OutI, curr_i) tl.store(OutJ, j) - out_i = to_triton(np.zeros((1,), dtype=np.int32), device=device) - out_j = to_triton(np.zeros((1,), dtype=np.int32), device=device) - init_i = to_triton(np.full((1,), 1, dtype=np.int32), device=device) - out_init_i = to_triton(np.full((1,), 0, dtype=np.int32), device=device) - bound = to_triton(np.full((1,), 10, dtype=np.int32), device=device) - cut_off = to_triton(np.full((1,), 5, dtype=np.int32), device=device) - kernel[(1,)](init_i, bound, cut_off, out_i, out_init_i, out_j) + out_i = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + out_j = to_triton(np.zeros((1, ), dtype=np.int32), device=device) + init_i = to_triton(np.full((1, ), 1, dtype=np.int32), device=device) + out_init_i = to_triton(np.full((1, ), 0, dtype=np.int32), device=device) + bound = to_triton(np.full((1, ), 10, dtype=np.int32), device=device) + cut_off = to_triton(np.full((1, ), 5, dtype=np.int32), device=device) + kernel[(1, )](init_i, bound, cut_off, out_i, out_init_i, out_j) assert out_init_i[0] == init_i[0] assert out_i[0] == init_i[0] + 1 assert out_j[0] == bound[0] def test_while2(device): + @triton.jit def nested_while(data, countPtr): for i in range(10): @@ -3641,8 +3612,8 @@ def nested_while(data, countPtr): count = count - 2 counter = torch.tensor([8], dtype=torch.int32, device=device) - data = torch.zeros((1,), device=device, dtype=torch.float32) - nested_while[(1,)](data, counter) + data = torch.zeros((1, ), device=device, dtype=torch.float32) + nested_while[(1, )](data, counter) assert data[0] == 40 @@ -3670,6 +3641,7 @@ def nested_while(data, countPtr): # test extra # ----------------------- + def test_num_threads(device): if is_hip(): pytest.skip("test_num_threads is not supported in HIP") @@ -3682,8 +3654,8 @@ def kernel(Out): tl.store(Out + offs, 1) num_threads = 256 - out = to_triton(np.zeros((num_threads,), dtype=np.int32), device=device) - kernel[(1,)](out, num_warps=num_threads // 32) + out = to_triton(np.zeros((num_threads, ), dtype=np.int32), device=device) + kernel[(1, )](out, num_warps=num_threads // 32) assert torch.sum(out) == 256 @@ -3701,9 +3673,9 @@ def kernel(Out1, Out2): end = tl.extra.cuda.globaltimer() tl.store(Out2, end - start) - out1 = to_triton(np.zeros((128,), dtype=np.int64), device=device) - out2 = to_triton(np.zeros((1,), dtype=np.int64), device=device) - h = kernel[(1,)](out1, out2) + out1 = to_triton(np.zeros((128, ), dtype=np.int64), device=device) + out2 = to_triton(np.zeros((1, ), dtype=np.int64), device=device) + h = kernel[(1, )](out1, out2) assert out2[0] > 0 assert h.asm["ptx"].count("%globaltimer") == 2 @@ -3717,17 +3689,17 @@ def test_smid(device): def kernel(Out): tl.store(Out + tl.program_id(0), tl.extra.cuda.smid()) - out = to_triton(np.zeros((1024,), dtype=np.int32), device=device) - h = kernel[(out.shape[0],)](out) + out = to_triton(np.zeros((1024, ), dtype=np.int32), device=device) + h = kernel[(out.shape[0], )](out) assert out.sort()[0].unique().shape[0] > 0 assert h.asm["ptx"].count("%smid") == 1 + # ----------------------- # test layout conversions # ----------------------- # TODO: backend should be tested separately - layouts = [ # MmaLayout(1, [1, 4], [1, 1], [0, 1]), # MmaLayout((2, 0), [1, 4], [1, 1], [1, 1], [0, 1], [16, 8]), @@ -3827,15 +3799,17 @@ def test_convert2d(M, N, src_layout, interm_layout, dst_layout, dtype, device): def test_load_scalar_with_mask(device): + @triton.jit def kernel(Input, Index, Out, N: int): index = tl.load(Index) scalar = tl.load(Input + index, mask=index < N, other=0) tl.store(Out, scalar, mask=index < N) + Index = torch.tensor([0], dtype=torch.int32, device=device) Input = torch.tensor([0], dtype=torch.int32, device=device) Out = torch.empty_like(Index, device=device) - kernel[(1,)](Input, Index, Out, Index.numel()) + kernel[(1, )](Input, Index, Out, Index.numel()) assert Out.data[0] == 0 @@ -3843,6 +3817,7 @@ def kernel(Input, Index, Out, N: int): # maybe delete it later after ptxas has been fixed @pytest.mark.parametrize("dtype_str", ['float16', 'int16']) def test_ptx_cast(dtype_str, device): + @triton.jit def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): xoffset = tl.program_id(0) * XBLOCK @@ -3874,15 +3849,17 @@ def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.co s0 = 4 buf11 = -torch.ones((6 * s0, 197, 197), device=device, dtype=torch_dtype) buf14 = -torch.ones((s0, 6, 197, 197), device=device, dtype=torch_dtype) - kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) + kernel[(4728, )](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2) assert buf14.to(torch.float32).mean() == -2.0 + # ----------------------- # test fp8 -> fp32 dot # ----------------------- def f8_to_f16(x, dtype): + @triton.jit def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(0) @@ -3892,21 +3869,21 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): tl.store(Y + offs, x, mask=mask) ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) - grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) dtype = getattr(tl, dtype) kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) return ret @triton.jit -def matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - low_precision_acc: tl.constexpr, +def matmul_kernel( # + a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + low_precision_acc: tl.constexpr # ): pid = tl.program_id(axis=0) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -3944,9 +3921,7 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device): a = to_triton(A, device='cuda', dst_type=in_type_str) b = to_triton(B, device='cuda', dst_type=in_type_str) grid = (triton.cdiv(M, BLOCK_M), 1) - matmul_kernel[grid](a, b, C, M, N, K, - a.stride(0), a.stride(1), b.stride(0), b.stride( - 1), C.stride(0), C.stride(1), + matmul_kernel[grid](a, b, C, M, N, K, a.stride(0), a.stride(1), b.stride(0), b.stride(1), C.stride(0), C.stride(1), BLOCK_M, BLOCK_N, BLOCK_K, low_precision_acc, num_warps=num_warps) torch_a = torch.from_numpy(A) th_a = f8_to_f16(torch_a.cuda(), in_type_str) @@ -3960,6 +3935,7 @@ def test_fp8_dot_acc(in_type_str, low_precision_acc, device): else: torch.testing.assert_close(ref_out, C) + # ----------------------- # test enable_fp_fusion # ----------------------- @@ -3973,8 +3949,8 @@ def mul_add(data): ptrs = data + tl.arange(0, 128) tl.store(ptrs, tl.load(ptrs) * 1.5 + 1.0) - data = torch.randn((128,), device='cuda', dtype=torch.float32) - h = mul_add[(1,)](data, enable_fp_fusion=enable_fp_fusion) + data = torch.randn((128, ), device='cuda', dtype=torch.float32) + h = mul_add[(1, )](data, enable_fp_fusion=enable_fp_fusion) found_fma = re.search(r'(mad|fma)\.r[nzmp]\.(ftz\.)?f32', h.asm["ptx"]) is not None assert found_fma == enable_fp_fusion diff --git a/python/test/unit/language/test_line_info.py b/python/test/unit/language/test_line_info.py index fc73f2bf374a..026b204837fd 100644 --- a/python/test/unit/language/test_line_info.py +++ b/python/test/unit/language/test_line_info.py @@ -10,9 +10,7 @@ @triton.jit -def kernel_single(X, - Y, - BLOCK: tl.constexpr): +def kernel_single(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) tl.store(Y + tl.arange(0, BLOCK), x) @@ -23,9 +21,7 @@ def device_inline(x): @triton.jit -def kernel_call(X, - Y, - BLOCK: tl.constexpr): +def kernel_call(X, Y, BLOCK: tl.constexpr): x = tl.load(X + tl.arange(0, BLOCK)) y = device_inline(x) tl.store(Y + tl.arange(0, BLOCK), y) @@ -91,13 +87,13 @@ def test_line_info(func: str): y = torch.zeros(shape, dtype=x.dtype, device="cuda") kernel_info = {} if func == "single": - kernel_info = kernel_single[(1,)](x, y, BLOCK=shape[0]) + kernel_info = kernel_single[(1, )](x, y, BLOCK=shape[0]) elif func == "call": - kernel_info = kernel_call[(1,)](x, y, BLOCK=shape[0]) + kernel_info = kernel_call[(1, )](x, y, BLOCK=shape[0]) elif func == "call_noinline": - kernel_info = kernel_call_noinline[(1,)](x, y, BLOCK=shape[0]) + kernel_info = kernel_call_noinline[(1, )](x, y, BLOCK=shape[0]) elif func == "multi_files": - kernel_info = kernel_multi_files[(1,)](x, y, BLOCK=shape[0]) + kernel_info = kernel_multi_files[(1, )](x, y, BLOCK=shape[0]) file_lines = extract_file_lines(kernel_info.asm["cubin"]) if func == "single": diff --git a/python/test/unit/language/test_random.py b/python/test/unit/language/test_random.py index e12adff1e7ee..7f6784d0b9a0 100644 --- a/python/test/unit/language/test_random.py +++ b/python/test/unit/language/test_random.py @@ -12,6 +12,7 @@ class PhiloxConfig: + def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, DTYPE): self.PHILOX_ROUND_A = np.array(PHILOX_ROUND_A, dtype=DTYPE) self.PHILOX_ROUND_B = np.array(PHILOX_ROUND_B, dtype=DTYPE) @@ -40,6 +41,7 @@ def __init__(self, PHILOX_ROUND_A, PHILOX_ROUND_B, PHILOX_KEY_A, PHILOX_KEY_B, D class CustomPhilox4x: + def __init__(self, seed, config): self._config = config seed = self._into_pieces(seed) @@ -92,6 +94,7 @@ def advance(self, n_steps): class CustomPhilox(CustomPhilox4x): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.buffer = [] @@ -111,10 +114,9 @@ def random_raw(self): # test generation of random uint32 -@pytest.mark.parametrize('size, seed', - [(size, seed) for size in ['10', '4,53', '10000'] - for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]] - ) +@pytest.mark.parametrize('size, seed', [(size, seed) + for size in ['10', '4,53', '10000'] + for seed in [0, 42, 124, 54, 0xffffffff, 0xdeadbeefcafeb0ba]]) def test_randint(size, seed, device): size = list(map(int, size.split(','))) @@ -123,10 +125,11 @@ def kernel(X, N, seed): offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) rand = tl.randint(seed, offset) tl.store(X + offset, rand, mask=offset < N) + # triton result x = torch.empty(size, dtype=torch.int32, device=device) N = x.numel() - grid = (triton.cdiv(N, BLOCK),) + grid = (triton.cdiv(N, BLOCK), ) kernel[grid](x, N, seed) out_tri = x.cpu().numpy().astype(np.uint32).flatten().tolist() # reference result @@ -134,44 +137,44 @@ def kernel(X, N, seed): out_ref = [gen.random_raw()[0] for _ in out_tri] assert out_tri == out_ref + # test uniform PRNG -@pytest.mark.parametrize('size, seed', - [(size, seed) for size in [1000000] - for seed in [0, 42, 124, 54]] - ) +@pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000] for seed in [0, 42, 124, 54]]) def test_rand(size, seed, device): + @triton.jit def kernel(X, N, seed): offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) rand = tl.rand(seed, offset) tl.store(X + offset, rand, mask=offset < N) + # triton result x = torch.empty(size, dtype=torch.float32, device=device) N = x.numel() - grid = (triton.cdiv(N, BLOCK),) + grid = (triton.cdiv(N, BLOCK), ) kernel[grid](x, N, seed) assert all((x >= 0) & (x <= 1)) assert scipy.stats.kstest(x.tolist(), 'uniform', args=(0, 1)).statistic < 0.01 + # test normal PRNG -@pytest.mark.parametrize('size, seed', - [(size, seed) for size in [1000000] - for seed in [0, 42, 124, 54]] - ) +@pytest.mark.parametrize('size, seed', [(size, seed) for size in [1000000] for seed in [0, 42, 124, 54]]) def test_randn(size, seed, device): + @triton.jit def kernel(X, N, seed): offset = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) rand = tl.randn(seed, offset) tl.store(X + offset, rand, mask=offset < N) + # triton result x = torch.empty(size, dtype=torch.float32, device=device) N = x.numel() - grid = (triton.cdiv(N, BLOCK),) + grid = (triton.cdiv(N, BLOCK), ) kernel[grid](x, N, seed) assert abs(x.mean()) < 1e-2 assert abs(x.std() - 1) < 1e-2 @@ -179,7 +182,9 @@ def kernel(X, N, seed): # tl.rand() should never produce >=1.0 + def test_rand_limits(device): + @triton.jit def kernel(input, output, n: tl.constexpr): idx = tl.arange(0, n) @@ -192,7 +197,7 @@ def kernel(input, output, n: tl.constexpr): torch.iinfo(torch.int32).max, ], dtype=torch.int32, device=device) output = torch.empty(2, dtype=torch.float32, device=device) - kernel[(1,)](min_max_int32, output, 2) + kernel[(1, )](min_max_int32, output, 2) assert output[0] == output[1] assert 1.0 - torch.finfo(torch.float32).eps <= output[0].item() < 1.0 diff --git a/python/test/unit/language/test_subprocess.py b/python/test/unit/language/test_subprocess.py index 8a4c8957d6e5..08bc63a3e2ed 100644 --- a/python/test/unit/language/test_subprocess.py +++ b/python/test/unit/language/test_subprocess.py @@ -17,16 +17,15 @@ # TODO: Print with multiple operands -@pytest.mark.parametrize("func_type, data_type", - [("device_print", data_type) for data_type in torch_types] + [ - ("print", "int32"), - ("static_print", "int32"), - ("no_arg_print", "int32"), - ("print_no_arg", "int32"), - ("device_print_large", "int32"), - ("print_multiple_args", "int32"), - ("device_print_multiple_args", "int32"), - ]) +@pytest.mark.parametrize("func_type, data_type", [("device_print", data_type) for data_type in torch_types] + [ + ("print", "int32"), + ("static_print", "int32"), + ("no_arg_print", "int32"), + ("print_no_arg", "int32"), + ("device_print_large", "int32"), + ("print_multiple_args", "int32"), + ("device_print_multiple_args", "int32"), +]) def test_print(func_type: str, data_type: str): proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False) outs, _ = proc.communicate() @@ -71,7 +70,8 @@ def test_print(func_type: str, data_type: str): @pytest.mark.parametrize("func_type", assert_types) def test_assert(func_type: str): os.environ["TRITON_DEBUG"] = "1" - proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) + proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, + shell=False) _, errs = proc.communicate() errs = errs.splitlines() num_errs = 0 @@ -91,7 +91,8 @@ def test_assert(func_type: str): @pytest.mark.parametrize("caller_type, callee_type", nested_types) def test_assert_nested(caller_type, callee_type): - proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False) + proc = subprocess.Popen([sys.executable, assert_path, caller_type, callee_type], stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=False) _, errs = proc.communicate() errs = errs.splitlines() num_errs = 0 diff --git a/python/test/unit/operators/test_blocksparse.py b/python/test/unit/operators/test_blocksparse.py index 7e6f820a374d..acc5e30c68e8 100644 --- a/python/test/unit/operators/test_blocksparse.py +++ b/python/test/unit/operators/test_blocksparse.py @@ -68,8 +68,7 @@ def test_matmul(MODE, TRANS_A, TRANS_B, BLOCK, DTYPE, Z=3, H=2, M=512, N=384, K= b_ref = do_mask(b_ref) if is_dds else b_ref a_ref.retain_grad() b_ref.retain_grad() - c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, - b_ref.transpose(2, 3) if TRANS_B else b_ref) + c_ref = torch.matmul(a_ref.transpose(2, 3) if TRANS_A else a_ref, b_ref.transpose(2, 3) if TRANS_B else b_ref) c_ref.backward(dc_ref) c_ref = do_sparsify(c_ref) if is_sdd else c_ref da_ref = do_sparsify(a_ref.grad) if is_dsd else a_ref.grad @@ -172,7 +171,7 @@ def test_attention_fwd_bwd( value.retain_grad() attn_out = triton_attention(layout, block, query=query, key=key, value=value, scale=scale) # ad hoc loss - loss = (attn_out ** 2).mean() + loss = (attn_out**2).mean() loss.backward() grads = [query.grad, key.grad, value.grad] @@ -189,7 +188,7 @@ def test_attention_fwd_bwd( probs = torch.softmax(scores, dim=-1) torch_attn_out = torch.einsum("bhst,bhtd->bhsd", probs, torch_v) # ad hoc loss - torch_loss = (torch_attn_out ** 2).mean() + torch_loss = (torch_attn_out**2).mean() torch_loss.backward() torch_grads = [torch_q.grad, torch_k.grad, torch_v.grad] @@ -209,8 +208,10 @@ def triton_attention( value: torch.Tensor, scale: float, ): - sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, device=value.device) - sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, device=value.device) + sparse_dot_sdd_nt = triton.ops.blocksparse.matmul(layout, block, "sdd", trans_a=False, trans_b=True, + device=value.device) + sparse_dot_dsd_nn = triton.ops.blocksparse.matmul(layout, block, "dsd", trans_a=False, trans_b=False, + device=value.device) sparse_softmax = triton.ops.blocksparse.softmax(layout, block, device=value.device) w = sparse_dot_sdd_nt(query, key) diff --git a/python/test/unit/operators/test_cross_entropy.py b/python/test/unit/operators/test_cross_entropy.py index f6ae42ac3e9a..5bffd2ad835d 100644 --- a/python/test/unit/operators/test_cross_entropy.py +++ b/python/test/unit/operators/test_cross_entropy.py @@ -5,14 +5,13 @@ import triton.ops -@pytest.mark.parametrize("M, N, dtype, mode", - [ - (M, N, dtype, mode) for M in [1024, 821] - for N in [512, 857, 1871, 2089, 8573, 31000] - for dtype in ['float16', 'float32'] - for mode in ['forward', 'backward'] - ] - ) +@pytest.mark.parametrize("M, N, dtype, mode", [ # + (M, N, dtype, mode) + for M in [1024, 821] + for N in [512, 857, 1871, 2089, 8573, 31000] + for dtype in ['float16', 'float32'] + for mode in ['forward', 'backward'] +]) def test_op(M, N, dtype, mode): capability = torch.cuda.get_device_capability() if capability[0] < 8 and dtype == "bfloat16": diff --git a/python/test/unit/operators/test_flash_attention.py b/python/test/unit/operators/test_flash_attention.py index 48cb57ce0d2a..2faf85a51cc6 100644 --- a/python/test/unit/operators/test_flash_attention.py +++ b/python/test/unit/operators/test_flash_attention.py @@ -5,10 +5,12 @@ import triton.ops -@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(2, 4, 512, 16), - (2, 4, 512, 32), - (2, 4, 512, 64), - (2, 4, 512, 128)]) +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ # + (2, 4, 512, 16), + (2, 4, 512, 32), + (2, 4, 512, 64), + (2, 4, 512, 128), +]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('seq_par', [True, False]) @@ -63,24 +65,21 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype, causal, seq_par): BATCH, N_HEADS, N_CTX, D_HEAD = 4, 48, 4096, 64 # vary seq length for fixed head and batch=4 -configs = [triton.testing.Benchmark( - x_names=['N_CTX'], - x_vals=[2**i for i in range(10, 14)], - line_arg='provider', - line_vals=['triton'] + (['flash'] if HAS_FLASH else []), - line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), - styles=[('red', '-'), ('blue', '-')], - ylabel='ms', - plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}', - args={ - 'H': N_HEADS, - 'BATCH': BATCH, - 'D_HEAD': D_HEAD, - 'dtype': torch.float16, - 'mode': mode, - 'casual': casual, - 'seq_par': seq_par} -) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False]] +configs = [ + triton.testing.Benchmark( + x_names=['N_CTX'], x_vals=[2**i for i in range(10, 14)], line_arg='provider', + line_vals=['triton'] + (['flash'] if HAS_FLASH else []), + line_names=['Triton'] + (['Flash'] if HAS_FLASH else []), styles=[('red', '-'), ('blue', '-')], ylabel='ms', + plot_name=f'fused-attention-batch{BATCH}-head{N_HEADS}-d{D_HEAD}-{mode}-{casual}-{seq_par}', args={ + 'H': N_HEADS, + 'BATCH': BATCH, + 'D_HEAD': D_HEAD, + 'dtype': torch.float16, + 'mode': mode, + 'casual': casual, + 'seq_par': seq_par, + }) for mode in ['fwd', 'bwd'] for casual in [True, False] for seq_par in [True, False] +] @triton.testing.perf_report(configs) @@ -101,9 +100,8 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, mode, casual, seq_par, provid ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) return ms if provider == "flash": - lengths = torch.full((BATCH,), fill_value=N_CTX, device=device) - cu_seqlens = torch.zeros( - (BATCH + 1,), device=device, dtype=torch.int32) + lengths = torch.full((BATCH, ), fill_value=N_CTX, device=device) + cu_seqlens = torch.zeros((BATCH + 1, ), device=device, dtype=torch.int32) cu_seqlens[1:] = lengths.cumsum(0) fn = lambda: flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=sm_scale, causal=casual) if mode == 'bwd': diff --git a/python/test/unit/operators/test_inductor.py b/python/test/unit/operators/test_inductor.py index 579d0ad935da..2fdfe235e824 100644 --- a/python/test/unit/operators/test_inductor.py +++ b/python/test/unit/operators/test_inductor.py @@ -8,7 +8,8 @@ def test_normalization_with_remat(): @triton.jit - def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): + def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel, rnumel, XBLOCK: tl.constexpr, + RBLOCK: tl.constexpr): xnumel = 512 rnumel = 4096 xoffset = tl.program_id(0) * XBLOCK @@ -52,7 +53,7 @@ def triton_(in_out_ptr0, in_out_ptr1, in_ptr0, in_ptr1, in_ptr2, in_ptr3, xnumel arg115_1 = torch.rand(64, device="cuda") arg8_1 = torch.rand(64, device="cuda") arg9_1 = torch.rand(64, device="cuda") - triton_[(512,)](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048) + triton_[(512, )](buf14, buf16, arg114_1, arg115_1, arg8_1, arg9_1, 512, 4096, 1, 2048) torch.testing.assert_close(buf16.mean().item(), buf14.mean().item(), atol=1e-7, rtol=0) @@ -148,7 +149,7 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): inp = torch.ones(8, 2048, 8, 8, device="cuda", dtype=torch.half) out = torch.ones_like(inp) * 3 numel = inp.numel() - triton_[(numel // 1024,)](inp, out, 1024) + triton_[(numel // 1024, )](inp, out, 1024) out_ref = torch.ones_like(inp) out_ref[:, :, 1:7, 0::7] = 2 / 3 out_ref[:, :, 0::7, 1:7] = 2 / 3 @@ -159,6 +160,7 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr): @pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128]) @pytest.mark.parametrize("num_warps", [1, 4]) def test_scan2d_broadcast(RBLOCK, num_warps): + @triton.jit(debug=True) def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): rindex = tl.arange(0, RBLOCK)[None, :] @@ -172,12 +174,13 @@ def fn(in_ptr, out_ptr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr): XBLOCK = 4 input = torch.randint(0, 10, (1, RBLOCK), dtype=torch.int64, device='cuda') output = torch.empty((XBLOCK, RBLOCK), dtype=torch.int64, device='cuda') - fn[(1,)](input, output, XBLOCK, RBLOCK, num_warps=num_warps) + fn[(1, )](input, output, XBLOCK, RBLOCK, num_warps=num_warps) ref = input.cumsum(1).broadcast_to((XBLOCK, RBLOCK)) torch.testing.assert_close(output, ref) def test_scan2d_for(): + @triton.jit def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr): rbase = tl.arange(0, RBLOCK)[None, :] @@ -190,6 +193,6 @@ def fn(out_ptr0, rnumel, RBLOCK: tl.constexpr): RBLOCK = 8 out0 = torch.empty(RBLOCK, device="cuda", dtype=torch.int64) - fn[(1,)](out0, RBLOCK, RBLOCK) + fn[(1, )](out0, RBLOCK, RBLOCK) ref = torch.arange(RBLOCK, device="cuda", dtype=torch.int64) + 1 torch.testing.assert_close(out0, ref) diff --git a/python/test/unit/operators/test_matmul.py b/python/test/unit/operators/test_matmul.py index 642b0982b45a..801cb8a41962 100644 --- a/python/test/unit/operators/test_matmul.py +++ b/python/test/unit/operators/test_matmul.py @@ -19,7 +19,7 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): tl.store(Y + offs, x, mask=mask) ret = torch.empty(x.shape, dtype=torch.float16, device=x.device) - grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),) + grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']), ) dtype = getattr(tl, dtype) kernel[grid](ret, triton.reinterpret(x, dtype), ret.numel(), BLOCK_SIZE=1024) return ret @@ -28,87 +28,88 @@ def kernel(Y, X, N, BLOCK_SIZE: tl.constexpr): @pytest.mark.parametrize( "BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM", itertools.chain( - *[ - [ - # 1 warp - (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - # 2 warp - (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - # 4 warp - (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - # 8 warp - (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), - # variable input - (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True), - (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True), - ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] - ], + *[[ + # 1 warp + (16, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 16, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 32, 16, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 16, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 32, 32, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 16, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (16, 64, 64, 1, 1, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + # 2 warp + (64, 32, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 64, 64, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 64, 16, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 32, 1, 2, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + # 4 warp + (128, 64, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (64, 128, 16, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 32, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (128, 32, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (32, 128, 64, 1, 4, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + # 8 warp + (128, 256, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 16, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 32, 1, 8, 2, None, None, None, AT, BT, DTYPE, DTYPE, True, True), + # variable input + (128, 128, 32, 1, 4, 2, 256, 384, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 107, 233, 128, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, 2, 107, 233, 83, AT, BT, DTYPE, DTYPE, True, True), + (128, 256, 64, 1, 8, 3, 256, 512, 160, AT, BT, DTYPE, DTYPE, True, True), + ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True]], # n-stage - *[ - [ - (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True), - (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True), - (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True), - (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True), - (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True), - ] for DTYPE in ["float16", "bfloat16", "float32"] for AT in [False, True] for BT in [False, True] for STAGES in [4] - ], + *[[ + (16, 16, 16, 1, 1, STAGES, 32, 32, 80, AT, BT, DTYPE, DTYPE, True, True), + (64, 32, 64, 1, 2, STAGES, 128, 64, 128, AT, BT, DTYPE, DTYPE, True, True), + (128, 64, 16, 1, 4, STAGES, 256, 128, 80, AT, BT, DTYPE, DTYPE, True, True), + (256, 128, 32, 1, 8, STAGES, 512, 256, 160, AT, BT, DTYPE, DTYPE, True, True), + (128, 128, 32, 1, 4, STAGES, 256, 256, 160, AT, BT, DTYPE, DTYPE, True, True), + ] + for DTYPE in ["float16", "bfloat16", "float32"] + for AT in [False, True] + for BT in [False, True] + for STAGES in [4]], # mixed-precision - *[ - [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), - ] for ADTYPE, BDTYPE in [("float8e4nv", "float8e5"), - ("float8e4nv", "float8e4nv"), - ("float8e5", "float8e4nv"), - ("float8e5", "float8e5"), - ("float8e4b15", "float8e4b15"), - ("float8e4nv", "float16"), - ("float16", "float8e5"), - ("float16", "float32"), - ("float32", "float16"), - ("bfloat16", "float32"), - ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False] - ], + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, True, FASTACCUM), + ] for ADTYPE, BDTYPE in [ + ("float8e4nv", "float8e5"), + ("float8e4nv", "float8e4nv"), + ("float8e5", "float8e4nv"), + ("float8e5", "float8e5"), + ("float8e4b15", "float8e4b15"), + ("float8e4nv", "float16"), + ("float16", "float8e5"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16"), + ] for AT in [False, True] for BT in [False, True] for FASTACCUM in [True, False]], # mixed-precision block layout - *[ - [ - (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), - (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), - (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True), - ] for ADTYPE, BDTYPE in [("float8e4nv", "float16"), - ("float16", "float8e5"), - ("float16", "float32"), - ("float32", "float16"), - ("bfloat16", "float32"), - ("float32", "bfloat16")] for AT in [False, True] for BT in [False, True] - ], + *[[ + (32, 32, 32, 1, 1, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), + (128, 256, 32, 1, 8, 2, None, None, None, AT, BT, ADTYPE, BDTYPE, False, True), + (32, 64, 32, 1, 1, 2, 64, 128, 32, AT, BT, ADTYPE, BDTYPE, False, True), + ] for ADTYPE, BDTYPE in [ + ("float8e4nv", "float16"), + ("float16", "float8e5"), + ("float16", "float32"), + ("float32", "float16"), + ("bfloat16", "float32"), + ("float32", "bfloat16"), + ] for AT in [False, True] for BT in [False, True]], ), ) -def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, F8_FASTACCUM): +def test_op(BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, NWARP, NSTAGE, M, N, K, AT, BT, ADTYPE, BDTYPE, ALLOW_TF32, + F8_FASTACCUM): capability = torch.cuda.get_device_capability() if capability[0] < 7: pytest.skip("Only test tl.dot() on devices with sm >= 70") @@ -147,7 +148,7 @@ def init_input(m, n, dtype): return torch.randint(-128, 127, (m, n), device="cuda", dtype=torch.int8) dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32}[dtype] exponents = torch.randint(-10, 0, size=(m, n)) - ret = (2. ** exponents).to(dtype).to("cuda") + ret = (2.**exponents).to(dtype).to("cuda") return ret # allocate/transpose inputs diff --git a/python/test/unit/runtime/test_autotuner.py b/python/test/unit/runtime/test_autotuner.py index 0c02830ec44a..198717a32663 100644 --- a/python/test/unit/runtime/test_autotuner.py +++ b/python/test/unit/runtime/test_autotuner.py @@ -17,7 +17,8 @@ def _kernel(dst, src, N, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(src + offsets, mask=offsets < N) tl.store(dst + offsets, x, mask=offsets < N) - grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) _kernel[grid](dst, src, N) _kernel[grid](dst=dst, src=src, N=N) @@ -34,6 +35,7 @@ def _kernel(src, N, BLOCK_SIZE: tl.constexpr): offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(src + offsets, mask=offsets < N) + 1 tl.store(src + offsets, x, mask=offsets < N) - grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']),) + + grid = lambda META: (triton.cdiv(N, META['BLOCK_SIZE']), ) _kernel[grid](src, N) triton.testing.assert_close(src, torch.ones_like(src)) diff --git a/python/test/unit/runtime/test_cache.py b/python/test/unit/runtime/test_cache.py index f75fa7c32800..c8e2e91fa4ed 100644 --- a/python/test/unit/runtime/test_cache.py +++ b/python/test/unit/runtime/test_cache.py @@ -80,11 +80,12 @@ def test_reuse(): def inc_counter(*args, **kwargs): nonlocal counter counter += 1 + JITFunction.cache_hook = inc_counter reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') for i in range(10): - kernel[(1,)](x, 1, BLOCK=1024) + kernel[(1, )](x, 1, BLOCK=1024) assert counter == 1 @@ -95,17 +96,19 @@ def test_specialize(mode): def inc_counter(*args, **kwargs): nonlocal counter counter += 1 + JITFunction.cache_hook = inc_counter reset_tmp_dir() x = torch.empty(1, dtype=torch.int32, device='cuda') function = {'enable': kernel, 'disable': kernel_nospec}[mode] target = {'enable': 4, 'disable': 1}[mode] for i in [1, 2, 4, 8, 16, 32]: - function[(1,)](x, i, BLOCK=512) + function[(1, )](x, i, BLOCK=512) assert counter == target def test_annotation(): + @triton.jit def kernel(X, i: tl.int32): tl.store(X, i) @@ -113,14 +116,15 @@ def kernel(X, i: tl.int32): x = torch.empty(1, dtype=torch.int32, device='cuda') device = torch.cuda.current_device() - kernel[(1,)](x, 1) - kernel[(1,)](x, 8) - kernel[(1,)](x, 16) - kernel[(1,)](x, 17) + kernel[(1, )](x, 1) + kernel[(1, )](x, 8) + kernel[(1, )](x, 16) + kernel[(1, )](x, 17) assert len(kernel.cache[device]) == 4 def test_constexpr_not_callable() -> None: + @triton.jit def kernel(X, c: tl.constexpr): tl.store(X, 2) @@ -141,11 +145,11 @@ def kernel(X, c: tl.constexpr): def test_jit_warmup_cache() -> None: + @triton.jit def kernel_add(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) - tl.store(o + idx, - tl.load(a + idx) + tl.load(b + idx)) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) args = [ torch.randn(32, dtype=torch.float32, device="cuda"), @@ -155,31 +159,31 @@ def kernel_add(a, b, o, N: tl.constexpr): ] device = torch.cuda.current_device() assert len(kernel_add.cache[device]) == 0 - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 1 - kernel_add.warmup(*args, grid=(1,)) + kernel_add.warmup(*args, grid=(1, )) assert len(kernel_add.cache[device]) == 1 - kernel_add.warmup(*args, grid=(1,)) + kernel_add.warmup(*args, grid=(1, )) assert len(kernel_add.cache[device]) == 1 def test_jit_debug() -> None: + @triton.jit def kernel_add(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) tl.device_assert(idx < 32, "idx < 32") - tl.store(o + idx, - tl.load(a + idx) + tl.load(b + idx)) + tl.store(o + idx, tl.load(a + idx) + tl.load(b + idx)) device = torch.cuda.current_device() assert len(kernel_add.cache[device]) == 0 - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 1 kernel_add.debug = False - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 2 kernel_add.debug = True - kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add.cache[device]) == 3 bins = list(kernel_add.cache[device].values()) assert bins[2].asm['ttir'] != bins[1].asm['ttir'] @@ -192,13 +196,14 @@ def add_fn(a, b, o, N: tl.constexpr): def test_jit_noinline() -> None: + @triton.jit def kernel_add_device(a, b, o, N: tl.constexpr): add_fn(a, b, o, N) device = torch.cuda.current_device() assert len(kernel_add_device.cache[device]) == 0 - kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add_device.cache[device]) == 1 bins = list(kernel_add_device.cache[device].values()) inline_ttir = bins[0].asm['ttir'] @@ -206,7 +211,7 @@ def kernel_add_device(a, b, o, N: tl.constexpr): add_fn.hash = None kernel_add_device.hash = None kernel_add_device.cache[device].clear() - kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,)) + kernel_add_device.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1, )) assert len(kernel_add_device.cache[device]) == 1 bins = list(kernel_add_device.cache[device].values()) noinline_ttir = bins[0].asm['ttir'] @@ -214,6 +219,7 @@ def kernel_add_device(a, b, o, N: tl.constexpr): def test_memory_leak() -> None: + @triton.jit def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): xnumel = 10 diff --git a/python/test/unit/runtime/test_launch.py b/python/test/unit/runtime/test_launch.py index d3f9fd01bda8..00009f230f0b 100644 --- a/python/test/unit/runtime/test_launch.py +++ b/python/test/unit/runtime/test_launch.py @@ -31,11 +31,11 @@ def kernel(in_ptr0, out_ptr0, xnumel, XBLOCK: tl.constexpr): try: inp = torch.randn(10, device='cuda') out = torch.randn(10, device='cuda') - kernel[(10,)](inp, out, 10, XBLOCK=16) + kernel[(10, )](inp, out, 10, XBLOCK=16) gc.collect() begin, _ = tracemalloc.get_traced_memory() for _ in range(100): - kernel[(10,)](inp, out, 10, XBLOCK=16) + kernel[(10, )](inp, out, 10, XBLOCK=16) gc.collect() end, _ = tracemalloc.get_traced_memory() assert end - begin < 30000 diff --git a/python/test/unit/runtime/test_subproc.py b/python/test/unit/runtime/test_subproc.py index 05e36eee8f1c..f1039d011e2c 100644 --- a/python/test/unit/runtime/test_subproc.py +++ b/python/test/unit/runtime/test_subproc.py @@ -17,14 +17,17 @@ def reset_tmp_dir(): shutil.rmtree(tmpdir, ignore_errors=True) -instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]) +instance_descriptor = namedtuple("instance_descriptor", + ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"]) def compile_fn(config, cc): + @triton.jit def kernel_sub(a, b, o, N: tl.constexpr): idx = tl.arange(0, N) tl.store(o + idx, tl.load(a + idx) - tl.load(b + idx) * 777) + triton.compile( fn=kernel_sub, signature={0: "*fp32", 1: "*fp32", 2: "*fp32"}, @@ -42,15 +45,14 @@ def test_compile_in_subproc() -> None: config = instance_descriptor(tuple(range(4)), (), (), ()) multiprocessing.set_start_method('fork') - proc = multiprocessing.Process( - target=compile_fn, - args=(config, cc)) + proc = multiprocessing.Process(target=compile_fn, args=(config, cc)) proc.start() proc.join() assert proc.exitcode == 0 def compile_fn_dot(config, cc): + @triton.jit def kernel_dot(Z): offs = tl.arange(0, 16)[:, None] * 16 + tl.arange(0, 16)[None, :] @@ -75,9 +77,7 @@ def test_compile_in_forked_subproc() -> None: config = instance_descriptor(tuple(range(1)), (), (), ()) assert multiprocessing.get_start_method() == 'fork' - proc = multiprocessing.Process( - target=compile_fn_dot, - args=(config, cc)) + proc = multiprocessing.Process(target=compile_fn_dot, args=(config, cc)) proc.start() proc.join() assert proc.exitcode == 0 diff --git a/python/test/unit/tools/test_aot.py b/python/test/unit/tools/test_aot.py index cd6900321471..92b5562e9527 100644 --- a/python/test/unit/tools/test_aot.py +++ b/python/test/unit/tools/test_aot.py @@ -173,8 +173,7 @@ def gen_test_bin(dir, M, N, K, exe="test", algo_id=0): with open(os.path.join(dir, "test.c"), "w") as file: file.write(src) subprocess.run( - ["gcc"] - + [ + ["gcc"] + [ "test.c", "-I", cuda_include_dir(), @@ -206,9 +205,7 @@ def write_triton_kernels(dir, src, util_src): return kernel_path -def _compile_kernel( - dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path -): +def _compile_kernel(dir, signature, kernel_name, out_name, out_path, num_warps, grid, kernel_path): compiler_path = os.path.join(triton.tools.__path__[0], "compile.py") subprocess.run( @@ -276,9 +273,7 @@ def link_aot_kernels(dir): # link all desired configs h_files = glob.glob(os.path.join(dir, "*.h")) - subprocess.run( - [sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir - ) + subprocess.run([sys.executable, linker_path] + h_files + ["-o", "kernel"], check=True, cwd=dir) def generate_matmul_test_data(dir, M, N, K): @@ -315,9 +310,7 @@ def test_compile_link_matmul_no_specialization(): # run test case env = os.environ.copy() env["LD_LIBRARY_PATH"] = tmp_dir - subprocess.run( - ["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir - ) + subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir) # read data and compare against reference c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) @@ -334,9 +327,7 @@ def test_compile_link_matmul(): BM, BN, BK = 16, 16, 16 kernel_path = write_triton_kernels(tmp_dir, kernel_src, kernel_utils_src) - compile_aot_kernels( - tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"] - ) + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) link_aot_kernels(tmp_dir) # compile test case @@ -350,9 +341,7 @@ def test_compile_link_matmul(): # run test case env = os.environ.copy() env["LD_LIBRARY_PATH"] = tmp_dir - subprocess.run( - ["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir - ) + subprocess.run(["./test", a_path, b_path, c_path], env=env, check=True, cwd=tmp_dir) # read data and compare against reference c = np.genfromtxt(c_path, delimiter=",", dtype=np.int32) @@ -413,9 +402,7 @@ def test_compile_link_autotune_matmul(): for ts in tile_sizes: BM, BN, BK = ts[0], ts[1], ts[2] - compile_aot_kernels( - tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"] - ) + compile_aot_kernels(tmp_dir, kernel_path, dtype, BM, BN, BK, ha_hb_hints=["", ":16"]) link_aot_kernels(tmp_dir) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index aca365dad396..55484acd5bf2 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -45,12 +45,12 @@ "tools", ] - # ------------------------------------- # misc. utilities that don't fit well # into any specific module # ------------------------------------- + def cdiv(x: int, y: int): return (x + y - 1) // y diff --git a/python/triton/common/backend.py b/python/triton/common/backend.py index 1aa8c9fe481d..f56cc7c9b1b2 100644 --- a/python/triton/common/backend.py +++ b/python/triton/common/backend.py @@ -1,4 +1,3 @@ - import functools import hashlib import importlib @@ -16,6 +15,7 @@ class BaseBackend: + def __init__(self, device_type: str) -> None: self.device_type = device_type @@ -154,7 +154,7 @@ def compute_core_version_key(): libtriton_hash = hashlib.sha1() with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: while True: - chunk = f.read(1024 ** 2) + chunk = f.read(1024**2) if not chunk: break libtriton_hash.update(chunk) diff --git a/python/triton/common/build.py b/python/triton/common/build.py index eb1096402c83..0795b45b14fb 100644 --- a/python/triton/common/build.py +++ b/python/triton/common/build.py @@ -86,9 +86,15 @@ def _build(name, src, srcdir): py_include_dir = sysconfig.get_paths(scheme=scheme)["include"] if is_hip(): - ret = subprocess.check_call([cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", f"-L{hip_lib_dir}", "-lamdhip64", "-o", so]) + ret = subprocess.check_call([ + cc, src, f"-I{hip_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", + f"-L{hip_lib_dir}", "-lamdhip64", "-o", so + ]) else: - cc_cmd = [cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", "-o", so] + cc_cmd = [ + cc, src, "-O3", f"-I{cu_include_dir}", f"-I{py_include_dir}", f"-I{srcdir}", "-shared", "-fPIC", "-lcuda", + "-o", so + ] cc_cmd += [f"-L{dir}" for dir in cuda_lib_dirs] ret = subprocess.check_call(cc_cmd) diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index 89f46a1fe1ca..fd0665e1e549 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,5 +1,8 @@ -from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, - get_arch_default_num_warps, instance_descriptor) +from .compiler import (CompiledKernel, compile, get_arch_default_num_stages, get_arch_default_num_warps, + instance_descriptor) from .errors import CompilationError -__all__ = ["compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", "get_arch_default_num_stages"] +__all__ = [ + "compile", "instance_descriptor", "CompiledKernel", "CompilationError", "get_arch_default_num_warps", + "get_arch_default_num_stages" +] diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 4d5ade791f93..03982a47de09 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -10,8 +10,7 @@ from ..language import constexpr, tensor # ideally we wouldn't need any runtime component from ..runtime import JITFunction -from .errors import (CompilationError, CompileTimeAssertionFailure, - UnsupportedLanguageConstruct) +from .errors import (CompilationError, CompileTimeAssertionFailure, UnsupportedLanguageConstruct) def mangle_ty(ty): @@ -68,7 +67,10 @@ def _check_fn_args(node, fn, args): if fn.noinline: for idx, arg in enumerate(args): if not _is_constexpr(arg) and not _is_triton_scalar(arg): - raise UnsupportedLanguageConstruct(fn.src, node, f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}') + raise UnsupportedLanguageConstruct( + fn.src, node, + f'Function {fn.__name__} is marked noinline, but was called with non-scalar argument {fn.arg_names[idx]}:{arg}' + ) def _get_fn_file_line(fn): @@ -89,6 +91,7 @@ def _get_fn_file_line(fn): class enter_sub_region: + def __init__(self, generator): self.generator = generator @@ -109,6 +112,7 @@ def __exit__(self, *args, **kwargs): # Check if the given syntax node has an "early" return class ContainsReturnChecker(ast.NodeVisitor): + def __init__(self, gscope): self.gscope = gscope @@ -199,9 +203,10 @@ def visit_Call(self, node: ast.Call) -> bool: class CodeGenerator(ast.NodeVisitor): - def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, - module=None, is_kernel=False, function_types: Optional[Dict] = None, - debug=False, noinline=False, file_name: Optional[str] = None, begin_line=0): + + def __init__(self, context, prototype, gscope, attributes, constants, function_name, target, module=None, + is_kernel=False, function_types: Optional[Dict] = None, debug=False, noinline=False, + file_name: Optional[str] = None, begin_line=0): self.context = context self.builder = ir.builder(context) self.file_name = file_name @@ -237,8 +242,10 @@ def __init__(self, context, prototype, gscope, attributes, constants, function_n )) def _define_name_lookup(self): + def local_lookup(name: str, absent): - value = self.lscope.get(name, absent) # this needs to be re-fetched from `self` every time, because it gets switched occasionally + # this needs to be re-fetched from `self` every time, because it gets switched occasionally + value = self.lscope.get(name, absent) if value is not absent and name not in self.local_defs: self.global_uses[name] = value return value @@ -255,8 +262,7 @@ def name_lookup(name: str) -> Any: return name_lookup - def set_value(self, name: str, - value: Union[tensor, constexpr]) -> None: + def set_value(self, name: str, value: Union[tensor, constexpr]) -> None: ''' This function: called by visit_Assign() & visit_FunctionDef() to store left value (lvalue) 1. record local defined name (FIXME: should consider control flow) @@ -338,7 +344,8 @@ def visit_FunctionDef(self, node): self.visit(init_node) # initialize function visibility = "public" if self.is_kernel else "private" - self.fn = self.builder.get_or_insert_function(self.module, self.function_name, self.prototype.to_ir(self.builder), visibility, self.noinline) + self.fn = self.builder.get_or_insert_function(self.module, self.function_name, + self.prototype.to_ir(self.builder), visibility, self.noinline) self.module.push_back(self.fn) entry = self.fn.add_entry_block() arg_values = [] @@ -469,12 +476,23 @@ def visit_BinOp(self, node): rhs = self.visit(node.right) method_name = self._method_name_for_bin_op.get(type(node.op)) if method_name is None: - raise UnsupportedLanguageConstruct(None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + raise UnsupportedLanguageConstruct( + None, node, "AST binary operator '{}' is not (currently) implemented.".format(node.op.__name__)) return self._apply_binary_method(method_name, lhs, rhs) + _method_name_for_bin_op: Dict[Type[ast.operator], str] = { - ast.Add: '__add__', ast.Sub: '__sub__', ast.Mult: '__mul__', ast.Div: '__truediv__', - ast.FloorDiv: '__floordiv__', ast.Mod: '__mod__', ast.Pow: '__pow__', - ast.LShift: '__lshift__', ast.RShift: '__rshift__', ast.BitAnd: '__and__', ast.BitOr: '__or__', ast.BitXor: '__xor__', + ast.Add: '__add__', + ast.Sub: '__sub__', + ast.Mult: '__mul__', + ast.Div: '__truediv__', + ast.FloorDiv: '__floordiv__', + ast.Mod: '__mod__', + ast.Pow: '__pow__', + ast.LShift: '__lshift__', + ast.RShift: '__rshift__', + ast.BitAnd: '__and__', + ast.BitOr: '__or__', + ast.BitXor: '__xor__', } def visit_then_else_blocks(self, node, liveins, then_block, else_block): @@ -508,7 +526,8 @@ def visit_then_else_blocks(self, node, liveins, then_block, else_block): if name in then_defs or name in else_defs: names.append(name) ret_types.append(then_defs[name].type if name in then_defs else else_defs[name].type) - ir_ret_types.append(then_defs[name].handle.get_type() if name in then_defs else else_defs[name].handle.get_type()) + ir_ret_types.append(then_defs[name].handle.get_type() if name in + then_defs else else_defs[name].handle.get_type()) # variable defined in then but not in else if name in then_defs and name not in else_defs: else_defs[name] = liveins[name] @@ -602,8 +621,7 @@ def visit_If(self, node): contains_return = ContainsReturnChecker(self.gscope).visit(node) if self.scf_stack and contains_return: raise UnsupportedLanguageConstruct( - None, node, - "Cannot have `return` statements inside `while` or `for` statements in triton " + None, node, "Cannot have `return` statements inside `while` or `for` statements in triton " "(note that this also applies to `return` statements that are inside functions " "transitively called from within `while`/`for` statements)") elif self.scf_stack or not contains_return: @@ -612,10 +630,13 @@ def visit_If(self, node): self.visit_if_top_level(cond, node) else: cond = _unwrap_if_constexpr(cond) - if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: raise UnsupportedLanguageConstruct( - None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( - ', '.join(_.__name__ for _ in _condition_types), type(cond).__name__)) + None, node, + "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) if cond: self.visit_compound_statement(node.body) else: @@ -662,10 +683,13 @@ def visit_IfExp(self, node): return language.core.tensor(if_op.get_result(0), ret_type) if ret_type_ir else None else: cond = _unwrap_if_constexpr(cond) - if type(cond) not in _condition_types: # not isinstance - we insist the real thing, no subclasses and no ducks + # not isinstance - we insist the real thing, no subclasses and no ducks + if type(cond) not in _condition_types: raise UnsupportedLanguageConstruct( - None, node, "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( - ', '.join(_.__name__ for _ in _condition_types), type(cond).__name__)) + None, node, + "`if` conditionals can only accept values of type {{{}}}, not objects of type {}".format( + ', '.join(_.__name__ for _ in _condition_types), + type(cond).__name__)) if cond: return self.visit(node.body) else: @@ -687,8 +711,10 @@ def visit_Compare(self, node): return constexpr(lhs_value is not rhs_value) method_name = self._method_name_for_comp_op.get(type(node.ops[0])) if method_name is None: - raise UnsupportedLanguageConstruct(None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) + raise UnsupportedLanguageConstruct( + None, node, "AST comparison operator '{}' is not (currently) implemented.".format(node.ops[0].__name__)) return self._apply_binary_method(method_name, lhs, rhs) + _method_name_for_comp_op: Dict[Type[ast.cmpop], str] = { ast.Eq: '__eq__', ast.NotEq: '__ne__', ast.Lt: '__lt__', ast.LtE: '__le__', ast.Gt: '__gt__', ast.GtE: '__ge__' } @@ -697,11 +723,15 @@ def visit_UnaryOp(self, node): op = self.visit(node.operand) fn = self._method_name_for_unary_op.get(type(node.op)) if fn is None: - raise UnsupportedLanguageConstruct(None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__)) + raise UnsupportedLanguageConstruct( + None, node, "AST unary operator '{}' is not (currently) implemented.".format(node.op.__name__)) if _is_triton_tensor(op): return getattr(op, fn)(_builder=self.builder) return getattr(op, fn)() - _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = {ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__'} + + _method_name_for_unary_op: Dict[Type[ast.unaryop], str] = { + ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' + } def visit_While(self, node): with enter_sub_region(self) as sr: @@ -796,9 +826,7 @@ def visit_For(self, node): iter_args = [self.visit(arg) for arg in node.iter.args] if IteratorClass == language.static_range: iterator = IteratorClass(*iter_args) - static_range = range(iterator.start.value, - iterator.end.value, - iterator.step.value) + static_range = range(iterator.start.value, iterator.end.value, iterator.step.value) for i in static_range: self.lscope[node.target.id] = constexpr(i) self.visit_compound_statement(node.body) @@ -935,8 +963,7 @@ def visit_Assert(self, node) -> Any: def call_JitFunction(self, fn: JITFunction, args, kwargs): args = inspect.getcallargs(fn.fn, *args, **kwargs) args = [args[name] for name in fn.arg_names] - args = [arg if _is_triton_tensor(arg) - else constexpr(arg) for arg in args] + args = [arg if _is_triton_tensor(arg) else constexpr(arg) for arg in args] # generate function def attributes = dict() constexprs = [i for i, arg in enumerate(args) if _is_constexpr(arg)] @@ -954,8 +981,9 @@ def call_JitFunction(self, fn: JITFunction, args, kwargs): debug = self.debug if fn.debug is None else fn.debug file_name, begin_line = _get_fn_file_line(fn) generator = CodeGenerator(self.context, prototype, gscope, attributes, constants, module=self.module, - function_name=fn_name, function_types=self.function_ret_types, debug=debug, noinline=fn.noinline, - file_name=file_name, begin_line=begin_line, target=self.builder.target) + function_name=fn_name, function_types=self.function_ret_types, debug=debug, + noinline=fn.noinline, file_name=file_name, begin_line=begin_line, + target=self.builder.target) generator.visit(fn.parse()) callee_ret_type = generator.last_ret_type self.function_ret_types[fn_name] = callee_ret_type @@ -983,7 +1011,7 @@ def visit_Call(self, node): kws = dict(self.visit(keyword) for keyword in node.keywords) args = [self.visit(arg) for arg in node.args] - if fn is language.core.device_assert: # TODO: this should not be so hardcoded + if fn is language.core.device_assert: # TODO: this should not be so hardcoded if not self.debug: return if isinstance(fn, JITFunction): @@ -1004,16 +1032,21 @@ def visit_Constant(self, node): def visit_BoolOp(self, node: ast.BoolOp): if len(node.values) != 2: - raise UnsupportedLanguageConstruct(None, node, "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") + raise UnsupportedLanguageConstruct( + None, node, + "chained boolean operators (A or B or C) are not supported; use parentheses to split the chain.") lhs = self.visit(node.values[0]) rhs = self.visit(node.values[1]) method_name = self._method_name_for_bool_op.get(type(node.op)) if method_name is None: - raise UnsupportedLanguageConstruct(None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) + raise UnsupportedLanguageConstruct( + None, node, "AST boolean operator '{}' is not (currently) implemented.".format(node.op.__name__)) return self._apply_binary_method(method_name, lhs, rhs) + _method_name_for_bool_op: Dict[Type[ast.boolop], str] = {ast.And: 'logical_and', ast.Or: 'logical_or'} if sys.version_info < (3, 8): + def visit_NameConstant(self, node): return constexpr(node.value) @@ -1046,7 +1079,9 @@ def visit_JoinedStr(self, node): evaluated = self.visit(value.value) if not _is_constexpr(evaluated): raise UnsupportedLanguageConstruct( - None, node, "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + str(type(evaluated))) + None, node, + "Cannot evaluate f-string containing non-constexpr conversion values, found conversion of type " + + str(type(evaluated))) values[i] = ("{}" if conversion_code < 0 else "{!" + chr(conversion_code) + "}").format(evaluated.value) else: raise AssertionError("encountered unexpected node of type {} in a JoinedStr node".format(type(value))) @@ -1088,7 +1123,9 @@ def execute_static_assert(self, node: ast.Call) -> None: passed = _unwrap_if_constexpr(self.visit(node.args[0])) if not isinstance(passed, bool): - raise NotImplementedError("Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values") + raise NotImplementedError( + "Assertion condition could not be determined at compile-time. Make sure that it depends only on `constexpr` values" + ) if not passed: if arg_count == 1: message = "" @@ -1175,10 +1212,9 @@ def ast_to_ttir(fn, signature, specialization, constants, debug, target): file_name, begin_line = _get_fn_file_line(fn) prototype = language.function_type([], arg_types) - generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, - function_name=function_name, attributes=new_attrs, - is_kernel=True, debug=debug, file_name=file_name, begin_line=begin_line, - target=target) + generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, + attributes=new_attrs, is_kernel=True, debug=debug, file_name=file_name, + begin_line=begin_line, target=target) try: generator.visit(fn.parse()) except CompilationError as e: diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index ac751d5a4e12..1f0cb4925004 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -11,10 +11,8 @@ from dataclasses import dataclass -from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, - compile_ptx_to_cubin, get_env_vars, get_num_warps, - get_shared_memory_size, ir, runtime, - translate_llvmir_to_ptx, +from .._C.libtriton.triton import (ClusterInfo, TMAInfos, add_external_libs, compile_ptx_to_cubin, get_env_vars, + get_num_warps, get_shared_memory_size, ir, runtime, translate_llvmir_to_ptx, translate_triton_gpu_to_llvmir) from ..common.backend import get_backend, get_cuda_version_key, path_to_ptxas from ..common.build import is_hip @@ -23,13 +21,11 @@ from ..runtime.autotuner import OutOfResources from ..runtime.cache import get_cache_manager, get_dump_manager, get_override_manager from ..runtime.driver import driver -from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, - get_device_capability) +from ..runtime.jit import (JITFunction, get_cuda_stream, get_current_device, get_device_capability) from ..tools.disasm import get_sass from .code_generator import ast_to_ttir from .make_launcher import make_stub -from .utils import (InfoFromBackendForTensorMap, TensorMapManager, - get_ids_of_tensormaps, parse_tma_info) +from .utils import (InfoFromBackendForTensorMap, TensorMapManager, get_ids_of_tensormaps, parse_tma_info) @dataclass @@ -44,6 +40,7 @@ def _is_cuda(target): class LazyDict(dict): + def __getitem__(self, key): val = dict.__getitem__(self, key) if callable(val): @@ -94,8 +91,8 @@ def ttir_to_ttgir(mod, num_warps, num_ctas, target): return mod -def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, - cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue): +def optimize_ttgir(mod, num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, + enable_persistent, optimize_epilogue): is_cuda = _is_cuda(target) if is_cuda: capability = target.capability @@ -173,6 +170,7 @@ def ttgir_to_llir(mod, extern_libs, target, tma_infos): # PTX translation + @functools.lru_cache() def ptx_get_version(cuda_version) -> int: ''' @@ -253,7 +251,8 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs): enable_persistent = kwargs.get("enable_persistent", False) debug = kwargs.get("debug", False) # Get unique key for the compiled code - get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) + get_conf_key = lambda conf: (sorted(conf.divisible_by_16), sorted(conf.equal_to_1), + sorted(conf.ids_of_folded_args), sorted(conf.divisible_by_8)) configs_key = [get_conf_key(conf) for conf in configs] env_vars_list = [f"{env_vars[k]}" for k in sorted(env_vars.keys())] key = f"{fn.cache_key}-{version_key}-{''.join(signature.values())}-{configs_key}-{constants}-{num_warps}-{num_stages}-{num_ctas}-{num_stages}-{enable_warp_specialization}-{enable_persistent}-{debug}-{target}-{env_vars_list}" @@ -299,12 +298,14 @@ def make_hash(fn, target, env_vars, device_backend, **kwargs): def _get_jsonable_constants(constants): + def _is_jsonable(x): try: json.dumps(x) return True except (TypeError, OverflowError): return False + serialized_constants = {} for constant in constants: if _is_jsonable(constants[constant]): @@ -319,7 +320,9 @@ def parse_mlir_module(path, context): return module -instance_descriptor = namedtuple("instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], defaults=[set(), set(), set(), set()]) +instance_descriptor = namedtuple("instance_descriptor", + ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"], + defaults=[set(), set(), set(), set()]) def get_cuda_capability(capability): @@ -355,10 +358,8 @@ def get_arch_default_num_stages(device_type, capability=None): def add_cuda_stages(target, extern_libs, stages): - stages["ptx"] = (lambda path: Path(path).read_text(), - lambda src: llir_to_ptx(src, target)) - stages["cubin"] = (lambda path: Path(path).read_bytes(), - lambda src: ptx_to_cubin(src, target)) + stages["ptx"] = (lambda path: Path(path).read_text(), lambda src: llir_to_ptx(src, target)) + stages["cubin"] = (lambda path: Path(path).read_bytes(), lambda src: ptx_to_cubin(src, target)) def compile(fn, **kwargs): @@ -401,7 +402,8 @@ def compile(fn, **kwargs): # build architecture descriptor if device_type == "cuda": _device_backend = get_backend(device_type) - target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, enable_fp_fusion=enable_fp_fusion) + target = CudaTargetDescriptor(capability=get_cuda_capability(capability), num_warps=num_warps, + enable_fp_fusion=enable_fp_fusion) else: _device_backend = get_backend(device_type) assert _device_backend @@ -409,11 +411,12 @@ def compile(fn, **kwargs): # build compilation stages stages = dict() stages["ast"] = (lambda path: fn, None) - stages["ttir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target)) + stages["ttir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttir( + ast_to_ttir(src, signature, configs[0], constants, debug=debug, target=target), target)) if is_cuda: - stages["ttgir"] = (lambda path: parse_mlir_module(path, context), - lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, enable_warp_specialization, enable_persistent, optimize_epilogue)) + stages["ttgir"] = (lambda path: parse_mlir_module(path, context), lambda src: optimize_ttgir( + ttir_to_ttgir(src, num_warps, num_ctas, target), num_stages, num_warps, num_ctas, target, cluster_info, + enable_warp_specialization, enable_persistent, optimize_epilogue)) stages["llir"] = (lambda path: Path(path).read_text(), lambda src: ttgir_to_llir(src, extern_libs, target, tma_infos)) add_cuda_stages(target, extern_libs, stages) @@ -451,7 +454,8 @@ def compile(fn, **kwargs): if ir_name == 'ttgir': num_warps_matches = re.findall(ttgir_num_warps_pattern, src) assert len(num_warps_matches) == 1, "Expected exactly one match for num_warps" - assert "num_warps" not in kwargs or int(num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile" + assert "num_warps" not in kwargs or int( + num_warps_matches[0]) == num_warps, "num_warps in ttgir does not match num_warps in compile" num_warps = int(num_warps_matches[0]) param_tys = [convert_type_repr(ty) for ty in types] signature = {k: v for k, v in enumerate(param_tys)} @@ -461,8 +465,10 @@ def compile(fn, **kwargs): fn_cache_manager = get_cache_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs)) # managers used to dump and override IR for debugging enable_override = os.environ.get("TRITON_KERNEL_OVERRIDE", "0") == "1" - fn_override_manager = get_override_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True)) - fn_dump_manager = get_dump_manager(make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True)) + fn_override_manager = get_override_manager( + make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True)) + fn_dump_manager = get_dump_manager( + make_hash(fn, target, get_env_vars(), _device_backend, **kwargs, ignore_version=True)) # determine name and extension type of provided function if isinstance(fn, JITFunction): @@ -475,9 +481,7 @@ def compile(fn, **kwargs): metadata_filename = f"{name}.json" # The group is addressed by the metadata - metadata_group = fn_cache_manager.get_group( - metadata_filename - ) or {} + metadata_group = fn_cache_manager.get_group(metadata_filename) or {} metadata_path = metadata_group.get(metadata_filename) @@ -485,17 +489,18 @@ def compile(fn, **kwargs): with open(metadata_path) as f: metadata = json.load(f) if 'tensormaps_info' in metadata: - metadata['tensormaps_info'] = [ - InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] + metadata['tensormaps_info'] = [InfoFromBackendForTensorMap(e) for e in metadata['tensormaps_info']] else: - metadata = {"num_warps": num_warps, - "num_ctas": num_ctas, - "num_stages": num_stages, - "enable_warp_specialization": enable_warp_specialization, - "enable_persistent": enable_persistent, - "constants": _get_jsonable_constants(constants), - "debug": debug, - "target": target, } + metadata = { + "num_warps": num_warps, + "num_ctas": num_ctas, + "num_stages": num_stages, + "enable_warp_specialization": enable_warp_specialization, + "enable_persistent": enable_persistent, + "constants": _get_jsonable_constants(constants), + "debug": debug, + "target": target, + } metadata.update(get_env_vars()) if ext == "ptx": assert "shared" in kwargs, "ptx compilation must provide shared memory size" @@ -567,10 +572,7 @@ def compile(fn, **kwargs): ids_of_folded_args = tuple([int(k) for k in configs[0].ids_of_folded_args]) if isinstance(fn, JITFunction) else () if "clusterDims" not in metadata: - metadata["clusterDims"] = [ - cluster_info.clusterDimX, - cluster_info.clusterDimY, - cluster_info.clusterDimZ] + metadata["clusterDims"] = [cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ] if len(tma_infos) > 0: metadata["tensormaps_info"] = parse_tma_info(tma_infos, ids_of_folded_args) @@ -584,7 +586,10 @@ def compile(fn, **kwargs): fn.tensormaps_info = metadata["tensormaps_info"] ids_of_const_exprs = tuple(fn.constexprs) if isinstance(fn, JITFunction) else () - ids = {"ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": ids_of_const_exprs} + ids = { + "ids_of_tensormaps": ids_of_tensormaps, "ids_of_folded_args": ids_of_folded_args, "ids_of_const_exprs": + ids_of_const_exprs + } # cache manager if is_cuda: so_path = make_stub(name, signature, constants, ids, enable_warp_specialization=enable_warp_specialization) @@ -592,7 +597,8 @@ def compile(fn, **kwargs): so_path = _device_backend.make_launcher_stub(name, signature, constants, ids) # write-back metadata, if it didn't come from the cache if metadata_path is None: - metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, binary=False) + metadata_group[metadata_filename] = fn_cache_manager.put(json.dumps(metadata, default=vars), metadata_filename, + binary=False) fn_cache_manager.put_group(metadata_filename, metadata_group) # return handle to compiled kernel @@ -640,10 +646,7 @@ def _init_handles(self): if self.device_type in ["cuda"]: device = get_current_device() - bin_path = { - driver.HIP: "hsaco_path", - driver.CUDA: "cubin" - }[driver.backend] + bin_path = {driver.HIP: "hsaco_path", driver.CUDA: "cubin"}[driver.backend] max_shared = driver.utils.get_device_properties(device)["max_shared_mem"] fn_load_binary = driver.utils.load_binary else: @@ -691,4 +694,5 @@ def runner(*args, stream=None): self.c_wrapper(grid[0], grid[1], grid[2], self.num_warps, self.num_ctas, self.clusterDims[0], self.clusterDims[1], self.clusterDims[2], self.shared, stream, self.cu_function, CompiledKernel.launch_enter_hook, CompiledKernel.launch_exit_hook, self, *args_expand) + return runner diff --git a/python/triton/compiler/make_launcher.py b/python/triton/compiler/make_launcher.py index 856240675db0..52a8f74a11eb 100644 --- a/python/triton/compiler/make_launcher.py +++ b/python/triton/compiler/make_launcher.py @@ -40,6 +40,7 @@ def make_stub(name, signature, constants, ids, **kwargs): else: return cache_path + # ----- source code generation -------- @@ -100,7 +101,10 @@ def format_of(ty): # generate glue code folded_without_constexprs = [c for c in ids['ids_of_folded_args'] if c not in ids['ids_of_const_exprs']] - params = [i for i in signature.keys() if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs)] + params = [ + i for i in signature.keys() + if i >= desc_start_idx or (i not in constants and i not in folded_without_constexprs) + ] src = f""" #include \"cuda.h\" #include diff --git a/python/triton/compiler/utils.py b/python/triton/compiler/utils.py index d4b24a93ee49..ef629c75a6bc 100644 --- a/python/triton/compiler/utils.py +++ b/python/triton/compiler/utils.py @@ -158,19 +158,21 @@ def getTMADescArgIdx(self): # dtype:cuda.CUtensorMapDataType | int def bytes_from_type(self, dtype): - return {driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4, - driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4}[dtype] + return { + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT8"]: 1, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_UINT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_INT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT64"]: 8, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_BFLOAT16"]: 2, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_FLOAT32_FTZ"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32"]: 4, + driver.utils.CUtensorMapDataType["CU_TENSOR_MAP_DATA_TYPE_TFLOAT32_FTZ"]: 4 + }[dtype] def getTensorMapDataType(self): return self.tensorDataType @@ -259,22 +261,29 @@ def tensormap(self, args): self.getInterleave(), self.getSwizzle(), self.getL2Promotion(), - self.getOobFill() + self.getOobFill(), ) # make hashable to use as partial key in cache def __hash__(self): - return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), tuple(self.globalStridesArgIdx), self.tensorDataType, - self.tensorRank, tuple(self.boxDims), tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill)) + return hash((self.ids_of_folded_args, self.globalAddressArgIdx, tuple(self.globalDimsArgIdx), + tuple(self.globalStridesArgIdx), self.tensorDataType, self.tensorRank, tuple(self.boxDims), + tuple(self.elementStrides), self.interleave, self.swizzle, self.l2Promotion, self.oobFill)) def __eq__(self, other): if not isinstance(other, self.__class__): return False - return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, self.l2Promotion, self.oobFill) == ( - other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, other.oobFill) + return (self.ids_of_folded_args, self.globalAddressArgIdx, self.globalDimsArgIdx, self.globalStridesArgIdx, + self.tensorDataType, self.tensorRank, self.boxDims, self.elementStrides, self.interleave, self.swizzle, + self.l2Promotion, + self.oobFill) == (other.ids_of_folded_args, other.globalAddressArgIdx, other.globalDimsArgIdx, + other.globalStridesArgIdx, other.tensorDataType, other.tensorRank, other.boxDims, + other.elementStrides, other.interleave, other.swizzle, other.l2Promotion, + other.oobFill) class TensorMapManager: + def __init__(self): self.tensormaps_device = {} @@ -286,8 +295,7 @@ def __getitem__(self, key: tuple): t_tensormap = e.tensormap(args) TENSORMAP_SIZE_IN_BYTES = 128 t_tensormap_device = driver.utils.cuMemAlloc(TENSORMAP_SIZE_IN_BYTES) - driver.utils.cuMemcpyHtoD( - t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) + driver.utils.cuMemcpyHtoD(t_tensormap_device, t_tensormap, TENSORMAP_SIZE_IN_BYTES) self.tensormaps_device[key] = t_tensormap_device return int(self.tensormaps_device[key]) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 6b719c684cfd..27394acf918a 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -109,7 +109,6 @@ uint32_to_uniform_float, ) - __all__ = [ "TRITON_MAX_TENSOR_NUMEL", "abs", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 5b86ce5589ac..2919689a0cc4 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -22,10 +22,8 @@ def builtin(fn: T) -> T: @wraps(fn) def wrapper(*args, **kwargs): if "_builder" not in kwargs or kwargs["_builder"] is None: - raise ValueError( - "Did you forget to add @triton.jit ? " - "(`_builder` argument must be provided outside of JIT functions.)" - ) + raise ValueError("Did you forget to add @triton.jit ? " + "(`_builder` argument must be provided outside of JIT functions.)") return fn(*args, **kwargs) setattr(wrapper, TRITON_BUILTIN, True) @@ -54,7 +52,7 @@ def _to_tensor(x, builder): else: raise RuntimeError(f'Nonrepresentable integer {x}.') elif isinstance(x, float): - min_float32 = 2 ** -126 + min_float32 = 2**-126 max_float32 = (2 - 2**-23) * 2**127 abs_x = __builtins__['abs'](x) if abs_x == float("inf") or\ @@ -229,7 +227,7 @@ def __ne__(self, other: dtype): return not self.__eq__(other) def __hash__(self): - return hash((self.name,)) + return hash((self.name, )) @property def scalar(self): @@ -279,6 +277,7 @@ def __repr__(self): class pointer_type(dtype): + def __init__(self, element_ty: dtype, address_space: int = 1): if not isinstance(element_ty, dtype): raise TypeError('element_ty is a {type(element_ty).__name__}.') @@ -313,6 +312,7 @@ def scalar(self): class block_type(dtype): + def __init__(self, element_ty: dtype, shape: List): self.element_ty = element_ty @@ -363,6 +363,7 @@ def scalar(self): class function_type(dtype): + def __init__(self, ret_types: List[dtype], param_types: List[dtype]) -> None: self.ret_types = ret_types self.param_types = param_types @@ -511,7 +512,7 @@ def __invert__(self): return constexpr(~self.value) def __pow__(self, other): - return constexpr(self.value ** other.value) + return constexpr(self.value**other.value) def __rshift__(self, other): return constexpr(self.value >> other.value) @@ -527,6 +528,7 @@ def __call__(self, *args, **kwds): class tensor: + def __init__(self, handle, type: dtype): # IR handle self.handle = handle @@ -993,6 +995,7 @@ def expand_dims(input, axis, _builder=None): ret = semantic.expand_dims(ret, a, _builder) return ret + # ----------------------- # Linear Algebra # ----------------------- @@ -1141,6 +1144,7 @@ def advance(base: tensor, offsets, _builder=None): """ return semantic.advance(base, offsets, _builder) + # ----------------------- # Atomic Memory Operations # ----------------------- @@ -1253,6 +1257,7 @@ def atomic_xor(pointer, val, mask=None, sem=None, scope=None, _builder=None): # Conditioning # ----------------------- + @builtin def where(condition, x, y, _builder=None): """ @@ -1280,6 +1285,7 @@ def where(condition, x, y, _builder=None): # Math # ----------------------- + @builtin def umulhi(x, y, _builder=None): """ @@ -1373,6 +1379,7 @@ def abs(x, _builder=None): # Reductions # ----------------------- + def _add_reduction_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]: def _decorator(func: T) -> T: @@ -1411,8 +1418,7 @@ def reduce(input, axis, combine_fn, _builder=None, _generator=None): """ if isinstance(input, tensor): - return reduce((input,), axis, combine_fn, - _builder=_builder, _generator=_generator)[0] + return reduce((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0] def make_combine_region(reduce_op): in_scalar_tys = [t.type.scalar for t in input] @@ -1422,14 +1428,14 @@ def make_combine_region(reduce_op): with _insertion_guard(_builder): param_types = [ty.to_ir(_builder) for ty in prototype.param_types] block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) - for i, ty in enumerate(prototype.param_types)] + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] else: handles = [r.handle for r in results] _builder.create_reduce_ret(*handles) + if axis is not None: axis = _constexpr_to_value(axis) return semantic.reduction(input, axis, make_combine_region, _builder) @@ -1459,8 +1465,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None index = expand_dims(index, axes_to_expand, _builder=_builder) index = broadcast_to(index, input.shape, _builder=_builder) - rvalue, rindices = reduce((input, index), axis, combine_fn, - _builder=_builder, _generator=_generator) + rvalue, rindices = reduce((input, index), axis, combine_fn, _builder=_builder, _generator=_generator) return rvalue, rindices @@ -1468,6 +1473,7 @@ def _reduce_with_indices(input, axis, combine_fn, _builder=None, _generator=None # Scans # ----------------------- + def _add_scan_docstr(name: str, return_indices_arg: str = None, tie_break_arg: str = None) -> Callable[[T], T]: def _decorator(func: T) -> T: @@ -1492,8 +1498,7 @@ def associative_scan(input, axis, combine_fn, _builder=None, _generator=None): """ if isinstance(input, tensor): - return associative_scan((input,), axis, combine_fn, - _builder=_builder, _generator=_generator)[0] + return associative_scan((input, ), axis, combine_fn, _builder=_builder, _generator=_generator)[0] def make_combine_region(scan_op): in_scalar_tys = [t.type.scalar for t in input] @@ -1503,17 +1508,18 @@ def make_combine_region(scan_op): with _insertion_guard(_builder): param_types = [ty.to_ir(_builder) for ty in prototype.param_types] block = _builder.create_block_with_parent(region, param_types) - args = [tensor(block.arg(i), ty) - for i, ty in enumerate(prototype.param_types)] + args = [tensor(block.arg(i), ty) for i, ty in enumerate(prototype.param_types)] results = _generator.call_JitFunction(combine_fn, args, kwargs={}) if isinstance(results, tensor): handles = [results.handle] else: handles = [r.handle for r in results] _builder.create_scan_ret(*handles) + axis = _constexpr_to_value(axis) return semantic.associative_scan(input, axis, make_combine_region, _builder) + # ----------------------- # Compiler Hint Ops # ----------------------- @@ -1576,6 +1582,8 @@ def max_constancy(input, values, _builder=None): raise TypeError(f"values element {i} must have type `constexpr[int]`, got `constexpr[{type(d.value)}]") values = [x.value for x in values] return semantic.max_constancy(input, values) + + # ----------------------- # Debugging functions # ----------------------- @@ -1715,12 +1723,12 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur broadcast_arg = dispatch_args[0] # Get the broadcast shape over all the arguments for i, item in enumerate(dispatch_args): - _, broadcast_arg = semantic.binary_op_type_checking_impl( - item, broadcast_arg, _builder, arithmetic_check=False) + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=False) # Change the shape of each argument based on the broadcast shape for i in range(len(dispatch_args)): - dispatch_args[i], _ = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder, arithmetic_check=False) + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=False) ret_shape = broadcast_arg.shape res_ty = block_type(dtype, ret_shape) call = _builder.create_inline_asm(asm, constraints, [t.handle for t in args], res_ty.to_ir(_builder), is_pure, pack) @@ -1733,7 +1741,6 @@ def inline_asm_elementwise(asm: str, constraints: str, args: list, dtype, is_pur class static_range: - """ Iterator that counts upward forever. @@ -1777,7 +1784,9 @@ def __next__(self): # Extern functions # ----------------------- -def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, is_pure: bool, _builder=None): + +def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, ret_shape: tuple, + is_pure: bool, _builder=None): ''' Dispatch a function to a library :param func: the function to dispatch @@ -1819,7 +1828,8 @@ def dispatch(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dic return tensor(func(lib_name, lib_path, symbol, arg_list, ret_type.to_ir(_builder), is_pure), ret_type) -def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _builder=None): +def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _builder=None): ''' Dispatch an elementwise function to a library :param lib_name: the name of the library @@ -1848,12 +1858,12 @@ def extern_elementwise(lib_name: str, lib_path: str, args: list, arg_type_symbol broadcast_arg = dispatch_args[0] # Get the broadcast shape over all the arguments for i, item in enumerate(dispatch_args): - _, broadcast_arg = semantic.binary_op_type_checking_impl( - item, broadcast_arg, _builder, arithmetic_check=arithmetic_check) + _, broadcast_arg = semantic.binary_op_type_checking_impl(item, broadcast_arg, _builder, + arithmetic_check=arithmetic_check) # Change the shape of each argument based on the broadcast shape for i in range(len(dispatch_args)): - dispatch_args[i], _ = semantic.binary_op_type_checking_impl( - dispatch_args[i], broadcast_arg, _builder, arithmetic_check=arithmetic_check) + dispatch_args[i], _ = semantic.binary_op_type_checking_impl(dispatch_args[i], broadcast_arg, _builder, + arithmetic_check=arithmetic_check) if not all_scalar: ret_shape = broadcast_arg.shape func = getattr(_builder, "create_extern_elementwise") diff --git a/python/triton/language/extra/cuda.py b/python/triton/language/extra/cuda.py index 8c4114739309..9400ae797887 100644 --- a/python/triton/language/extra/cuda.py +++ b/python/triton/language/extra/cuda.py @@ -3,16 +3,14 @@ @core.extern def globaltimer(_builder=None): - return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], - dtype=core.int64, is_pure=False, - pack=1, _builder=_builder) + return core.inline_asm_elementwise("mov.u64 $0, %globaltimer;", "=l", [], dtype=core.int64, is_pure=False, pack=1, + _builder=_builder) @core.extern def smid(_builder=None): - return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], - dtype=core.int32, is_pure=True, - pack=1, _builder=_builder) + return core.inline_asm_elementwise("mov.u32 $0, %smid;", "=r", [], dtype=core.int32, is_pure=True, pack=1, + _builder=_builder) @core.builtin diff --git a/python/triton/language/math.py b/python/triton/language/math.py index 6f8b0aced0e7..1cbad660d780 100644 --- a/python/triton/language/math.py +++ b/python/triton/language/math.py @@ -18,25 +18,27 @@ def libdevice_path(): @core.extern def clz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_clz", core.dtype("int32")), - (core.dtype("int64"),): ("__nv_clzll", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_clz", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_clzll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def popc(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_popc", core.dtype("int32")), - (core.dtype("int64"),): ("__nv_popcll", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_popc", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_popcll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def byte_perm(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("int32"), core.dtype("int32"), core.dtype("int32"),): ("__nv_byte_perm", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("int32")): ("__nv_byte_perm", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern @@ -73,1471 +75,1602 @@ def max(arg0, arg1, _builder=None): @core.extern def mulhi(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mulhi", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umulhi", core.dtype("uint32")), - (core.dtype("int64"), core.dtype("int64"),): ("__nv_mul64hi", core.dtype("int64")), - (core.dtype("uint64"), core.dtype("uint64"),): ("__nv_umul64hi", core.dtype("uint64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mulhi", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umulhi", core.dtype("uint32")), + (core.dtype("int64"), core.dtype("int64")): ("__nv_mul64hi", core.dtype("int64")), + (core.dtype("uint64"), core.dtype("uint64")): ("__nv_umul64hi", core.dtype("uint64")), + }, is_pure=True, _builder=_builder) @core.extern def mul24(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_mul24", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_umul24", core.dtype("uint32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_mul24", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_umul24", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) @core.extern def brev(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_brev", core.dtype("int32")), - (core.dtype("int64"),): ("__nv_brevll", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_brev", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_brevll", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def sad(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("int32"), core.dtype("int32"), core.dtype("uint32"),): ("__nv_sad", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32"),): ("__nv_usad", core.dtype("uint32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("int32"), core.dtype("int32"), core.dtype("uint32")): ("__nv_sad", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32"), core.dtype("uint32")): ("__nv_usad", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) @core.extern def abs(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_abs", core.dtype("int32")), - (core.dtype("int64"),): ("__nv_llabs", core.dtype("int64")), - (core.dtype("fp32"),): ("__nv_fabsf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_fabs", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_abs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_llabs", core.dtype("int64")), + (core.dtype("fp32"), ): ("__nv_fabsf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_fabs", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def floor(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_floorf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_floor", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_floorf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_floor", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcp64h(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_rcp64h", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_rcp64h", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rsqrt(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_rsqrtf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_rsqrt", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_rsqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rsqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ceil(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_ceil", core.dtype("fp64")), - (core.dtype("fp32"),): ("__nv_ceilf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_ceil", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_ceilf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def trunc(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_trunc", core.dtype("fp64")), - (core.dtype("fp32"),): ("__nv_truncf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_trunc", core.dtype("fp64")), + (core.dtype("fp32"), ): ("__nv_truncf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def exp2(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_exp2f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_exp2", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_exp2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def saturatef(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_saturatef", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_saturatef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fma_rn(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rn", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fma_rz(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rz", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fma_rd(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_rd", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fma_ru(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf_ru", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fast_dividef(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_fdividef", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_fdividef", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def div_rn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rn", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def div_rz(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rz", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def div_rd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_rd", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def div_ru(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdiv_ru", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_ddiv_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdiv_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_ddiv_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcp_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_frcp_rn", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_drcp_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcp_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_frcp_rz", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_drcp_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcp_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_frcp_rd", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_drcp_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcp_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_frcp_ru", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_drcp_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_frcp_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_drcp_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sqrt_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fsqrt_rn", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_dsqrt_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rn", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sqrt_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fsqrt_rz", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_dsqrt_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rz", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sqrt_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fsqrt_rd", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_dsqrt_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_rd", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sqrt_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fsqrt_ru", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_dsqrt_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fsqrt_ru", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_dsqrt_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sqrt(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_sqrtf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_sqrt", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_sqrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sqrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def add_rn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rn", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def add_rz(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rz", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def add_rd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_rd", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def add_ru(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dadd_ru", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fadd_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dadd_ru", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fadd_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def mul_rn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rn", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rn", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def mul_rz(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rz", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rz", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def mul_rd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_rd", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dmul_rd", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmul_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def mul_ru(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dmul_ru", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmul_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + arg1, + ], { + ( + core.dtype("fp64"), + core.dtype("fp64"), + ): ("__nv_dmul_ru", core.dtype("fp64")), + ( + core.dtype("fp32"), + core.dtype("fp32"), + ): ("__nv_fmul_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def double2float_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2float_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def double2float_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2float_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def double2float_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2float_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def double2float_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2float_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def double2int_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2int_rn", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2int_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2int_rz", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2int_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2int_rd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2int_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2int_ru", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2uint_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2uint_rn", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2uint_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2uint_rz", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2uint_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2uint_rd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2uint_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2uint_ru", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def int2double_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int2double_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def uint2double_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint2double_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def float2int_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2int_rn", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2int_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2int_rz", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2int_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2int_rd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2int_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2int_ru", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2int_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2uint_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2uint_rn", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rn", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2uint_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2uint_rz", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rz", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2uint_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2uint_rd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_rd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2uint_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2uint_ru", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2uint_ru", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def int2float_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int2float_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def int2float_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int2float_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def int2float_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int2float_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def int2float_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int2float_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def uint2float_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint2float_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def uint2float_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint2float_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def uint2float_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint2float_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def uint2float_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint2float_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def hiloint2double(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hiloint2double", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hiloint2double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def double2loint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2loint", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2loint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def double2hiint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2hiint", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2hiint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def float2ll_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ll_rn", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ll_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ll_rz", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ll_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ll_rd", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ll_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ll_ru", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ull_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ull_rn", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ull_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ull_rz", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ull_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ull_rd", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def float2ull_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float2ull_ru", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ll_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ll_rn", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ll_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ll_rz", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ll_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ll_rd", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ll_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ll_ru", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ll_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ull_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ull_rn", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rn", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ull_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ull_rz", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rz", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ull_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ull_rd", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_rd", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def double2ull_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double2ull_ru", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double2ull_ru", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def ll2float_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2float_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ll2float_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2float_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ll2float_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2float_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ll2float_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2float_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ull2float_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2float_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ull2float_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2float_rz", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rz", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ull2float_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2float_rd", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_rd", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ull2float_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2float_ru", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2float_ru", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ll2double_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2double_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ll2double_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2double_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ll2double_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2double_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ll2double_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_ll2double_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_ll2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ull2double_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2double_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ull2double_rz(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2double_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ull2double_rd(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2double_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ull2double_ru(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint64"),): ("__nv_ull2double_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint64"), ): ("__nv_ull2double_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def int_as_float(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_int_as_float", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int32"), ): ("__nv_int_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def float_as_int(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float_as_int", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_int", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def uint_as_float(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("uint32"),): ("__nv_uint_as_float", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("uint32"), ): ("__nv_uint_as_float", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def float_as_uint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_float_as_uint", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_float_as_uint", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def longlong_as_double(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int64"),): ("__nv_longlong_as_double", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("int64"), ): ("__nv_longlong_as_double", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def double_as_longlong(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_double_as_longlong", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_double_as_longlong", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def fast_sinf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_sinf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_sinf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_cosf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_cosf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_cosf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_log2f(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_log2f", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log2f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_logf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_logf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_logf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_expf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_expf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_expf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_tanf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_tanf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_tanf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_exp10f(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_exp10f", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_exp10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_log10f(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_fast_log10f", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_fast_log10f", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def fast_powf(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fast_powf", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fast_powf", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def hadd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_hadd", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_uhadd", core.dtype("uint32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_hadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_uhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) @core.extern def rhadd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("int32"),): ("__nv_rhadd", core.dtype("int32")), - (core.dtype("uint32"), core.dtype("uint32"),): ("__nv_urhadd", core.dtype("uint32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("int32")): ("__nv_rhadd", core.dtype("int32")), + (core.dtype("uint32"), core.dtype("uint32")): ("__nv_urhadd", core.dtype("uint32")), + }, is_pure=True, _builder=_builder) @core.extern def sub_rn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rn", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rn", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sub_rz(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rz", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rz", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rz", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rz", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sub_rd(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_rd", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_rd", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_rd", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_rd", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sub_ru(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fsub_ru", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_dsub_ru", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fsub_ru", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_dsub_ru", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rsqrt_rn(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_frsqrt_rn", core.dtype("fp32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_frsqrt_rn", core.dtype("fp32")), + }, is_pure=True, _builder=_builder) @core.extern def ffs(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("int32"),): ("__nv_ffs", core.dtype("int32")), - (core.dtype("int64"),): ("__nv_ffsll", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("int32"), ): ("__nv_ffs", core.dtype("int32")), + (core.dtype("int64"), ): ("__nv_ffsll", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def rint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_rintf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_rint", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_rintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def llrint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_llrintf", core.dtype("int64")), - (core.dtype("fp64"),): ("__nv_llrint", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_llrintf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llrint", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def nearbyint(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_nearbyintf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_nearbyint", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_nearbyintf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_nearbyint", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def isnan(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_isnanf", core.dtype("int32")), - (core.dtype("fp64"),): ("__nv_isnand", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_isnanf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isnand", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def signbit(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_signbitf", core.dtype("int32")), - (core.dtype("fp64"),): ("__nv_signbitd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [ + arg0, + ], { + (core.dtype("fp32"), ): ("__nv_signbitf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_signbitd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def copysign(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_copysignf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_copysign", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_copysignf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_copysign", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def finitef(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_finitef", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_finitef", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def isinf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_isinff", core.dtype("int32")), - (core.dtype("fp64"),): ("__nv_isinfd", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_isinff", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_isinfd", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def nextafter(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_nextafterf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_nextafter", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_nextafterf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_nextafter", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sin(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_sinf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_sin", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_sinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cos(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_cosf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cos", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_cosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sinpi(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_sinpif", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_sinpi", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_sinpif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinpi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cospi(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_cospif", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cospi", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_cospif", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cospi", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def tan(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_tanf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_tan", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_tanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def log2(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_log2f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_log2", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_log2f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def exp(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_expf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_exp", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_expf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def exp10(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_exp10f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_exp10", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_exp10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_exp10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cosh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_coshf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cosh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_coshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def sinh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_sinhf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_sinh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_sinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_sinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def tanh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_tanhf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_tanh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_tanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def atan2(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_atan2f", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_atan2", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_atan2f", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_atan2", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def atan(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_atanf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_atan", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_atanf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atan", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def asin(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_asinf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_asin", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_asinf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asin", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def acos(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_acosf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_acos", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_acosf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acos", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def log(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_logf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_log", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_logf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def log10(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_log10f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_log10", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_log10f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log10", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def log1p(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_log1pf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_log1p", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_log1pf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_log1p", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def acosh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_acoshf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_acosh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_acoshf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_acosh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def asinh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_asinhf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_asinh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_asinhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_asinh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def atanh(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_atanhf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_atanh", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_atanhf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_atanh", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def expm1(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_expm1f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_expm1", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_expm1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_expm1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def hypot(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_hypotf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_hypot", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_hypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_hypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rhypot(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rhypotf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rhypot", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_rhypotf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_rhypot", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def norm3d(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm3df", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm3d", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_norm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_norm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rnorm3d(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm3df", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm3d", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_rnorm3df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_rnorm3d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def norm4d(arg0, arg1, arg2, arg3, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, arg3, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_norm4df", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_norm4d", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_norm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_norm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rnorm4d(arg0, arg1, arg2, arg3, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, arg3, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_rnorm4df", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_rnorm4d", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2, arg3], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): + ("__nv_rnorm4df", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): + ("__nv_rnorm4d", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cbrt(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_cbrtf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cbrt", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_cbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def rcbrt(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_rcbrtf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_rcbrt", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_rcbrtf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_rcbrt", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def j0(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_j0f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_j0", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_j0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def j1(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_j1f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_j1", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_j1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_j1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def y0(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_y0f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_y0", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_y0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def y1(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_y1f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_y1", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_y1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_y1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def yn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_ynf", core.dtype("fp32")), - (core.dtype("int32"), core.dtype("fp64"),): ("__nv_yn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_ynf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_yn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def jn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("int32"), core.dtype("fp32"),): ("__nv_jnf", core.dtype("fp32")), - (core.dtype("int32"), core.dtype("fp64"),): ("__nv_jn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("int32"), core.dtype("fp32")): ("__nv_jnf", core.dtype("fp32")), + (core.dtype("int32"), core.dtype("fp64")): ("__nv_jn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cyl_bessel_i0(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cyl_bessel_i0", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i0f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i0", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def cyl_bessel_i1(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_cyl_bessel_i1", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_cyl_bessel_i1f", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_cyl_bessel_i1", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def erf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_erff", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_erf", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_erff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def erfinv(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_erfinvf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_erfinv", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_erfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def erfc(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_erfcf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_erfc", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfc", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def erfcx(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_erfcxf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_erfcx", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcxf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcx", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def erfcinv(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_erfcinvf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_erfcinv", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_erfcinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_erfcinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def normcdfinv(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_normcdfinvf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_normcdfinv", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdfinvf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdfinv", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def normcdf(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_normcdff", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_normcdf", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_normcdff", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_normcdf", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def lgamma(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_lgammaf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_lgamma", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_lgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_lgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ldexp(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_ldexpf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("int32"),): ("__nv_ldexp", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_ldexpf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_ldexp", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def scalbn(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_scalbnf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("int32"),): ("__nv_scalbn", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_scalbnf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_scalbn", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fmod(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmodf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fmod", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmodf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fmod", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def remainder(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_remainderf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_remainder", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_remainderf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_remainder", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def fma(arg0, arg1, arg2, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, arg2, ], - {(core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fmaf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fma", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1, arg2], { + (core.dtype("fp32"), core.dtype("fp32"), core.dtype("fp32")): ("__nv_fmaf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64"), core.dtype("fp64")): ("__nv_fma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def pow(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("int32"),): ("__nv_powif", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("int32"),): ("__nv_powi", core.dtype("fp64")), - (core.dtype("fp32"), core.dtype("fp32"),): ("__nv_powf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_pow", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("int32")): ("__nv_powif", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("int32")): ("__nv_powi", core.dtype("fp64")), + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_powf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_pow", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def tgamma(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_tgammaf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_tgamma", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_tgammaf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_tgamma", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def round(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_roundf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_round", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_roundf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_round", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def llround(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_llroundf", core.dtype("int64")), - (core.dtype("fp64"),): ("__nv_llround", core.dtype("int64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_llroundf", core.dtype("int64")), + (core.dtype("fp64"), ): ("__nv_llround", core.dtype("int64")), + }, is_pure=True, _builder=_builder) @core.extern def fdim(arg0, arg1, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, arg1, ], - {(core.dtype("fp32"), core.dtype("fp32"),): ("__nv_fdimf", core.dtype("fp32")), - (core.dtype("fp64"), core.dtype("fp64"),): ("__nv_fdim", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0, arg1], { + (core.dtype("fp32"), core.dtype("fp32")): ("__nv_fdimf", core.dtype("fp32")), + (core.dtype("fp64"), core.dtype("fp64")): ("__nv_fdim", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def ilogb(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_ilogbf", core.dtype("int32")), - (core.dtype("fp64"),): ("__nv_ilogb", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_ilogbf", core.dtype("int32")), + (core.dtype("fp64"), ): ("__nv_ilogb", core.dtype("int32")), + }, is_pure=True, _builder=_builder) @core.extern def logb(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp32"),): ("__nv_logbf", core.dtype("fp32")), - (core.dtype("fp64"),): ("__nv_logb", core.dtype("fp64")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise( + "libdevice", libdevice_path(), [arg0], { + (core.dtype("fp32"), ): ("__nv_logbf", core.dtype("fp32")), + (core.dtype("fp64"), ): ("__nv_logb", core.dtype("fp64")), + }, is_pure=True, _builder=_builder) @core.extern def isfinited(arg0, _builder=None): - return core.extern_elementwise("libdevice", libdevice_path(), [arg0, ], - {(core.dtype("fp64"),): ("__nv_isfinited", core.dtype("int32")), - }, is_pure=True, _builder=_builder) + return core.extern_elementwise("libdevice", libdevice_path(), [arg0], { + (core.dtype("fp64"), ): ("__nv_isfinited", core.dtype("int32")), + }, is_pure=True, _builder=_builder) diff --git a/python/triton/language/random.py b/python/triton/language/random.py index 7af60855b040..a8fb520e0268 100644 --- a/python/triton/language/random.py +++ b/python/triton/language/random.py @@ -91,6 +91,7 @@ def randint4x(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): # two_to_the_minus_32: tl.constexpr = 2.328306e-10 # return x * two_to_the_minus_32 + @jit def uint32_to_uniform_float(x): """ @@ -134,6 +135,7 @@ def rand4x(seed, offsets, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT): u4 = uint32_to_uniform_float(i4) return u1, u2, u3, u4 + # ------------------- # randn # ------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index e59e6bd92700..d74cbb150254 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -16,10 +16,12 @@ def _is_cuda(target): from ..compiler.compiler import CudaTargetDescriptor return isinstance(target, CudaTargetDescriptor) + # Create custom exception that prints message "hello" class IncompatibleTypeErrorImpl(Exception): + def __init__(self, type_a, type_b): self.type_a = type_a self.type_b = type_b @@ -31,6 +33,7 @@ def __init__(self, type_a, type_b): # Programming Model # ===----------------------------------------------------------------------===## + def program_id(axis: int, builder: ir.builder) -> tl.tensor: if axis not in (0, 1, 2): raise ValueError(f"program_id axis must be 0, 1, or 2 but got {axis}") @@ -42,6 +45,7 @@ def num_programs(axis: int, builder: ir.builder) -> tl.tensor: raise ValueError(f"num_programs axis must be 0, 1, or 2 but got {axis}") return tl.tensor(builder.create_get_num_programs(axis), tl.int32) + # ===----------------------------------------------------------------------===// # Implicit Casting Utilities # ===----------------------------------------------------------------------===// @@ -92,10 +96,12 @@ def computation_type_impl(a_ty: tl.dtype, b_ty: tl.dtype, div_or_mod: bool) -> t # 5 ) both operands are integer and undergo # integer promotion if div_or_mod and a_ty.int_signedness != b_ty.int_signedness: - raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + " because they have different signedness;" + raise ValueError("Cannot use /, #, or % with " + a_ty.__repr__() + " and " + b_ty.__repr__() + + " because they have different signedness;" "this is unlikely to result in a useful answer. Cast them to the same signedness.") return integer_promote_impl(a_ty, b_ty) + # ===----------------------------------------------------------------------===// # Binary Operators # ===----------------------------------------------------------------------===// @@ -113,12 +119,9 @@ def check_ptr_type_impl(type_a: tl.dtype, type_b: tl.dtype, allow_ptr_a: bool) - raise IncompatibleTypeErrorImpl(type_a, type_b) -def binary_op_type_checking_impl(lhs: tl.tensor, - rhs: tl.tensor, - builder: ir.builder, - allow_lhs_ptr=False, allow_rhs_ptr=False, - arithmetic_check=True, div_or_mod=False - ) -> Tuple[tl.tensor, tl.tensor]: +def binary_op_type_checking_impl(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder, allow_lhs_ptr=False, + allow_rhs_ptr=False, arithmetic_check=True, + div_or_mod=False) -> Tuple[tl.tensor, tl.tensor]: # implicit broadcasting lhs, rhs = broadcast_impl_value(lhs, rhs, builder) # implicit typecasting @@ -133,9 +136,7 @@ def binary_op_type_checking_impl(lhs: tl.tensor, return lhs, rhs -def add(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def add(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar @@ -159,15 +160,12 @@ def add(input: tl.tensor, assert False -def sub(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def sub(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, True, False) scalar_ty = input.type.scalar # ptr - offset if scalar_ty.is_ptr(): - return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), - input.type) + return tl.tensor(builder.create_addptr(input.handle, minus(other, builder).handle), input.type) # float - float if scalar_ty.is_floating(): return tl.tensor(builder.create_fsub(input.handle, other.handle), input.type) @@ -177,9 +175,7 @@ def sub(input: tl.tensor, assert False -def mul(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def mul(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float * float @@ -191,9 +187,7 @@ def mul(input: tl.tensor, assert False -def truediv(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def truediv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar @@ -219,9 +213,7 @@ def truediv(input: tl.tensor, return tl.tensor(builder.create_fdiv(input.handle, other.handle), input.type) -def floordiv(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def floordiv(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar @@ -236,10 +228,7 @@ def floordiv(input: tl.tensor, assert False -def fdiv(input: tl.tensor, - other: tl.tensor, - ieee_rounding: bool, - builder: ir.builder) -> tl.tensor: +def fdiv(input: tl.tensor, other: tl.tensor, ieee_rounding: bool, builder: ir.builder) -> tl.tensor: input_scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar if not input_scalar_ty.is_floating() or not other_scalar_ty.is_floating(): @@ -249,18 +238,14 @@ def fdiv(input: tl.tensor, return tl.tensor(ret, input.type) -def mod(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def mod(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder, False, False, True, True) scalar_ty = input.type.scalar other_scalar_ty = other.type.scalar # float % float if scalar_ty.is_floating(): # input - input.div(other, rounding_mode="floor") * other - ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder), - other, builder), - builder) + ret = sub(input, mul(floor(fdiv(input, other, False, builder), builder), other, builder), builder) return ret # % int elif scalar_ty.is_int(): @@ -274,13 +259,13 @@ def mod(input: tl.tensor, return tl.tensor(builder.create_urem(input.handle, other.handle), input.type) assert False + ############## # bitwise ops ############## -def bitwise_op_type_checking_impl(input: tl.tensor, - other: tl.tensor, +def bitwise_op_type_checking_impl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor]: input, other = binary_op_type_checking_impl(input, other, builder, False, False, False) input_sca_ty = input.type.scalar @@ -295,23 +280,17 @@ def bitwise_op_type_checking_impl(input: tl.tensor, return input, other -def and_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def and_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_and(input.handle, other.handle), input.type) -def or_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def or_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_or(input.handle, other.handle), input.type) -def xor_(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def xor_(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_xor(input.handle, other.handle), input.type) @@ -338,26 +317,21 @@ def not_(input: tl.tensor, builder: ir.builder): return invert(input, builder) -def lshr(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def lshr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_lshr(input.handle, other.handle), input.type) -def ashr(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def ashr(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_ashr(input.handle, other.handle), input.type) -def shl(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def shl(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = bitwise_op_type_checking_impl(input, other, builder) return tl.tensor(builder.create_shl(input.handle, other.handle), input.type) + # ===----------------------------------------------------------------------===// # Unary Operators # ===----------------------------------------------------------------------===// @@ -367,8 +341,7 @@ def plus(input: tl.tensor) -> tl.tensor: return input -def minus(input: tl.tensor, - builder: ir.builder) -> tl.tensor: +def minus(input: tl.tensor, builder: ir.builder) -> tl.tensor: input_sca_ty = input.type.scalar if input_sca_ty.is_ptr(): raise ValueError("wrong type argument to unary minus (" + input_sca_ty.__repr__() + ")") @@ -376,8 +349,7 @@ def minus(input: tl.tensor, return sub(_0, input, builder) -def invert(input: tl.tensor, - builder: tl.tensor) -> tl.tensor: +def invert(input: tl.tensor, builder: tl.tensor) -> tl.tensor: input_sca_ty = input.type.scalar if input_sca_ty.is_ptr() or input_sca_ty.is_floating(): raise ValueError("wrong type argument to unary invert (" + input_sca_ty.__repr__() + ")") @@ -395,9 +367,7 @@ def _bool_like(v: tl.tensor) -> tl.block_type: return tl.block_type(tl.int1, shape) -def greater_than(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def greater_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float > float @@ -412,9 +382,7 @@ def greater_than(input: tl.tensor, assert False -def greater_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def greater_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float >= float @@ -429,9 +397,7 @@ def greater_equal(input: tl.tensor, assert False -def less_than(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def less_than(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float < float @@ -446,9 +412,7 @@ def less_than(input: tl.tensor, assert False -def less_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def less_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float < float @@ -463,9 +427,7 @@ def less_equal(input: tl.tensor, assert False -def equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float == float @@ -477,9 +439,7 @@ def equal(input: tl.tensor, assert False -def not_equal(input: tl.tensor, - other: tl.tensor, - builder: ir.builder) -> tl.tensor: +def not_equal(input: tl.tensor, other: tl.tensor, builder: ir.builder) -> tl.tensor: input, other = binary_op_type_checking_impl(input, other, builder) scalar_ty = input.type.scalar # float == float @@ -490,6 +450,7 @@ def not_equal(input: tl.tensor, return tl.tensor(builder.create_icmpNE(input.handle, other.handle), _bool_like(input)) assert False + # ===----------------------------------------------------------------------===// # Block Creation # ===----------------------------------------------------------------------===// @@ -532,6 +493,7 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te # Shape Manipulation # ===----------------------------------------------------------------------===// + def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: assert not value.type.is_block(), "Cannot splat a block tensor" if len(shape) == 0: @@ -540,9 +502,7 @@ def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_splat(value.handle, shape), ret_ty) -def view(input: tl.tensor, - dst_shape: List[int], - builder: ir.builder) -> tl.tensor: +def view(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor: # TODO: disable when TritonToTritonGPU handles views properly # assert len(input.shape) == len(dst_shape) @@ -555,9 +515,7 @@ def view(input: tl.tensor, return tl.tensor(builder.create_view(input.handle, dst_shape), ret_ty) -def reshape(input: tl.tensor, - dst_shape: List[int], - builder: ir.builder) -> tl.tensor: +def reshape(input: tl.tensor, dst_shape: List[int], builder: ir.builder) -> tl.tensor: raise ValueError("`reshape` is not supported yet. Please use `view` instead if applicable. " "Note that view may reorder elements in an implementation- and context- dependent way.") @@ -587,9 +545,7 @@ def trans(input: tl.tensor, builder: ir.builder) -> tl.tensor: return tl.tensor(builder.create_trans(input.handle), ret_type) -def broadcast_impl_shape(input: tl.tensor, - shape: List[int], - builder: ir.builder) -> tl.tensor: +def broadcast_impl_shape(input: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor: if not input.type.is_block(): ret_ty = tl.block_type(input.type, shape) return tl.tensor(builder.create_splat(input.handle, shape), ret_ty) @@ -607,9 +563,7 @@ def broadcast_impl_shape(input: tl.tensor, return tl.tensor(builder.create_broadcast(input.handle, shape), ret_ty) -def broadcast_impl_value(lhs: tl.tensor, - rhs: tl.tensor, - builder: ir.builder) -> tl.tensor: +def broadcast_impl_value(lhs: tl.tensor, rhs: tl.tensor, builder: ir.builder) -> tl.tensor: lhs_ty = lhs.type rhs_ty = rhs.type @@ -629,13 +583,15 @@ def broadcast_impl_value(lhs: tl.tensor, if len(lhs_shape) < len(rhs_shape): # Add new axes to lhs for dim in range(len(lhs_shape), len(rhs_shape)): - lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) + lhs = tl.tensor(builder.create_expand_dims(lhs.handle, 0), + tl.block_type(lhs_ty.scalar, [1] + lhs_shape)) lhs_ty = lhs.type lhs_shape = lhs_ty.get_block_shapes() elif len(rhs_shape) < len(lhs_shape): # Add new axes to rhs for dim in range(len(rhs_shape), len(lhs_shape)): - rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) + rhs = tl.tensor(builder.create_expand_dims(rhs.handle, 0), + tl.block_type(rhs_ty.scalar, [1] + rhs_shape)) rhs_ty = rhs.type rhs_shape = rhs_ty.get_block_shapes() assert len(rhs_shape) == len(lhs_shape) @@ -661,14 +617,13 @@ def broadcast_impl_value(lhs: tl.tensor, # (scalar, scalar) => returns original blocks return lhs, rhs + ####### # cast ####### -def bitcast(input: tl.tensor, - dst_ty: tl.dtype, - builder: ir.builder) -> tl.tensor: +def bitcast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_ty = input.type if src_ty.is_block(): dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes()) @@ -684,13 +639,10 @@ def bitcast(input: tl.tensor, if src_bits != dst_bits: raise ValueError("Cannot bitcast data-type of size " + str(src_bits) + " to " "data-type of size " + str(dst_bits)) - return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_bitcast(input.handle, dst_ty.to_ir(builder)), dst_ty) -def cast(input: tl.tensor, - dst_ty: tl.dtype, - builder: ir.builder) -> tl.tensor: +def cast(input: tl.tensor, dst_ty: tl.dtype, builder: ir.builder) -> tl.tensor: src_ty = input.type if isinstance(dst_ty, tl.constexpr): dst_ty = dst_ty.value @@ -709,8 +661,7 @@ def cast(input: tl.tensor, # Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64 if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \ (src_sca_ty.is_floating() and dst_sca_ty.is_fp8()): - return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) # bf16 <=> (not fp32) if (src_sca_ty.is_fp16() and not dst_sca_ty.is_fp32()) or \ @@ -724,9 +675,7 @@ def cast(input: tl.tensor, dst_sca_ty.is_floating() and \ src_sca_ty.primitive_bitwidth > dst_sca_ty.primitive_bitwidth if truncate_fp: - return tl.tensor(builder.create_fp_trunc(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_fp_trunc(input.handle, dst_ty.to_ir(builder)), dst_ty) # Standard floating types' casting: extension # fp32 => fp64 @@ -736,9 +685,7 @@ def cast(input: tl.tensor, dst_sca_ty.is_floating() and \ src_sca_ty.primitive_bitwidth < dst_sca_ty.primitive_bitwidth if ext_fp: - return tl.tensor(builder.create_fp_ext(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_fp_ext(input.handle, dst_ty.to_ir(builder)), dst_ty) # Casting between integer types if src_sca_ty.is_int() and dst_sca_ty.is_int() and \ @@ -749,9 +696,7 @@ def cast(input: tl.tensor, _0 = tl.tensor(builder.get_null_value(ty), input.dtype) return not_equal(input, _0, builder) else: - return tl.tensor(builder.create_int_cast(input.handle, - dst_ty.to_ir(builder), sign_extend), - dst_ty) + return tl.tensor(builder.create_int_cast(input.handle, dst_ty.to_ir(builder), sign_extend), dst_ty) # Casting standard floating types to integer types if src_sca_ty.is_standard_floating() and dst_sca_ty.is_int(): @@ -760,35 +705,24 @@ def cast(input: tl.tensor, _0 = tl.tensor(builder.get_null_value(ty), input.dtype) return not_equal(input, _0, builder) elif dst_sca_ty.is_int_signed(): - return tl.tensor(builder.create_fp_to_si(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_fp_to_si(input.handle, dst_ty.to_ir(builder)), dst_ty) else: - return tl.tensor(builder.create_fp_to_ui(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_fp_to_ui(input.handle, dst_ty.to_ir(builder)), dst_ty) # Casting integer types to standard floating types if src_sca_ty.is_int() and dst_sca_ty.is_standard_floating(): if src_sca_ty.is_bool() or not src_sca_ty.is_int_signed(): - return tl.tensor(builder.create_ui_to_fp(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_ui_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) else: - return tl.tensor(builder.create_si_to_fp(input.handle, - dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_si_to_fp(input.handle, dst_ty.to_ir(builder)), dst_ty) # Casting pointer types to integer types if src_sca_ty.is_ptr() and dst_sca_ty.is_int(): bitwidth = dst_sca_ty.int_bitwidth if bitwidth == 64: - return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), - dst_ty) + return tl.tensor(builder.create_ptr_to_int(input.handle, dst_ty.to_ir(builder)), dst_ty) if bitwidth == 1: - return not_equal(cast(input, tl.int64, builder), - tl.tensor(builder.get_int64(0), tl.int64), - builder) + return not_equal(cast(input, tl.int64, builder), tl.tensor(builder.get_int64(0), tl.int64), builder) # Casting integer types to pointer types if src_sca_ty.is_int() and dst_sca_ty.is_ptr(): @@ -800,6 +734,7 @@ def cast(input: tl.tensor, assert False, f'cannot cast {input} to {dst_ty}' + # ===----------------------------------------------------------------------===// # Memory Operators # ===----------------------------------------------------------------------===// @@ -918,8 +853,8 @@ def _load_block_pointer(ptr, mask, other, boundary_check, padding, cache, evicti boundary_check = _canonicalize_boundary_check(boundary_check, dst_ty.get_block_shapes()) # Build IR - return tl.tensor(builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, - is_volatile), dst_ty) + return tl.tensor( + builder.create_tensor_pointer_load(ptr.handle, boundary_check, padding, cache, eviction, is_volatile), dst_ty) def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_volatile, builder): @@ -975,19 +910,13 @@ def _load_legacy(ptr, mask, other, boundary_check, padding, cache, eviction, is_ if not mask: return tl.tensor(builder.create_load(ptr.handle, cache, eviction, is_volatile), dst_ty) else: - return tl.tensor(builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, - eviction, is_volatile), dst_ty) - - -def load(ptr: tl.tensor, - mask: Optional[tl.tensor], - other: Optional[tl.tensor], - boundary_check, - padding_option: str, - cache_modifier: str, - eviction_policy: str, - is_volatile: bool, - builder: ir.builder) -> tl.tensor: + return tl.tensor( + builder.create_masked_load(ptr.handle, mask.handle, other.handle if other else None, cache, eviction, + is_volatile), dst_ty) + + +def load(ptr: tl.tensor, mask: Optional[tl.tensor], other: Optional[tl.tensor], boundary_check, padding_option: str, + cache_modifier: str, eviction_policy: str, is_volatile: bool, builder: ir.builder) -> tl.tensor: # Cache, eviction and padding options cache = _str_to_load_cache_modifier(cache_modifier) eviction = _str_to_eviction_policy(eviction_policy) @@ -1012,7 +941,8 @@ def _store_block_pointer(ptr, val, mask, boundary_check, cache, eviction, builde if not val.type.is_block(): val = broadcast_impl_shape(val, block_shape, builder) assert val.type.is_block(), "Value argument must be block type or a scalar" - assert block_shape == val.type.get_block_shapes(), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" + assert block_shape == val.type.get_block_shapes( + ), f"Block shape({block_shape}) and value shape({val.type.get_block_shapes()}) mismatch" assert ptr.type.element_ty.element_ty == val.type.element_ty, f"Block element type({ptr.type.element_ty.element_ty}) and value element type({val.type.element_ty}) mismatch" elt_ty = ptr.type.element_ty.element_ty @@ -1070,13 +1000,8 @@ def _store_legacy(ptr, val, mask, boundary_check, cache, eviction, builder): return tl.tensor(builder.create_masked_store(ptr.handle, val.handle, mask.handle, cache, eviction), tl.void) -def store(ptr: tl.tensor, - val: tl.tensor, - mask: Optional[tl.tensor], - boundary_check, - cache_modifier: str, - eviction_policy: str, - builder: ir.builder) -> tl.tensor: +def store(ptr: tl.tensor, val: tl.tensor, mask: Optional[tl.tensor], boundary_check, cache_modifier: str, + eviction_policy: str, builder: ir.builder) -> tl.tensor: # Cache and eviction options cache = _str_to_store_cache_modifier(cache_modifier) eviction = _str_to_eviction_policy(eviction_policy) @@ -1094,12 +1019,7 @@ def store(ptr: tl.tensor, ######### -def atomic_cas(ptr: tl.tensor, - cmp: tl.tensor, - val: tl.tensor, - sem: str, - scope: str, - builder: ir.builder) -> tl.tensor: +def atomic_cas(ptr: tl.tensor, cmp: tl.tensor, val: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: sem = _str_to_sem(sem) scope = _str_to_scope(scope) element_ty = ptr.type.scalar.element_ty @@ -1108,10 +1028,7 @@ def atomic_cas(ptr: tl.tensor, return tl.tensor(builder.create_atomic_cas(ptr.handle, cmp.handle, val.handle, sem, scope), val.type) -def atom_red_typechecking_impl(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - op: str, +def atom_red_typechecking_impl(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, op: str, builder: ir.builder) -> Tuple[tl.tensor, tl.tensor, tl.tensor]: if not ptr.type.scalar.is_ptr(): raise ValueError("Pointer argument of store instruction is " + ptr.type.__repr__()) @@ -1136,12 +1053,7 @@ def atom_red_typechecking_impl(ptr: tl.tensor, return ptr, val, mask -def atomic_max(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - scope: str, - builder: ir.builder) -> tl.tensor: +def atomic_max(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'max', builder) sem = _str_to_sem(sem) scope = _str_to_scope(scope) @@ -1149,21 +1061,11 @@ def atomic_max(ptr: tl.tensor, # direct call to atomic_max for integers if sca_ty.is_int(): if sca_ty.is_int_signed(): - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, - ptr.handle, - val.handle, - mask.handle, - sem, - scope), - val.type) + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) else: - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, - ptr.handle, - val.handle, - mask.handle, - sem, - scope), - val.type) + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, ptr.handle, val.handle, mask.handle, sem, scope), val.type) # for float # return atomic_smax(i_ptr, i_val) if val >= 0 # return atomic_umin(i_ptr, i_val) if val < 0 @@ -1177,18 +1079,17 @@ def atomic_max(ptr: tl.tensor, i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder) pos = greater_equal(val, zero, builder) neg = less_than(val, zero, builder) - pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, and_(mask, pos, builder).handle, sem, scope), i_val.type) - neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, and_(mask, neg, builder).handle, sem, scope), i_val.type) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MAX, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, i_ptr.handle, i_val.handle, + and_(mask, neg, builder).handle, sem, scope), i_val.type) ret = where(pos, pos_ret, neg_ret, builder) return bitcast(ret, sca_ty, builder) -def atomic_min(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - scope: str, - builder: ir.builder) -> tl.tensor: +def atomic_min(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'min', builder) sem = _str_to_sem(sem) scope = _str_to_scope(scope) @@ -1196,21 +1097,11 @@ def atomic_min(ptr: tl.tensor, # direct call to atomic_min for integers if sca_ty.is_int(): if sca_ty.is_int_signed(): - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, - ptr.handle, - val.handle, - mask.handle, - sem, - scope), - val.type) + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) else: - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, - ptr.handle, - val.handle, - mask.handle, - sem, - scope), - val.type) + return tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMIN, ptr.handle, val.handle, mask.handle, sem, scope), val.type) # for float # return atomic_smin(i_ptr, i_val) if val >= 0 # return atomic_umax(i_ptr, i_val) if val < 0 @@ -1224,30 +1115,17 @@ def atomic_min(ptr: tl.tensor, i_ptr = bitcast(ptr, tl.pointer_type(itype, 1), builder) pos = greater_equal(val, zero, builder) neg = less_than(val, zero, builder) - pos_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, - i_ptr.handle, - i_val.handle, - and_(mask, pos, builder).handle, - sem, - scope), - i_val.type) - neg_ret = tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, - i_ptr.handle, - i_val.handle, - and_(mask, neg, builder).handle, - sem, - scope), - i_val.type) + pos_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.MIN, i_ptr.handle, i_val.handle, + and_(mask, pos, builder).handle, sem, scope), i_val.type) + neg_ret = tl.tensor( + builder.create_atomic_rmw(ir.ATOMIC_OP.UMAX, i_ptr.handle, i_val.handle, + and_(mask, neg, builder).handle, sem, scope), i_val.type) ret = where(pos, pos_ret, neg_ret, builder) return bitcast(ret, sca_ty, builder) -def atomic_add(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - scope: str, - builder: ir.builder) -> tl.tensor: +def atomic_add(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'add', builder) sem = _str_to_sem(sem) scope = _str_to_scope(scope) @@ -1256,52 +1134,38 @@ def atomic_add(ptr: tl.tensor, return tl.tensor(builder.create_atomic_rmw(op, ptr.handle, val.handle, mask.handle, sem, scope), val.type) -def atomic_and(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - scope: str, - builder: ir.builder) -> tl.tensor: +def atomic_and(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'and', builder) sem = _str_to_sem(sem) scope = _str_to_scope(scope) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.AND, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) -def atomic_or(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - scope: str, - builder: ir.builder) -> tl.tensor: +def atomic_or(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'or', builder) sem = _str_to_sem(sem) scope = _str_to_scope(scope) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.OR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) -def atomic_xor(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - scope: str, - builder: ir.builder) -> tl.tensor: +def atomic_xor(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xor', builder) sem = _str_to_sem(sem) scope = _str_to_scope(scope) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XOR, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) -def atomic_xchg(ptr: tl.tensor, - val: tl.tensor, - mask: tl.tensor, - sem: str, - scope: str, +def atomic_xchg(ptr: tl.tensor, val: tl.tensor, mask: tl.tensor, sem: str, scope: str, builder: ir.builder) -> tl.tensor: ptr, val, mask = atom_red_typechecking_impl(ptr, val, mask, 'xchg', builder) sem = _str_to_sem(sem) scope = _str_to_scope(scope) - return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), val.type) + return tl.tensor(builder.create_atomic_rmw(ir.ATOMIC_OP.XCHG, ptr.handle, val.handle, mask.handle, sem, scope), + val.type) + # ===----------------------------------------------------------------------===// # Linear Algebra @@ -1321,13 +1185,9 @@ def mfma_supported(M, N, K, allow_tf32, ret_scalar_ty) -> bool: return True -def dot(lhs: tl.tensor, - rhs: tl.tensor, - acc: tl.tensor, - allow_tf32: bool, - max_num_imprecise_acc: int, - out_dtype: tl.dtype, - builder: ir.builder) -> tl.tensor: +def dot(lhs: tl.tensor, rhs: tl.tensor, acc: tl.tensor, allow_tf32: bool, max_num_imprecise_acc: int, + out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): # Checks for non-cuda archs if not _is_cuda(target): @@ -1335,22 +1195,30 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): return # Checks for cuda arch if target.capability < 90: - assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv(), "Dot op does not support fp8e4nv on CUDA arch < 90" + assert not lhs_dtype.is_fp8e4nv() and not rhs_dtype.is_fp8e4nv( + ), "Dot op does not support fp8e4nv on CUDA arch < 90" if lhs_dtype.is_fp8() and rhs_dtype.is_fp8(): return assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" else: - assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15(), "Dot op does not support fp8e4b15 on CUDA arch >= 90" - assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4(), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90" + assert not lhs_dtype.is_fp8e4b15() and not rhs_dtype.is_fp8e4b15( + ), "Dot op does not support fp8e4b15 on CUDA arch >= 90" + assert not lhs_dtype.is_fp8e4b15x4() and not rhs_dtype.is_fp8e4b15x4( + ), "Dot op does not support fp8e4b15x4 on CUDA arch >= 90" if lhs_dtype.is_int() or rhs_dtype.is_int(): assert lhs_dtype == rhs_dtype, f"Both operands must be same type. First operand ({lhs_dtype}) and second operand ({rhs_dtype})" - assert lhs_dtype.is_int8() or lhs_dtype.is_uint8(), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" + assert lhs_dtype.is_int8() or lhs_dtype.is_uint8( + ), f"Both operands must be either int8 or uint8. Operand type ({lhs_dtype})" elif lhs_dtype.is_fp8() or rhs_dtype.is_fp8(): - assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})" - assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5(), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})" + assert lhs_dtype.is_fp8e4nv() or lhs_dtype.is_fp8e5( + ), f"Only supports fp8e4nv or fp8e5. First operand ({lhs_dtype})" + assert rhs_dtype.is_fp8e4nv() or rhs_dtype.is_fp8e5( + ), f"Only supports fp8e4nv or fp8e5. Second operand ({rhs_dtype})" else: - assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1(), f"Unsupported dtype {lhs_dtype}" - assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1(), f"Unsupported dtype {rhs_dtype}" + assert lhs_dtype.is_fp16() or lhs_dtype.is_bf16() or lhs_dtype.is_fp32() or lhs_dtype.is_int1( + ), f"Unsupported dtype {lhs_dtype}" + assert rhs_dtype.is_fp16() or rhs_dtype.is_bf16() or rhs_dtype.is_fp32() or rhs_dtype.is_int1( + ), f"Unsupported dtype {rhs_dtype}" assert lhs_dtype == rhs_dtype, f"First input ({lhs_dtype}) and second input ({rhs_dtype}) must have the same dtype!" assert lhs.type.is_block() and rhs.type.is_block() @@ -1359,7 +1227,8 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): assert len(lhs.shape) == 2, f"First input shape ({lhs.shape}) is not two dimensional!" assert len(rhs.shape) == 2, f"Second input shape ({rhs.shape}) is not two dimensional!" - assert lhs.shape[1].value == rhs.shape[0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})" + assert lhs.shape[1].value == rhs.shape[ + 0].value, f"First input shape ({lhs.shape}) and second input shape {rhs.shape} are not compatible for matmul (second index of first shape ({lhs.shape[1].value}) must be equal to first index of second shape ({rhs.shape[0].value})" assert lhs.shape[0].value >= 16 and lhs.shape[1].value >= 16 \ and rhs.shape[1].value >= 16, \ f"All values in both first input shape ({lhs.shape}) and second input shape ({rhs.shape}) must be >= 16!" @@ -1370,7 +1239,8 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): _0 = builder.get_int32(0) ret_scalar_ty = tl.int32 elif out_dtype.is_bf16(): - raise ValueError("out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") + raise ValueError( + "out_dtype=bfloat16 is unsupported. Please use out_dtype=float32/float16 and cast with `.to(tl.bfloat16)`") elif lhs.type.scalar.is_fp32() or lhs.type.scalar.is_bf16(): _0 = builder.get_fp32(0) ret_scalar_ty = tl.float32 @@ -1391,10 +1261,10 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): else: _0 = builder.create_splat(builder.get_fp32(0), [M, N]) ret_ty = tl.block_type(ret_cast_scalar_ty, [M, N]) - ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), - ret_ty) + ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) return cast(ret, ret_scalar_ty, builder) - if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32: + if is_hip() and mfma_supported(M, N, lhs.type.shape[1], allow_tf32, + ret_scalar_ty) and ret_scalar_ty.primitive_bitwidth < 32: if lhs.type.scalar.is_int(): ret_dot_scalar_ty = tl.int32 _0 = builder.create_splat(builder.get_int32(0), [M, N]) @@ -1402,8 +1272,7 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): ret_dot_scalar_ty = tl.float32 _0 = builder.create_splat(builder.get_fp32(0), [M, N]) ret_ty = tl.block_type(ret_dot_scalar_ty, [M, N]) - ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), - ret_ty) + ret = tl.tensor(builder.create_dot(lhs.handle, rhs.handle, _0, allow_tf32), ret_ty) return cast(ret, ret_scalar_ty, builder) ret_ty = tl.block_type(ret_scalar_ty, [M, N]) if acc is None: @@ -1413,23 +1282,21 @@ def assert_dtypes_valid(lhs_dtype, rhs_dtype, target): assert acc.type == ret_ty # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 - if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() and ret_scalar_ty.is_fp32()): + if not (_is_cuda(builder.target) and builder.target.capability == 90 and lhs.dtype.is_fp8() and rhs.dtype.is_fp8() + and ret_scalar_ty.is_fp32()): max_num_imprecise_acc = 0 if max_num_imprecise_acc is None: max_num_imprecise_acc = 2**30 - return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), - ret_ty) + return tl.tensor(builder.create_dot(lhs.handle, rhs.handle, acc_handle, allow_tf32, max_num_imprecise_acc), ret_ty) # ===----------------------------------------------------------------------===// # Indexing # ===----------------------------------------------------------------------===// -def where(condition: tl.tensor, - x: tl.tensor, - y: tl.tensor, - builder: ir.builder) -> tl.tensor: + +def where(condition: tl.tensor, x: tl.tensor, y: tl.tensor, builder: ir.builder) -> tl.tensor: condition = cast(condition, tl.int1, builder) if condition.type.is_block(): condition, x = broadcast_impl_value(condition, x, builder) @@ -1442,14 +1309,13 @@ def where(condition: tl.tensor, ret_ty = x.type return tl.tensor(builder.create_select(condition.handle, x.handle, y.handle), ret_ty) + # ===----------------------------------------------------------------------===// # Reduction # ===----------------------------------------------------------------------=== -def reduction( - inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder -) -> Tuple[tl.tensor, ...]: +def reduction(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder) -> Tuple[tl.tensor, ...]: if axis is None: new_inputs = [] for i in range(len(inputs)): @@ -1475,10 +1341,7 @@ def wrap_tensor(x, scalar_ty): region_builder_fn(reduce_op) reduce_op.verify() - return tuple( - wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) - for i in range(len(inputs)) - ) + return tuple(wrap_tensor(reduce_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs))) # ===----------------------------------------------------------------------=== @@ -1486,9 +1349,8 @@ def wrap_tensor(x, scalar_ty): # ===----------------------------------------------------------------------=== -def associative_scan( - inputs: Sequence[tl.tensor], axis: int, region_builder_fn, builder: ir.builder -) -> Tuple[tl.tensor, ...]: +def associative_scan(inputs: Sequence[tl.tensor], axis: int, region_builder_fn, + builder: ir.builder) -> Tuple[tl.tensor, ...]: if len(inputs) != 1: raise ValueError("Current implementation only support single tensor input") shape = inputs[0].type.shape @@ -1501,16 +1363,14 @@ def wrap_tensor(x, scalar_ty): region_builder_fn(scan_op) scan_op.verify() - return tuple( - wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar) - for i in range(len(inputs)) - ) + return tuple(wrap_tensor(scan_op.get_result(i), inputs[i].type.scalar) for i in range(len(inputs))) # ===----------------------------------------------------------------------=== # Math # ===----------------------------------------------------------------------=== + def _check_dtype(dtypes: List[str]) -> T: """ We're following libdevice's convention to check accepted data types for math functions. @@ -1519,7 +1379,9 @@ def _check_dtype(dtypes: List[str]) -> T: We should let the users know that they are using and invoke explicit cast to convert the data type to the supported one. """ + def wrapper(fn): + @wraps(fn) def check(*args, **kwargs): # concatenate args and kwargs @@ -1528,6 +1390,7 @@ def check(*args, **kwargs): if arg.type.scalar.name not in dtypes: raise ValueError(f"Expected dtype {dtypes} but got {arg.type.scalar.name}") return fn(*args, **kwargs) + return check return wrapper @@ -1631,8 +1494,8 @@ def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl. def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor: cond_ty = cond.type if not cond_ty.is_block(): - cond_ty = tl.block_type(cond_ty.scalar, (1,)) - cond = tl.tensor(builder.create_splat(cond.handle, (1,)), cond_ty) + cond_ty = tl.block_type(cond_ty.scalar, (1, )) + cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty) return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void) diff --git a/python/triton/language/standard.py b/python/triton/language/standard.py index 8ef52cb9cfd6..acda0ca7acc6 100644 --- a/python/triton/language/standard.py +++ b/python/triton/language/standard.py @@ -123,6 +123,7 @@ def maximum(x, y): """ return math.max(x, y) + # max and argmax @@ -149,8 +150,7 @@ def _argmax_combine_tie_break_fast(value1, index1, value2, index2): @jit -@core._add_reduction_docstr("maximum", - return_indices_arg="return_indices", +@core._add_reduction_docstr("maximum", return_indices_arg="return_indices", tie_break_arg="return_indices_tie_break_left") def max(input, axis=None, return_indices=False, return_indices_tie_break_left=True): input = core._promote_reduction_input(input) @@ -175,6 +175,7 @@ def argmax(input, axis, tie_break_left=True): (_, ret) = max(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left) return ret + # min and argmin @@ -201,8 +202,7 @@ def _argmin_combine_tie_break_fast(value1, index1, value2, index2): @jit -@core._add_reduction_docstr("minimum", - return_indices_arg="return_indices", +@core._add_reduction_docstr("minimum", return_indices_arg="return_indices", tie_break_arg="return_indices_tie_break_left") def min(input, axis=None, return_indices=False, return_indices_tie_break_left=True): input = core._promote_reduction_input(input) @@ -222,8 +222,7 @@ def min(input, axis=None, return_indices=False, return_indices_tie_break_left=Tr @jit -@core._add_reduction_docstr("minimum index", - tie_break_arg="tie_break_left") +@core._add_reduction_docstr("minimum index", tie_break_arg="tie_break_left") def argmin(input, axis, tie_break_left=True): _, ret = min(input, axis, return_indices=True, return_indices_tie_break_left=tie_break_left) return ret @@ -233,6 +232,7 @@ def argmin(input, axis, tie_break_left=True): def _sum_combine(a, b): return a + b + # sum @@ -247,6 +247,7 @@ def sum(input, axis=None): def _xor_combine(a, b): return a ^ b + # xor sum @@ -258,8 +259,8 @@ def xor_sum(input, axis=None, _builder=None, _generator=None): raise ValueError("xor_sum only supported for integers") input = core._promote_reduction_input(input, _builder=_builder) - return core.reduce(input, axis, _xor_combine, - _builder=_builder, _generator=_generator) + return core.reduce(input, axis, _xor_combine, _builder=_builder, _generator=_generator) + # cumsum @@ -271,6 +272,7 @@ def cumsum(input, axis=0): input = core._promote_reduction_input(input) return core.associative_scan(input, axis, _sum_combine) + # cumprod diff --git a/python/triton/ops/blocksparse/matmul.py b/python/triton/ops/blocksparse/matmul.py index eaf4f2f40dee..098e1543809e 100644 --- a/python/triton/ops/blocksparse/matmul.py +++ b/python/triton/ops/blocksparse/matmul.py @@ -17,15 +17,14 @@ 'EVEN_K': lambda nargs: nargs['K'] % nargs['TILE_K'] == 0, }) @jit -def _sdd_kernel( - A, B, C, - stride_za, stride_ha, stride_ma, stride_ak, - stride_zb, stride_hb, stride_bk, stride_nb, - stride_zc, stride_hc, stride_mc, stride_nc, - K, grid_offset, lut, - TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, - BLOCK: tl.constexpr, EVEN_K: tl.constexpr -): +def _sdd_kernel(A, B, C, # + stride_za, stride_ha, stride_ma, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_nb, # + stride_zc, stride_hc, stride_mc, stride_nc, # + K, grid_offset, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + BLOCK: tl.constexpr, EVEN_K: tl.constexpr # + ): # ------------ # # - Prologue - # # ------------ # @@ -104,13 +103,13 @@ def sdd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, widths, out= c = out grid = [c.shape[1], 1, c.shape[0]] _sdd_kernel[grid]( - a, b, c, - a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), - b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), - c.stride(0), c.stride(1), c.stride(2), c.stride(3), - Ka, 0, lut, - TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, - num_warps=4, + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(2), c.stride(3), # + Ka, 0, lut, # + TILE_M=block, TILE_N=block, TILE_K=32, BLOCK=block, num_stages=4, # + num_warps=4 # ) return c @@ -120,6 +119,7 @@ def sdd_lut(layout, block, device): lut = lut.contiguous() return lut, None + # ----------------------------- # Dense = Sparse x Dense (DSD) # This operation uses a look-up table that contains pre-computed pointer increments @@ -128,15 +128,14 @@ def sdd_lut(layout, block, device): @jit -def _dsd_kernel( - A, B, C, - stride_az, stride_ha, stride_am, stride_ak, - stride_zb, stride_hb, stride_bk, stride_bn, - stride_zc, stride_hc, stride_cm, stride_cn, - DS0, DS1, lut, - TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr -): +def _dsd_kernel(A, B, C, # + stride_az, stride_ha, stride_am, stride_ak, # + stride_zb, stride_hb, stride_bk, stride_bn, # + stride_zc, stride_hc, stride_cm, stride_cn, # + DS0, DS1, lut, # + TILE_M: tl.constexpr, TILE_N: tl.constexpr, TILE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, BLOCK: tl.constexpr # + ): # ------------ # # - Prologue - # # ------------ # @@ -229,13 +228,13 @@ def dsd_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=N # compute output grid = lambda meta: [cdiv(BS3, meta['TILE_N']), width, BS0] _dsd_kernel[grid]( - a, b, c, - a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), - b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), - c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), - BS3, AS1, lut, - TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, - num_warps=4, GROUP_SIZE_M=4, + a, b, c, # + a.stride(0), a.stride(1), a.stride(3 if trans_a else 2), a.stride(2 if trans_a else 3), # + b.stride(0), b.stride(1), b.stride(3 if trans_b else 2), b.stride(2 if trans_b else 3), # + c.stride(0), c.stride(1), c.stride(3 if trans_c else 2), c.stride(2 if trans_c else 3), # + BS3, AS1, lut, # + TILE_M=block, TILE_N=TILE_N, TILE_K=min(block, 32), BLOCK=block, num_stages=4, # + num_warps=4, GROUP_SIZE_M=4 # ) # exit() return c @@ -337,6 +336,7 @@ def dsd_lut(layout, block, step, trans, device): # create locks return lut, width + # ----------------------------- # Dense = Dense x Sparse (DDS) # ----------------------------- @@ -346,6 +346,7 @@ def dsd_lut(layout, block, step, trans, device): def dds_matmul(a, b, trans_a, trans_b, trans_c, spdims, block, lut, width, out=None): return dsd_matmul(b, a, not trans_b, not trans_a, not trans_c, spdims, block, lut, width, out=out) + ############## # MAIN API # ############## @@ -356,10 +357,8 @@ class _matmul(torch.autograd.Function): fn = {'sdd': sdd_matmul, 'dsd': dsd_matmul, 'dds': dds_matmul} @staticmethod - def forward( - ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, - c_lut, c_width, da_lut, da_width, db_lut, db_width, out - ): + def forward(ctx, a, b, trans_a, trans_b, trans_c, mode, spdims, block, c_lut, c_width, da_lut, da_width, db_lut, + db_width, out): c = _matmul.fn[mode](a, b, trans_a, trans_b, trans_c, spdims, block, c_lut, c_width, out=out) # save for backward ctx.save_for_backward(a, b) @@ -385,15 +384,13 @@ def backward(ctx, dc): # gradients w.r.t. a if ctx.needs_input_grad[0]: mode_da = mode[1] + mode[0] + mode[2] - da = _matmul.fn[mode_da]( - dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, ctx.da_lut, ctx.da_width, - ) + da = _matmul.fn[mode_da](dc, b, ctx.trans_c, not ctx.trans_b, ctx.trans_a, ctx.spdims, ctx.block, + ctx.da_lut, ctx.da_width) # gradients w.r.t. b if ctx.needs_input_grad[1]: mode_db = mode[2] + mode[1] + mode[0] - db = _matmul.fn[mode_db]( - a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, ctx.db_lut, ctx.db_width, - ) + db = _matmul.fn[mode_db](a, dc, not ctx.trans_a, ctx.trans_c, ctx.trans_b, ctx.spdims, ctx.block, + ctx.db_lut, ctx.db_width) dout = dc if ctx.has_out else None return da, db, None, None, None, \ None, None, None, None, \ @@ -427,11 +424,9 @@ def __init__(self, layout, block, mode, device, trans_a=False, trans_b=False, tr self.db_lut, self.db_width = sdd_lut(layout, block, device) def __call__(self, a, b, out=None): - c = _matmul.apply( - a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, - self.c_lut, self.c_width, - self.da_lut, self.da_width, - self.db_lut, self.db_width, - out - ) + c = _matmul.apply(a, b, self.trans_a, self.trans_b, self.trans_c, self.mode, self.spdims, self.block, # + self.c_lut, self.c_width, # + self.da_lut, self.da_width, # + self.db_lut, self.db_width, # + out) return c diff --git a/python/triton/ops/blocksparse/softmax.py b/python/triton/ops/blocksparse/softmax.py index c045b11a539c..bcffff26bb51 100644 --- a/python/triton/ops/blocksparse/softmax.py +++ b/python/triton/ops/blocksparse/softmax.py @@ -18,14 +18,13 @@ def num_warps(n): @jit -def _blocksparse_softmax_fwd( - Out, A, stride_xz, LUT, - R, extent, stride_zr, stride_hr, # relative attention - scale, is_causal, - ROW_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - IS_DENSE: tl.constexpr, -): +def _blocksparse_softmax_fwd(Out, A, stride_xz, LUT, # + R, extent, stride_zr, stride_hr, # relative attention + scale, is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr # + ): h = tl.program_id(0) m = tl.program_id(1) z = tl.program_id(2) @@ -73,18 +72,16 @@ def _blocksparse_softmax_fwd( @jit -def _blocksparse_softmax_bwd( - DA, stride_zdx, - DOut, stride_zdout, - Out, stride_zout, - scale, - LUT, - DR, extent, stride_zr, stride_hr, stride_er, - is_causal, - ROW_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - IS_DENSE: tl.constexpr, -): +def _blocksparse_softmax_bwd(DA, stride_zdx, # + DOut, stride_zdout, # + Out, stride_zout, # + scale, # + LUT, # + DR, extent, stride_zr, stride_hr, stride_er, # + is_causal, # + ROW_SIZE: tl.constexpr, # + BLOCK_SIZE: tl.constexpr, # + IS_DENSE: tl.constexpr): h = tl.program_id(0) m = tl.program_id(1) z = tl.program_id(2) @@ -133,6 +130,7 @@ def _blocksparse_softmax_bwd( class _softmax(torch.autograd.Function): + @staticmethod def make_lut(layout, block, device): _empty = torch.tensor([], dtype=torch.int64, device=layout.device) @@ -151,10 +149,7 @@ def make_lut(layout, block, device): return lut, int(total_sizes.max()) @staticmethod - def forward( - ctx, a, scale, rel_logits, is_causal, - spdims, block, lut, maxlut, is_dense - ): + def forward(ctx, a, scale, rel_logits, is_causal, spdims, block, lut, maxlut, is_dense): if scale is not None and isinstance(scale, torch.Tensor): assert scale.device.type == "cpu" scale = scale.item() @@ -165,14 +160,14 @@ def forward( # enqueue kernel out = torch.empty_like(a) _blocksparse_softmax_fwd[grid]( - out, a, a.stride(0), lut, - rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn - scale, - is_causal, - BLOCK_SIZE=block, - ROW_SIZE=next_power_of_2(maxlut), - IS_DENSE=is_dense, - num_warps=num_warps(maxlut) + out, a, a.stride(0), lut, # + rel_logits, rel_shape[-1], rel_strides[0], rel_strides[1], # relative attn# + scale, # + is_causal, # + BLOCK_SIZE=block, # + ROW_SIZE=next_power_of_2(maxlut), # + IS_DENSE=is_dense, # + num_warps=num_warps(maxlut) # ) # save to context # ctx.mark_dirty(x) @@ -201,28 +196,23 @@ def backward(ctx, dout): grid = (ctx.spdims[0], ctx.spdims[1] * ctx.block, M) da = torch.empty_like(dout) _blocksparse_softmax_bwd[grid]( - da, da.stride(0), - dout, dout.stride(0), - out, out.stride(0), - ctx.scale, - lut, - dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], - ctx.is_causal, - BLOCK_SIZE=ctx.block, - ROW_SIZE=next_power_of_2(ctx.maxlut), - IS_DENSE=ctx.is_dense, - num_warps=num_warps(ctx.maxlut) + da, da.stride(0), # + dout, dout.stride(0), # + out, out.stride(0), # + ctx.scale, # + lut, # + dr, ctx.rel_shape[-1], ctx.rel_strides[0], ctx.rel_strides[1], ctx.rel_strides[2], # + ctx.is_causal, # + BLOCK_SIZE=ctx.block, # + ROW_SIZE=next_power_of_2(ctx.maxlut), # + IS_DENSE=ctx.is_dense, # + num_warps=num_warps(ctx.maxlut) # ) - return (da, None, None, dr, None, - None, None, None, None, None, - None, - None, None, None, - None, - None, None, None - ) + return (da, None, None, dr, None, None, None, None, None, None, None, None, None, None, None, None, None, None) class softmax: + def __init__(self, layout, block, device, is_dense=False): self.spdims = layout.shape self.layout = layout @@ -233,8 +223,6 @@ def __init__(self, layout, block, device, is_dense=False): def __call__(self, a, *, scale=1.0, rel_logits=None, is_causal=False): if rel_logits is not None and rel_logits.dtype != a.dtype: raise ValueError(f"relative position embedding must be {a.dtype}") - a = _softmax.apply( - a, scale, rel_logits, is_causal, - self.spdims, self.block, self.lut, self.maxlut, self.is_dense, - ) + a = _softmax.apply(a, scale, rel_logits, is_causal, self.spdims, self.block, self.lut, self.maxlut, + self.is_dense) return a diff --git a/python/triton/ops/cross_entropy.py b/python/triton/ops/cross_entropy.py index 5b0bcf8367bd..88e8dae50db0 100644 --- a/python/triton/ops/cross_entropy.py +++ b/python/triton/ops/cross_entropy.py @@ -59,6 +59,7 @@ def _backward(PROBS, IDX, DPROBS, N, BLOCK: tl.constexpr): class _cross_entropy(torch.autograd.Function): + @classmethod def forward(cls, ctx, logits, indices): # make sure we can use triton diff --git a/python/triton/ops/flash_attention.py b/python/triton/ops/flash_attention.py index c4f725d43402..d024ba7ab005 100644 --- a/python/triton/ops/flash_attention.py +++ b/python/triton/ops/flash_attention.py @@ -15,22 +15,19 @@ @jit -def _fwd_kernel( - # fmt: off - Q, K, V, sm_scale, - L, - Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, N_CTX, - Z_H_N_CTX, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - IS_CAUSAL: tl.constexpr, - # fmt: on -): +def _fwd_kernel(Q, K, V, sm_scale, # + L, # + Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, N_CTX, # + Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + IS_CAUSAL: tl.constexpr # + ): start_m = tl.program_id(0) off_hz = tl.program_id(1) qvk_offset = off_hz * stride_qh @@ -132,27 +129,24 @@ def _bwd_preprocess( @jit -def _bwd_kernel_one_col_block( - # fmt: off - Q, K, V, sm_scale, qk_scale, - Out, DO, - DQ, DK, DV, - L, - D, - Q_block_ptr, K_block_ptr, V_block_ptr, - DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, - stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - Z, H, N_CTX, - off_h, off_z, off_hz, start_n, num_block, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - CAUSAL: tl.constexpr, - MMA_V3: tl.constexpr - # fmt: on -): +def _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): if CAUSAL: lo = start_n * BLOCK_M else: @@ -235,26 +229,23 @@ def _bwd_kernel_one_col_block( @jit -def _bwd_kernel( - # fmt: off - Q, K, V, sm_scale, - Out, DO, - DQ, DK, DV, - L, - D, - stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - Z, H, N_CTX, - Z_H_N_CTX, - SQ_Z_H_N_CTX, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - SEQUENCE_PARALLEL: tl.constexpr, - CAUSAL: tl.constexpr, - MMA_V3: tl.constexpr - # fmt: on -): +def _bwd_kernel(Q, K, V, sm_scale, # + Out, DO, # + DQ, DK, DV, # + L, # + D, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + Z_H_N_CTX, # + SQ_Z_H_N_CTX, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + SEQUENCE_PARALLEL: tl.constexpr, # + CAUSAL: tl.constexpr, # + MMA_V3: tl.constexpr # + ): qk_scale = sm_scale * 1.44269504 off_hz = tl.program_id(0) off_z = off_hz // H @@ -331,51 +322,46 @@ def _bwd_kernel( num_block_n = tl.cdiv(N_CTX, BLOCK_N) if not SEQUENCE_PARALLEL: for start_n in range(0, num_block_n): - _bwd_kernel_one_col_block( - # fmt: off - Q, K, V, sm_scale, qk_scale, Out, DO, - DQ, DK, DV, - L, - D, - Q_block_ptr, K_block_ptr, V_block_ptr, - DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, - stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - Z, H, N_CTX, - off_h, off_z, off_hz, start_n, num_block_n, - BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK_N, - SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, - CAUSAL=CAUSAL, - MMA_V3=MMA_V3 - # fmt: on - ) + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) else: start_n = tl.program_id(1) - _bwd_kernel_one_col_block( - # fmt: off - Q, K, V, sm_scale, qk_scale, Out, DO, - DQ, DK, DV, - L, - D, - Q_block_ptr, K_block_ptr, V_block_ptr, - DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, - stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - Z, H, N_CTX, - off_h, off_z, off_hz, start_n, num_block_n, - BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK_N, - SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, - CAUSAL=CAUSAL, - MMA_V3=MMA_V3 - # fmt: on - ) + _bwd_kernel_one_col_block(Q, K, V, sm_scale, qk_scale, Out, DO, # + DQ, DK, DV, # + L, # + D, # + Q_block_ptr, K_block_ptr, V_block_ptr, # + DO_block_ptr, DQ_block_ptr, DK_block_ptr, DV_block_ptr, # + stride_dqa, stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vn, stride_vk, # + Z, H, N_CTX, # + off_h, off_z, off_hz, start_n, num_block_n, # + BLOCK_M=BLOCK_M, BLOCK_DMODEL=BLOCK_DMODEL, # + BLOCK_N=BLOCK_N, # + SEQUENCE_PARALLEL=SEQUENCE_PARALLEL, # + CAUSAL=CAUSAL, # + MMA_V3=MMA_V3 # + ) class _attention(torch.autograd.Function): + @staticmethod def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): # only support for Ampere now @@ -393,21 +379,19 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 _fwd_kernel[grid]( - # fmt: off - q, k, v, sm_scale, - L, - o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], q.shape[2], - q.shape[0] * q.shape[1] * q.shape[2], - BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, - IS_CAUSAL=causal, - num_warps=num_warps, - num_stages=4, - # fmt: on + q, k, v, sm_scale, # + L, # + o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, # + IS_CAUSAL=causal, # + num_warps=num_warps, # + num_stages=4 # ) ctx.save_for_backward(q, k, v, o, L) @@ -429,14 +413,14 @@ def backward(ctx, do): do = do.contiguous() if sequence_parallel: replicas = cdiv(seq_len_kv, BLOCK) - new_dq_shape = (replicas,) + q.shape + new_dq_shape = (replicas, ) + q.shape dq = torch.zeros(new_dq_shape, device=q.device, dtype=q.dtype) else: dq = torch.zeros_like(q, dtype=q.dtype) dk = torch.empty_like(k) dv = torch.empty_like(v) delta = torch.empty_like(L) - _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1],)]( + _bwd_preprocess[(cdiv(q.shape[2], BLOCK) * ctx.grid[1], )]( o, do, delta, @@ -444,26 +428,24 @@ def backward(ctx, do): D_HEAD=ctx.BLOCK_DMODEL, ) _bwd_kernel[(ctx.grid[1], cdiv(seq_len_kv, BLOCK) if sequence_parallel else 1)]( - # fmt: off - q, k, v, ctx.sm_scale, - o, do, - dq, dk, dv, - L, - delta, - o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - q.shape[0], q.shape[1], q.shape[2], - q.shape[0] * q.shape[1] * q.shape[2], - cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], - BLOCK_M=BLOCK, BLOCK_N=BLOCK, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, - SEQUENCE_PARALLEL=sequence_parallel, - CAUSAL=ctx.causal, - MMA_V3=MMA_V3, - num_warps=8, - num_stages=1, - # fmt: on + q, k, v, ctx.sm_scale, # + o, do, # + dq, dk, dv, # + L, # + delta, # + o.numel(), q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + q.shape[0], q.shape[1], q.shape[2], # + q.shape[0] * q.shape[1] * q.shape[2], # + cdiv(seq_len_kv, BLOCK) * q.shape[0] * q.shape[1] * q.shape[2], # + BLOCK_M=BLOCK, BLOCK_N=BLOCK, # + BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + SEQUENCE_PARALLEL=sequence_parallel, # + CAUSAL=ctx.causal, # + MMA_V3=MMA_V3, # + num_warps=8, # + num_stages=1 # ) if len(dq.shape) == 5: diff --git a/python/triton/ops/matmul.py b/python/triton/ops/matmul.py index 9bbeb3650ac5..832e52727f09 100644 --- a/python/triton/ops/matmul.py +++ b/python/triton/ops/matmul.py @@ -37,8 +37,9 @@ def get_configs_io_bound(): num_stages=num_stages, num_warps=num_warps)) # split_k for split_k in [2, 4, 8, 16]: - configs.append(Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, - num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) + configs.append( + Config({'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k}, + num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C'))) return configs @@ -69,22 +70,22 @@ def get_configs_io_bound(): prune_configs_by={ 'early_config_prune': early_config_prune, 'perf_model': estimate_matmul_time, - 'top_k': 10 + 'top_k': 10, }, ) @heuristics({ 'EVEN_K': lambda args: args['K'] % (args['BLOCK_K'] * args['SPLIT_K']) == 0, }) @jit -def _kernel(A, B, C, M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - dot_out_dtype: tl.constexpr, - allow_tf32: tl.constexpr, - fp8_fast_accum: tl.constexpr, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, - GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr +def _kernel(A, B, C, M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + dot_out_dtype: tl.constexpr, # + allow_tf32: tl.constexpr, # + fp8_fast_accum: tl.constexpr, # + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, # + GROUP_M: tl.constexpr, SPLIT_K: tl.constexpr, EVEN_K: tl.constexpr, AB_DTYPE: tl.constexpr # ): # matrix multiplication pid = tl.program_id(0) @@ -184,14 +185,15 @@ def _call(a, b, dot_out_dtype, allow_tf32, fp8_fast_accum): ab_dtype = False # launch kernel grid = lambda META: (cdiv(M, META['BLOCK_M']) * cdiv(N, META['BLOCK_N']), META['SPLIT_K']) - _kernel[grid](a, b, c, M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - dot_out_dtype=dot_out_dtype, - allow_tf32=allow_tf32, - fp8_fast_accum=fp8_fast_accum, - GROUP_M=8, AB_DTYPE=ab_dtype) + _kernel[grid]( + a, b, c, M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + dot_out_dtype=dot_out_dtype, # + allow_tf32=allow_tf32, # + fp8_fast_accum=fp8_fast_accum, # + GROUP_M=8, AB_DTYPE=ab_dtype) return c @staticmethod diff --git a/python/triton/ops/matmul_perf_model.py b/python/triton/ops/matmul_perf_model.py index abe5325ee056..1e07b0a029bb 100644 --- a/python/triton/ops/matmul_perf_model.py +++ b/python/triton/ops/matmul_perf_model.py @@ -5,8 +5,7 @@ from .. import cdiv from .._C.libtriton.triton import runtime from ..runtime import driver -from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, - nvsmi) +from ..testing import (get_dram_gbps, get_max_simd_tflops, get_max_tensorcore_tflops, nvsmi) def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): @@ -14,7 +13,8 @@ def get_tensorcore_tflops(backend, device, num_ctas, num_warps, dtype): total_warps = num_ctas * min(num_warps, 4) num_subcores = driver.utils.get_device_properties(device)["multiprocessor_count"] * 4 # on recent GPUs cur_sm_clock = nvsmi(['clocks.current.sm'])[0] - tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops(dtype, cur_sm_clock, backend, device) + tflops = min(num_subcores, total_warps) / num_subcores * get_max_tensorcore_tflops( + dtype, cur_sm_clock, backend, device) return tflops @@ -35,12 +35,12 @@ def get_tflops(backend, device, num_ctas, num_warps, dtype): def estimate_matmul_time( - # backend, device, - num_warps, num_stages, - A, B, C, - M, N, K, - BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, - debug=False, **kwargs + # backend, device, + num_warps, num_stages, # + A, B, C, # + M, N, K, # + BLOCK_M, BLOCK_N, BLOCK_K, SPLIT_K, # + debug=False, **kwargs # ): ''' return estimated running time in ms = max(compute, loading) + store ''' @@ -149,8 +149,9 @@ def early_config_prune(configs, named_args): optimal_num_stages = ldgsts_latency / mma_cycles # nearest stages, prefer large #stages - nearest = heapq.nsmallest(2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) - if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) + nearest = heapq.nsmallest( + 2, v, key=lambda x: 10 + abs(x[1] - optimal_num_stages) + if (x[1] - optimal_num_stages) < 0 else x[1] - optimal_num_stages) for n in nearest: pruned_configs.append(n[0]) diff --git a/python/triton/runtime/__init__.py b/python/triton/runtime/__init__.py index b13d089a10a3..e785018e0f92 100644 --- a/python/triton/runtime/__init__.py +++ b/python/triton/runtime/__init__.py @@ -1,5 +1,4 @@ -from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, - heuristics) +from .autotuner import (Autotuner, Config, Heuristics, OutOfResources, autotune, heuristics) from .driver import driver from .jit import JITFunction, KernelInterface, MockTensor, TensorWrapper, reinterpret diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 7deca37115b3..7e4c1d7e243b 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -9,11 +9,10 @@ class OutOfResources(Exception): + def __init__(self, required, limit, name): - self.message = ( - f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " - + "Reducing block sizes or `num_stages` may help." - ) + self.message = (f"out of resource: {name}, Required: {required}, Hardware limit: {limit}. " + + "Reducing block sizes or `num_stages` may help.") self.required = required self.limit = limit self.name = name @@ -25,6 +24,7 @@ def __reduce__(self): class Autotuner(KernelInterface): + def __init__( self, fn, @@ -99,10 +99,8 @@ def _bench(self, *args, config, **meta): # as kwargs and by the autotuner conflicts = meta.keys() & config.kwargs.keys() if conflicts: - raise ValueError( - f"Conflicting meta-parameters: {', '.join(conflicts)}." - " Make sure that you don't re-define auto-tuned symbols." - ) + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") # augment meta-parameters with tunable ones current = dict(meta, **config.kwargs) full_nargs = {**self.nargs, **current} @@ -179,7 +177,8 @@ def prune_configs(self, kwargs): top_k = int(len(self.configs) * top_k) if len(pruned_configs) > top_k: est_timing = { - config: self.perf_model( + config: + self.perf_model( **self.nargs, **kwargs, **config.kwargs, @@ -296,6 +295,7 @@ def decorator(fn): class Heuristics(KernelInterface): + def __init__(self, fn, arg_names, values) -> None: self.fn = fn self.values = values diff --git a/python/triton/runtime/cache.py b/python/triton/runtime/cache.py index 170f614aa9b5..3799eaf08b43 100644 --- a/python/triton/runtime/cache.py +++ b/python/triton/runtime/cache.py @@ -19,6 +19,7 @@ def default_dump_dir(): class CacheManager(ABC): + def __init__(self, key): pass @@ -44,6 +45,7 @@ def put_group(self, filename: str, group: Dict[str, str]): class FileCacheManager(CacheManager): + def __init__(self, key, override=False, dump=False): self.key = key self.lock_path = None diff --git a/python/triton/runtime/driver.py b/python/triton/runtime/driver.py index f55619e01066..767a567c452b 100644 --- a/python/triton/runtime/driver.py +++ b/python/triton/runtime/driver.py @@ -26,6 +26,7 @@ def __init__(self) -> None: class CudaUtils(object): + def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super(CudaUtils, cls).__new__(cls) @@ -65,6 +66,7 @@ def __init__(self): class CudaDriver(DriverBase): + def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super(CudaDriver, cls).__new__(cls) @@ -81,6 +83,7 @@ def __init__(self): class HIPUtils(object): + def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super(HIPUtils, cls).__new__(cls) @@ -111,6 +114,7 @@ def __init__(self): class HIPDriver(DriverBase): + def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super(HIPDriver, cls).__new__(cls) @@ -122,6 +126,7 @@ def __init__(self): class UnsupportedDriver(DriverBase): + def __new__(cls): if not hasattr(cls, "instance"): cls.instance = super(UnsupportedDriver, cls).__new__(cls) @@ -138,6 +143,7 @@ def __init__(self): class LazyProxy: + def __init__(self, init_fn): self._init_fn = init_fn self._obj = None diff --git a/python/triton/runtime/errors.py b/python/triton/runtime/errors.py index f892a91722d9..a5d69aba685e 100644 --- a/python/triton/runtime/errors.py +++ b/python/triton/runtime/errors.py @@ -1,4 +1,5 @@ class OutOfResources(Exception): + def __init__(self, required, limit, name): self.message = f"out of resource: {name}, " f"Required: {required}, " f"Hardware limit: {limit}" self.message += ". Reducing block sizes or `num_stages` may help." diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index 5599a59ede26..c8b70bfb07d9 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -37,6 +37,7 @@ def str_to_ty(name): class TensorHandle: + def __init__(self, data, dtype): self.data = data self.dtype = dtype @@ -46,6 +47,7 @@ def __bool__(self): class BlockPointerHandle: + def __init__(self, base, shape, strides, offsets, tensor_shape, order): self.base = base self.shape = shape @@ -72,7 +74,9 @@ def materialize_pointers(self, boundary_check): def wrap_ret(compute_ret_ty): + def wrapper(fn): + def wrapped(*args, **kwargs): ret = fn(*args, **kwargs) return TensorHandle(ret.data, compute_ret_ty(*args, **kwargs)) @@ -83,6 +87,7 @@ def wrapped(*args, **kwargs): class Builder: + def __init__(self) -> None: self.arch = None # pass @@ -280,9 +285,8 @@ def create_addptr(self, ptr, offset): dtype_tt = ptr.dtype.element_ty return TensorHandle(ptr.data + (dtype_tt.primitive_bitwidth // 8) * offset.data.astype(np.uint64), ptr.dtype) - def create_tensor_pointer_load( - self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, is_volatile - ): + def create_tensor_pointer_load(self, ptr, boundary_check, padding_option, cache_modifier, eviction_policy, + is_volatile): ptrs, masks = ptr.materialize_pointers(boundary_check) assert padding_option is None other = None @@ -364,9 +368,10 @@ def create_advance(self, ptr, offsets): def patch_attr(obj, name, member, builder): - new_member = lambda *args, member=member, **kwargs: ( - member(*args, **{k: v for k, v in kwargs.items() if k != "_builder"}, _builder=builder) - ) + new_member = lambda *args, member=member, **kwargs: (member(*args, ** + {k: v + for k, v in kwargs.items() + if k != "_builder"}, _builder=builder)) setattr(obj, name, new_member) @@ -412,6 +417,7 @@ def _patch_lang_math(lang, builder): } def make_numpy(name): + def impl(*args, **kwargs): ret_type = args[0].type # TODO: incorrect ret_dtype = args[0].dtype # TODO: incorrect @@ -424,14 +430,13 @@ def impl(*args, **kwargs): return impl def make_fallback(name): + def fallback(*args, **kwargs): - raise NotImplementedError( - f""" + raise NotImplementedError(f""" {name} not supported in interpreter mode: no known numpy implementation. If you think that {name} in fact does have a numpy implementation, please add it to the mapping in python/triton/interpreter/new_interpreter.py:_patch_lang_math. -""" - ) +""") return fallback @@ -467,6 +472,7 @@ def _unwrap(tensor): class GridExecutor: + def __init__(self, fn, arg_names, grid): from .jit import _normalize_ty # TODO: modularize @@ -496,7 +502,7 @@ def __call__(self, *args_dev, **kwargs): # iterate through grid grid = self.grid(args) if callable(self.grid) else self.grid assert len(grid) <= 3 - grid = grid + (1,) * (3 - len(grid)) + grid = grid + (1, ) * (3 - len(grid)) builder.set_grid_dim(*grid) for x in range(grid[0]): for y in range(grid[1]): @@ -510,6 +516,7 @@ def __call__(self, *args_dev, **kwargs): class InterpretedFunction: + def _patch_lang(self, builder): lang = [value for _, value in self.fn.__globals__.items() if value in [tl, tl.core]] assert len(lang) == 1, "triton.language must be visible from within jit'd function" diff --git a/python/triton/runtime/jit.py b/python/triton/runtime/jit.py index 92dafaa947e4..998fe7475e77 100644 --- a/python/triton/runtime/jit.py +++ b/python/triton/runtime/jit.py @@ -72,9 +72,8 @@ def visit_Attribute(self, node): lhs = self.visit(node.value) while isinstance(lhs, ast.Attribute): lhs = self.visit(lhs.value) - if lhs is None or ( - getattr(lhs, "__name__", "") == "triton" or getattr(lhs, "__name__", "").endswith(".triton") - ): + if lhs is None or (getattr(lhs, "__name__", "") == "triton" + or getattr(lhs, "__name__", "").endswith(".triton")): return None return getattr(lhs, node.attr) @@ -176,7 +175,7 @@ def specialization_key(self): assert not self.param.do_not_specialize try: - return (self.value.data_ptr() % JITFunction.divisibility == 0,) + return (self.value.data_ptr() % JITFunction.divisibility == 0, ) except AttributeError: pass @@ -188,7 +187,7 @@ def specialization_key(self): self.value == 1, ) - return (False,) + return (False, ) class KernelInterface(Generic[T]): @@ -253,10 +252,11 @@ def _spec_of(arg): return arg.data_ptr() % JITFunction.divisibility == 0 elif isinstance(arg, int): return (arg % 16 == 0, arg == 1) - return (arg is None,) + return (arg is None, ) # TODO(jlebar): Fold this into the KernelArg class. def _get_config(self, *args): + def is_divisible_by_16(x): if hasattr(x, "data_ptr"): return x.data_ptr() % JITFunction.divisibility == 0 @@ -279,7 +279,9 @@ def is_divisible_by_8(x): if is_divisible_by_16(arg) and not param.do_not_specialize } divisible_by_8 = { - param.num for param, arg in zip(self.params, args) if is_divisible_by_8(arg) and not param.do_not_specialize + param.num + for param, arg in zip(self.params, args) + if is_divisible_by_8(arg) and not param.do_not_specialize } equal_to_1 = { param.num @@ -290,9 +292,10 @@ def is_divisible_by_8(x): # TODO: method to collect all folded args none_args = {param.num for param, arg in zip(self.params, args) if arg is None and not param.do_not_specialize} ids_of_folded_args = equal_to_1 | none_args - return namedtuple( - "instance_descriptor", ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"] - )(tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), tuple(divisible_by_8)) + return namedtuple("instance_descriptor", + ["divisible_by_16", "equal_to_1", "ids_of_folded_args", "divisible_by_8"])( # + tuple(divisible_by_16), tuple(equal_to_1), tuple(ids_of_folded_args), + tuple(divisible_by_8)) # return _triton.code_gen.instance_descriptor(divisible_by_16, # equal_to_1) @@ -356,6 +359,7 @@ def _call_hook( key = str(key) class LegacyCompiler: + def __init__(self, module, name): self.module = module self.name = name @@ -449,9 +453,8 @@ def get_special_arg(name: str, default=None): if device_type is None: device_types = [self._device_of(arg) for arg in non_constexpr_arg_values] device_types = [_device_type for _device_type in device_types if _device_type != ""] - device_type = self._conclude_device_type( - device_types, [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values] - ) + device_type = self._conclude_device_type(device_types, + [self._pinned_memory_of(arg) for arg in non_constexpr_arg_values]) device_backend = None if device_type not in ["cuda"]: @@ -498,7 +501,7 @@ def get_special_arg(name: str, default=None): # Kernel is not cached; we have to compile. if key not in self.cache[device]: - configs = (self._get_config(*[arg.value for arg in args]),) + configs = (self._get_config(*[arg.value for arg in args]), ) constants = { arg.param.num: arg.value for arg in args @@ -510,21 +513,23 @@ def get_special_arg(name: str, default=None): # Build kernel signature -- doesn't include constexpr arguments. signature = { - arg.param.num: self._type_of(self._key_of(arg.value)) for arg in args if not arg.param.is_constexpr + arg.param.num: self._type_of(self._key_of(arg.value)) + for arg in args + if not arg.param.is_constexpr } if self._call_hook( - key, - signature, - device, - constants, - num_warps, - num_ctas, - num_stages, - enable_warp_specialization, - enable_fp_fusion, - extern_libs, - configs, + key, + signature, + device, + constants, + num_warps, + num_ctas, + num_stages, + enable_warp_specialization, + enable_fp_fusion, + extern_libs, + configs, ): return None @@ -581,7 +586,7 @@ def __init__(self, fn, version=None, do_not_specialize=None, debug=None, noinlin # function source code (without decorators) self.src = textwrap.dedent(inspect.getsource(fn)) - self.src = self.src[self.src.find("def") :] + self.src = self.src[self.src.find("def"):] # cache of just-in-time compiled kernels self.cache = defaultdict(dict) self.hash = None @@ -734,6 +739,7 @@ def data_ptr(): class TensorWrapper: + def __init__(self, base, dtype): self.dtype = dtype self.base = base diff --git a/python/triton/testing.py b/python/triton/testing.py index 27da9b684708..848db7c1b508 100644 --- a/python/triton/testing.py +++ b/python/triton/testing.py @@ -78,10 +78,7 @@ def do_bench_cudagraph(fn, rep=20, grad_to_none=None): return torch.mean(torch.tensor(ret)).item() -def do_bench(fn, warmup=25, rep=100, grad_to_none=None, - quantiles=None, - fast_flush=True, - return_mode="mean"): +def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flush=True, return_mode="mean"): assert return_mode in ["min", "max", "mean", "median"] import torch """ @@ -261,6 +258,7 @@ def __init__( class Mark: + def __init__(self, fn, benchmarks): self.fn = fn self.benchmarks = benchmarks @@ -405,12 +403,15 @@ def get_max_tensorcore_tflops(dtype, clock_rate, backend=None, device=None): tflops = num_subcores * clock_rate * ops_per_sub_core * 1e-9 return tflops + # create decorator that wraps test function into # a cuda-memcheck system call def cuda_memcheck(**target_kwargs): + def decorator(test_fn): + @functools.wraps(test_fn) def wrapper(*args, **kwargs): import psutil @@ -428,7 +429,9 @@ def wrapper(*args, **kwargs): assert "ERROR SUMMARY: 0 errors" in str(out.stdout) else: test_fn(*args, **kwargs) + return wrapper + return decorator @@ -436,22 +439,18 @@ def wrapper(*args, **kwargs): def set_gpu_clock(ref_sm_clock=1350, ref_mem_clock=1215): try: subprocess.check_output(["nvidia-smi", "-i", "0", "-pm", "1"]) - subprocess.check_output( - [ - "nvidia-smi", - "-i", - "0", - f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", - ] - ) - subprocess.check_output( - [ - "nvidia-smi", - "-i", - "0", - f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", - ] - ) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-gpu-clocks={ref_sm_clock},{ref_sm_clock}", + ]) + subprocess.check_output([ + "nvidia-smi", + "-i", + "0", + f"--lock-memory-clocks={ref_mem_clock},{ref_mem_clock}", + ]) cur_sm_clock = nvsmi(["clocks.current.sm"])[0] cur_mem_clock = nvsmi(["clocks.current.memory"])[0] assert abs(cur_sm_clock - ref_sm_clock) < 10, f"GPU SMs must run at {ref_sm_clock} MHz" diff --git a/python/triton/tools/build_extern.py b/python/triton/tools/build_extern.py index f19fbd561c07..6f00e8192593 100644 --- a/python/triton/tools/build_extern.py +++ b/python/triton/tools/build_extern.py @@ -141,8 +141,7 @@ def generate_stub_file(self, output_dir) -> None: f.write(file_str) f.close() if self._format: - subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], - stdout=subprocess.PIPE).communicate() + subprocess.Popen(["autopep8", "-a", "-r", "-i", output_file], stdout=subprocess.PIPE).communicate() subprocess.Popen(["isort", output_file], stdout=subprocess.PIPE).communicate() @@ -208,56 +207,36 @@ def _group_symbols(self) -> None: # Group functions together by renaming. renaming = { - 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', - 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': 'add_rn', - 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', - 'dadd_rz': 'add_rz', 'fadd_rz': 'add_rz', 'asinf': 'asin', - 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', - 'atanhf': 'atanh', 'brevll': 'brev', 'cbrtf': 'cbrt', - 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', - 'cosf': 'cos', 'coshf': 'cosh', 'cospif': 'cospi', - 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', - 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', - 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', 'ddiv_ru': 'div_ru', - 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', - 'erfcf': 'erfc', 'erfcinvf': 'erfcinv', 'erfcxf': 'erfcx', - 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', - 'exp2f': 'exp2', 'expm1f': 'expm1', 'fabsf': 'abs', - 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', - 'fdimf': 'fdim', 'ffsll': 'ffs', 'floorf': 'floor', - 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', - 'fmaf_ru': 'fma_ru', 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', - 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', - 'isinff': 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', - 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', - 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', - 'llroundf': 'llround', 'logf': 'log', 'log10f': 'log10', - 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', - 'umax': 'max', 'llmax': 'max', 'ullmax': 'max', 'fmaxf': 'max', - 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', - 'fminf': 'min', 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', - 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', 'dmul_ru': 'mul_ru', - 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', - 'umul24': 'mul24', 'umulhi': 'mulhi', 'mul64hi': 'mulhi', - 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': 'nextafter', - 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', - 'normcdfinvf': 'normcdfinv', 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', - 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', 'drcp_rd': 'rcp_rd', - 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', - 'drcp_ru': 'rcp_ru', 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', - 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', - 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', - 'roundf': 'round', 'rsqrtf': 'rsqrt', 'frsqrt_rn': 'rsqrt_rn', - 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', - 'signbitd': 'signbit', 'sinf': 'sin', 'sinhf': 'sinh', - 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', - 'dsqrt_rd': 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', - 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', 'fsqrt_rz': 'sqrt_rz', - 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', - 'fsub_rn': 'sub_rn', 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', - 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', - 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', - 'y0f': 'y0', 'y1f': 'y1', 'ynf': 'yn' + 'llabs': 'abs', 'acosf': 'acos', 'acoshf': 'acosh', 'dadd_rd': 'add_rd', 'fadd_rd': 'add_rd', 'dadd_rn': + 'add_rn', 'fadd_rn': 'add_rn', 'dadd_ru': 'add_ru', 'fadd_ru': 'add_ru', 'dadd_rz': 'add_rz', 'fadd_rz': + 'add_rz', 'asinf': 'asin', 'asinhf': 'asinh', 'atanf': 'atan', 'atan2f': 'atan2', 'atanhf': 'atanh', + 'brevll': 'brev', 'cbrtf': 'cbrt', 'ceilf': 'ceil', 'clzll': 'clz', 'copysignf': 'copysign', 'cosf': 'cos', + 'coshf': 'cosh', 'cospif': 'cospi', 'cyl_bessel_i0f': 'cyl_bessel_i0', 'cyl_bessel_i1f': 'cyl_bessel_i1', + 'fdiv_rd': 'div_rd', 'ddiv_rd': 'div_rd', 'fdiv_rn': 'div_rn', 'ddiv_rn': 'div_rn', 'fdiv_ru': 'div_ru', + 'ddiv_ru': 'div_ru', 'fdiv_rz': 'div_rz', 'ddiv_rz': 'div_rz', 'erff': 'erf', 'erfcf': 'erfc', 'erfcinvf': + 'erfcinv', 'erfcxf': 'erfcx', 'erfinvf': 'erfinv', 'expf': 'exp', 'exp10f': 'exp10', 'exp2f': 'exp2', + 'expm1f': 'expm1', 'fabsf': 'abs', 'fabs': 'abs', 'fast_fdividef': 'fast_dividef', 'fdimf': 'fdim', 'ffsll': + 'ffs', 'floorf': 'floor', 'fmaf': 'fma', 'fmaf_rd': 'fma_rd', 'fmaf_rn': 'fma_rn', 'fmaf_ru': 'fma_ru', + 'fmaf_rz': 'fma_rz', 'fmodf': 'fmod', 'uhadd': 'hadd', 'hypotf': 'hypot', 'ilogbf': 'ilogb', 'isinff': + 'isinf', 'isinfd': 'isinf', 'isnanf': 'isnan', 'isnand': 'isnan', 'j0f': 'j0', 'j1f': 'j1', 'jnf': 'jn', + 'ldexpf': 'ldexp', 'lgammaf': 'lgamma', 'llrintf': 'llrint', 'llroundf': 'llround', 'logf': 'log', 'log10f': + 'log10', 'log1pf': 'log1p', 'log2f': 'log2', 'logbf': 'logb', 'umax': 'max', 'llmax': 'max', 'ullmax': + 'max', 'fmaxf': 'max', 'fmax': 'max', 'umin': 'min', 'llmin': 'min', 'ullmin': 'min', 'fminf': 'min', + 'fmin': 'min', 'dmul_rd': 'mul_rd', 'fmul_rd': 'mul_rd', 'dmul_rn': 'mul_rn', 'fmul_rn': 'mul_rn', + 'dmul_ru': 'mul_ru', 'fmul_ru': 'mul_ru', 'dmul_rz': 'mul_rz', 'fmul_rz': 'mul_rz', 'umul24': 'mul24', + 'umulhi': 'mulhi', 'mul64hi': 'mulhi', 'umul64hi': 'mulhi', 'nearbyintf': 'nearbyint', 'nextafterf': + 'nextafter', 'norm3df': 'norm3d', 'norm4df': 'norm4d', 'normcdff': 'normcdf', 'normcdfinvf': 'normcdfinv', + 'popcll': 'popc', 'powif': 'pow', 'powi': 'pow', 'powf': 'pow', 'rcbrtf': 'rcbrt', 'frcp_rd': 'rcp_rd', + 'drcp_rd': 'rcp_rd', 'frcp_rn': 'rcp_rn', 'drcp_rn': 'rcp_rn', 'frcp_ru': 'rcp_ru', 'drcp_ru': 'rcp_ru', + 'frcp_rz': 'rcp_rz', 'drcp_rz': 'rcp_rz', 'remainderf': 'remainder', 'urhadd': 'rhadd', 'rhypotf': 'rhypot', + 'rintf': 'rint', 'rnorm3df': 'rnorm3d', 'rnorm4df': 'rnorm4d', 'roundf': 'round', 'rsqrtf': 'rsqrt', + 'frsqrt_rn': 'rsqrt_rn', 'usad': 'sad', 'scalbnf': 'scalbn', 'signbitf': 'signbit', 'signbitd': 'signbit', + 'sinf': 'sin', 'sinhf': 'sinh', 'sinpif': 'sinpi', 'sqrtf': 'sqrt', 'fsqrt_rd': 'sqrt_rd', 'dsqrt_rd': + 'sqrt_rd', 'fsqrt_rn': 'sqrt_rn', 'dsqrt_rn': 'sqrt_rn', 'fsqrt_ru': 'sqrt_ru', 'dsqrt_ru': 'sqrt_ru', + 'fsqrt_rz': 'sqrt_rz', 'dsqrt_rz': 'sqrt_rz', 'fsub_rd': 'sub_rd', 'dsub_rd': 'sub_rd', 'fsub_rn': 'sub_rn', + 'dsub_rn': 'sub_rn', 'fsub_ru': 'sub_ru', 'dsub_ru': 'sub_ru', 'fsub_rz': 'sub_rz', 'dsub_rz': 'sub_rz', + 'tanf': 'tan', 'tanhf': 'tanh', 'tgammaf': 'tgamma', 'truncf': 'trunc', 'y0f': 'y0', 'y1f': 'y1', 'ynf': + 'yn' } for symbol in self._symbols.values(): @@ -347,8 +326,7 @@ def __init__(self, path) -> None: self._ll_file = "/tmp/extern_lib.ll" def disasm(self, lib_path: str) -> None: - subprocess.Popen([self._path, lib_path, "-o", self.ll_file], - stdout=subprocess.PIPE).communicate() + subprocess.Popen([self._path, lib_path, "-o", self.ll_file], stdout=subprocess.PIPE).communicate() @property def ll_file(self) -> str: diff --git a/python/triton/tools/compile.py b/python/triton/tools/compile.py index d80f15e8a1aa..a69c7100ddd0 100644 --- a/python/triton/tools/compile.py +++ b/python/triton/tools/compile.py @@ -40,10 +40,13 @@ # command-line arguments parser = ArgumentParser(description=desc) - parser.add_argument("path", help="Path to Python source containing desired kernel in its scope. File will be executed.") - parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", required=True) + parser.add_argument("path", + help="Path to Python source containing desired kernel in its scope. File will be executed.") + parser.add_argument("--kernel-name", "-n", type=str, default="", help="Name of the kernel to compile", + required=True) parser.add_argument("--num-warps", "-w", type=int, default=1, help="Number of warps to launch the kernel") - parser.add_argument("--num-stages", "-ns", type=int, default=3, help="Number of stages (meta-parameter of the kernel)") + parser.add_argument("--num-stages", "-ns", type=int, default=3, + help="Number of stages (meta-parameter of the kernel)") parser.add_argument("--out-name", "-on", type=str, default=None, help="Out name for the compiled kernel") parser.add_argument("--out-path", "-o", type=Path, default=None, help="Out filename") parser.add_argument("--signature", "-s", type=str, help="Signature of the kernel", required=True) @@ -104,7 +107,8 @@ def constexpr(s): config = triton.compiler.instance_descriptor(divisible_by_16=divisible_by_16, equal_to_1=equal_to_1) for i in equal_to_1: constexprs.update({i: 1}) - ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], num_warps=args.num_warps, num_stages=args.num_stages) + ccinfo = triton.compile(kernel, signature=signature, constants=constexprs, configs=[config], + num_warps=args.num_warps, num_stages=args.num_stages) arg_names = [] arg_types = [] for i in signature.keys(): diff --git a/python/triton/tools/link.py b/python/triton/tools/link.py index 68ace442f067..eb39b4bda4db 100644 --- a/python/triton/tools/link.py +++ b/python/triton/tools/link.py @@ -27,13 +27,12 @@ class KernelLinkerMeta: class HeaderParser: + def __init__(self) -> None: import re # [kernel_name, c signature] - self.linker_directives = re.compile( - "//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)" - ) + self.linker_directives = re.compile("//[\\s]*tt-linker:[\\s]*([\\w]+):(.+):(.+)") # [name, hash, suffix] self.kernel_name = re.compile("^([\\w]+)_([\\w]+)_([\\w]+)$") # [(type, name)] @@ -153,9 +152,7 @@ def make_global_decl(meta: KernelLinkerMeta) -> str: # generate dispatcher function for kernels with different meta-parameter and constant values def make_default_algo_kernel(meta: KernelLinkerMeta) -> str: src = f"CUresult {meta.orig_kernel_name}_default(CUstream stream, {gen_signature_with_full_args(meta)}){{\n" - src += ( - f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n" - ) + src += (f" return {meta.orig_kernel_name}(stream, {', '.join(meta.arg_names)}, 0);\n") src += "}\n" return src @@ -167,28 +164,22 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) - src += f"CUresult {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(CUstream stream, {gen_signature(meta)});\n" src += "\n" - src += ( - f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{" - ) + src += (f"CUresult {name}(CUstream stream, {gen_signature_with_full_args(metas[-1])}){{") src += "\n" for meta in sorted(metas, key=lambda m: -m.num_specs): - cond_fn = ( - lambda val, hint: f"({val} % {hint} == 0)" - if hint == 16 - else f"({val} == {hint})" - if hint == 1 - else None - ) - conds = " && ".join( - [ - cond_fn(val, hint) - for val, hint in zip(meta.arg_names, meta.sizes) - if hint is not None - ] - ) - src += ( - f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" - ) # Edge case where no specializations hence no dispatching required + cond_fn = ( # + lambda val, hint: f"({val} % {hint} == 0)" # + if hint == 16 # + else f"({val} == {hint})" # + if hint == 1 # + else None) + conds = " && ".join([ # + cond_fn(val, hint) # + for val, hint in zip(meta.arg_names, meta.sizes) # + if hint is not None + ]) + src += (f" if ({conds})\n" if any(meta.sizes) else "if (1)\n" + ) # Edge case where no specializations hence no dispatching required arg_names = [arg for arg, hint in zip(meta.arg_names, meta.sizes) if hint != 1] src += f" return {meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}(stream, {', '.join(arg_names)});\n" src += "\n" @@ -202,9 +193,7 @@ def make_kernel_hints_dispatcher(name: str, metas: Sequence[KernelLinkerMeta]) - src += f"void {mode}_{name}() {{" src += "\n" for meta in sorted(metas, key=lambda m: -m.num_specs): - src += ( - f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n" - ) + src += (f" {mode}_{meta.orig_kernel_name}_{meta.sig_hash}_{meta.suffix}();\n") src += "}\n" return src @@ -306,10 +295,7 @@ def make_get_num_algos_def(meta: KernelLinkerMeta) -> str: fp.write(out) # generate source - defs = [ - make_kernel_hints_dispatcher(name, meta) - for name, meta in parser.kernels.items() - ] + defs = [make_kernel_hints_dispatcher(name, meta) for name, meta in parser.kernels.items()] names = [name for name in parser.kernels.keys()] func_pointers_def = make_func_pointers(names, meta) meta_const_def = make_kernel_meta_const_dispatcher(meta) diff --git a/python/tutorials/01-vector-add.py b/python/tutorials/01-vector-add.py index 3463ddf1ced1..1c1900a07481 100644 --- a/python/tutorials/01-vector-add.py +++ b/python/tutorials/01-vector-add.py @@ -25,14 +25,13 @@ @triton.jit -def add_kernel( - x_ptr, # *Pointer* to first input vector. - y_ptr, # *Pointer* to second input vector. - output_ptr, # *Pointer* to output vector. - n_elements, # Size of the vector. - BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. - # NOTE: `constexpr` so it can be used as a shape value. -): +def add_kernel(x_ptr, # *Pointer* to first input vector. + y_ptr, # *Pointer* to second input vector. + output_ptr, # *Pointer* to output vector. + n_elements, # Size of the vector. + BLOCK_SIZE: tl.constexpr, # Number of elements each program should process. + # NOTE: `constexpr` so it can be used as a shape value. + ): # There are multiple 'programs' processing different data. We identify which program # we are here: pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. @@ -66,7 +65,7 @@ def add(x: torch.Tensor, y: torch.Tensor): # The SPMD launch grid denotes the number of kernel instances that run in parallel. # It is analogous to CUDA launch grids. It can be either Tuple[int], or Callable(metaparameters) -> Tuple[int]. # In this case, we use a 1D grid where the size is the number of blocks: - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) # NOTE: # - Each torch.tensor object is implicitly converted into a pointer to its first element. # - `triton.jit`'ed functions can be indexed with a launch grid to obtain a callable GPU kernel. @@ -88,10 +87,8 @@ def add(x: torch.Tensor, y: torch.Tensor): output_triton = add(x, y) print(output_torch) print(output_triton) -print( - f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}' -) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') # %% # Seems like we're good to go! @@ -108,9 +105,7 @@ def add(x: torch.Tensor, y: torch.Tensor): @triton.testing.perf_report( triton.testing.Benchmark( x_names=['size'], # Argument names to use as an x-axis for the plot. - x_vals=[ - 2 ** i for i in range(12, 28, 1) - ], # Different possible values for `x_name`. + x_vals=[2**i for i in range(12, 28, 1)], # Different possible values for `x_name`. x_log=True, # x axis is logarithmic. line_arg='provider', # Argument name whose value corresponds to a different line in the plot. line_vals=['triton', 'torch'], # Possible values for `line_arg`. @@ -119,8 +114,7 @@ def add(x: torch.Tensor, y: torch.Tensor): ylabel='GB/s', # Label name for the y-axis. plot_name='vector-add-performance', # Name for the plot. Used also as a file name for saving the plot. args={}, # Values for function arguments not in `x_names` and `y_name`. - ) -) + )) def benchmark(size, provider): x = torch.rand(size, device='cuda', dtype=torch.float32) y = torch.rand(size, device='cuda', dtype=torch.float32) diff --git a/python/tutorials/02-fused-softmax.py b/python/tutorials/02-fused-softmax.py index 13383cc1c783..f2d4c1138586 100644 --- a/python/tutorials/02-fused-softmax.py +++ b/python/tutorials/02-fused-softmax.py @@ -71,10 +71,7 @@ def naive_softmax(x): @triton.jit -def softmax_kernel( - output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, - BLOCK_SIZE: tl.constexpr -): +def softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr): # The rows of the softmax are independent, so we parallelize across those row_idx = tl.program_id(0) # The stride represents how much we need to increase the pointer to advance 1 row @@ -118,7 +115,7 @@ def softmax(x): y = torch.empty_like(x) # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o # f the input matrix - softmax_kernel[(n_rows,)]( + softmax_kernel[(n_rows, )]( y, x, x.stride(0), @@ -158,9 +155,7 @@ def softmax(x): @triton.testing.perf_report( triton.testing.Benchmark( x_names=['N'], # argument names to use as an x-axis for the plot - x_vals=[ - 128 * i for i in range(2, 100) - ], # different possible values for `x_name` + x_vals=[128 * i for i in range(2, 100)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot line_vals=[ 'triton', @@ -176,8 +171,7 @@ def softmax(x): ylabel="GB/s", # label name for the y-axis plot_name="softmax-performance", # name for the plot. Used also as a file name for saving the plot. args={'M': 4096}, # values for function arguments not in `x_names` and `y_name` - ) -) + )) def benchmark(M, N, provider): x = torch.randn(M, N, device='cuda', dtype=torch.float32) quantiles = [0.5, 0.2, 0.8] diff --git a/python/tutorials/03-matrix-multiplication.py b/python/tutorials/03-matrix-multiplication.py index 8bcae2007abd..88978f6173d9 100644 --- a/python/tutorials/03-matrix-multiplication.py +++ b/python/tutorials/03-matrix-multiplication.py @@ -163,33 +163,41 @@ # provided configs @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, + num_warps=8), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), ], key=['M', 'N', 'K'], ) @triton.jit def matmul_kernel( - # Pointers to matrices - a_ptr, b_ptr, c_ptr, - # Matrix dimensions - M, N, K, - # The stride variables represent how much to increase the ptr by when moving by 1 - # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` - # by to get the element one row down (A has M rows). - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, - ACTIVATION: tl.constexpr, + # Pointers to matrices + a_ptr, b_ptr, c_ptr, + # Matrix dimensions + M, N, K, + # The stride variables represent how much to increase the ptr by when moving by 1 + # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` + # by to get the element one row down (A has M rows). + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # + ACTIVATION: tl.constexpr # ): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) @@ -274,16 +282,14 @@ def matmul(a, b, activation=""): # Allocates output. c = torch.empty((M, N), device=a.device, dtype=a.dtype) # 1D launch kernel where each block gets its own program. - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - ACTIVATION=activation + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1), # + ACTIVATION=activation # ) return c @@ -320,9 +326,7 @@ def matmul(a, b, activation=""): @triton.testing.perf_report( triton.testing.Benchmark( x_names=['M', 'N', 'K'], # Argument names to use as an x-axis for the plot - x_vals=[ - 128 * i for i in range(2, 33) - ], # Different possible values for `x_name` + x_vals=[128 * i for i in range(2, 33)], # Different possible values for `x_name` line_arg='provider', # Argument name whose value corresponds to a different line in the plot # Possible values for `line_arg` line_vals=['cublas', 'triton'], @@ -333,8 +337,7 @@ def matmul(a, b, activation=""): ylabel="TFLOPS", # Label name for the y-axis plot_name="matmul-performance", # Name for the plot, used also as a file name for saving the plot. args={}, - ) -) + )) def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((K, N), device='cuda', dtype=torch.float16) diff --git a/python/tutorials/04-low-memory-dropout.py b/python/tutorials/04-low-memory-dropout.py index 3c4d217e22b0..fe52f0d8e316 100644 --- a/python/tutorials/04-low-memory-dropout.py +++ b/python/tutorials/04-low-memory-dropout.py @@ -32,7 +32,6 @@ # # Let's first take a look at the baseline implementation. - import tabulate import torch @@ -66,22 +65,22 @@ def dropout(x, x_keep, p): output = torch.empty_like(x) assert x.is_contiguous() n_elements = x.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024) return output # Input tensor -x = torch.randn(size=(10,)).cuda() +x = torch.randn(size=(10, )).cuda() # Dropout mask p = 0.5 -x_keep = (torch.rand(size=(10,)) > p).to(torch.int32).cuda() +x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda() # output = dropout(x, x_keep=x_keep, p=p) print(tabulate.tabulate([ ["input"] + x.tolist(), ["keep mask"] + x_keep.tolist(), - ["output"] + output.tolist() + ["output"] + output.tolist(), ])) # %% @@ -134,23 +133,24 @@ def seeded_dropout(x, p, seed): output = torch.empty_like(x) assert x.is_contiguous() n_elements = x.numel() - grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) + grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024) return output -x = torch.randn(size=(10,)).cuda() +x = torch.randn(size=(10, )).cuda() # Compare this to the baseline - dropout mask is never instantiated! output = seeded_dropout(x, p=0.5, seed=123) output2 = seeded_dropout(x, p=0.5, seed=123) output3 = seeded_dropout(x, p=0.5, seed=512) -print(tabulate.tabulate([ - ["input"] + x.tolist(), - ["output (seed = 123)"] + output.tolist(), - ["output (seed = 123)"] + output2.tolist(), - ["output (seed = 512)"] + output3.tolist() -])) +print( + tabulate.tabulate([ + ["input"] + x.tolist(), + ["output (seed = 123)"] + output.tolist(), + ["output (seed = 123)"] + output2.tolist(), + ["output (seed = 512)"] + output3.tolist(), + ])) # %% # Et VoilĂ ! We have a triton kernel that applies the same dropout mask provided the seed is the same! diff --git a/python/tutorials/05-layer-norm.py b/python/tutorials/05-layer-norm.py index b19ff00ae2c1..914aecd7bbc3 100644 --- a/python/tutorials/05-layer-norm.py +++ b/python/tutorials/05-layer-norm.py @@ -126,24 +126,22 @@ def _layer_norm_fwd_fused( # In Stage 2, the buffers are further reduced to compute the final :math:`\nabla_{w}` and :math:`\nabla_{b}`. # In the following implementation, Stage 1 is implemented by the function :code:`_layer_norm_bwd_dx_fused` and Stage 2 is implemented by the function :code:`_layer_norm_bwd_dwdb`. + @triton.jit -def _layer_norm_bwd_dx_fused( - DX, # pointer to the input gradient - DY, # pointer to the output gradient - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - X, # pointer to the input - W, # pointer to the weights - B, # pointer to the biases - Mean, # pointer to the mean - Rstd, # pointer to the 1/std - Lock, # pointer to the lock - stride, # how much to increase the pointer when moving by 1 row - N, # number of columns in X - eps, # epsilon to avoid division by zero - GROUP_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr -): +def _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient + DY, # pointer to the output gradient + DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + X, # pointer to the input + W, # pointer to the weights + B, # pointer to the biases + Mean, # pointer to the mean + Rstd, # pointer to the 1/std + Lock, # pointer to the lock + stride, # how much to increase the pointer when moving by 1 row + N, # number of columns in X + eps, # epsilon to avoid division by zero + GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # Map the program id to the elements of X, DX, and DY it should compute. row = tl.program_id(0) cols = tl.arange(0, BLOCK_SIZE_N) @@ -192,16 +190,13 @@ def _layer_norm_bwd_dx_fused( @triton.jit -def _layer_norm_bwd_dwdb( - DW, # pointer to the partial sum of weights gradient - DB, # pointer to the partial sum of biases gradient - FINAL_DW, # pointer to the weights gradient - FINAL_DB, # pointer to the biases gradient - M, # GROUP_SIZE_M - N, # number of columns - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr -): +def _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient + DB, # pointer to the partial sum of biases gradient + FINAL_DW, # pointer to the weights gradient + FINAL_DB, # pointer to the biases gradient + M, # GROUP_SIZE_M + N, # number of columns + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr): # Map the program id to the elements of DW and DB it should compute. pid = tl.program_id(0) cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) @@ -249,9 +244,10 @@ def forward(ctx, x, normalized_shape, weight, bias, eps): # heuristics for number of warps num_warps = min(max(BLOCK_SIZE // 256, 1), 8) # enqueue kernel - _layer_norm_fwd_fused[(M,)](x_arg, y, weight, bias, mean, rstd, - x_arg.stride(0), N, eps, - BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) + _layer_norm_fwd_fused[(M, )]( # + x_arg, y, weight, bias, mean, rstd, # + x_arg.stride(0), N, eps, # + BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1) ctx.save_for_backward(x, weight, bias, mean, rstd) ctx.BLOCK_SIZE = BLOCK_SIZE ctx.num_warps = num_warps @@ -271,23 +267,25 @@ def backward(ctx, dy): locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device='cuda') _dw = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) _db = torch.empty((GROUP_SIZE_M, w.shape[0]), dtype=x.dtype, device=w.device) - dw = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) - db = torch.empty((w.shape[0],), dtype=w.dtype, device=w.device) + dw = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device) + db = torch.empty((w.shape[0], ), dtype=w.dtype, device=w.device) dx = torch.empty_like(dy) # enqueue kernel using forward pass heuristics # also compute partial sums for DW and DB x_arg = x.reshape(-1, x.shape[-1]) M, N = x_arg.shape - _layer_norm_bwd_dx_fused[(M,)](dx, dy, _dw, _db, x, w, b, m, v, locks, - x_arg.stride(0), N, ctx.eps, - BLOCK_SIZE_N=ctx.BLOCK_SIZE, - GROUP_SIZE_M=GROUP_SIZE_M, - num_warps=ctx.num_warps) + _layer_norm_bwd_dx_fused[(M, )]( # + dx, dy, _dw, _db, x, w, b, m, v, locks, # + x_arg.stride(0), N, ctx.eps, # + BLOCK_SIZE_N=ctx.BLOCK_SIZE, # + GROUP_SIZE_M=GROUP_SIZE_M, # + num_warps=ctx.num_warps) grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] # accumulate partial sums in separate kernel - _layer_norm_bwd_dwdb[grid](_dw, _db, dw, db, GROUP_SIZE_M, N, - BLOCK_SIZE_M=32, - BLOCK_SIZE_N=128, num_ctas=1) + _layer_norm_bwd_dwdb[grid]( + _dw, _db, dw, db, GROUP_SIZE_M, N, # + BLOCK_SIZE_M=32, # + BLOCK_SIZE_N=128, num_ctas=1) return dx, None, dw, db, None @@ -330,9 +328,8 @@ def test_layer_norm(M, N, dtype, eps=1e-5, device='cuda'): styles=[('blue', '-'), ('green', '-'), ('orange', '-')], ylabel='GB/s', plot_name='layer-norm-backward', - args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'} - ) -) + args={'M': 4096, 'dtype': torch.float16, 'mode': 'backward'}, + )) def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='cuda'): # create data x_shape = (M, N) @@ -345,24 +342,34 @@ def bench_layer_norm(M, N, dtype, provider, mode='backward', eps=1e-5, device='c quantiles = [0.5, 0.2, 0.8] # utility functions if provider == 'triton': - def y_fwd(): return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + + def y_fwd(): + return layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + if provider == 'torch': - def y_fwd(): return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + + def y_fwd(): + return torch.nn.functional.layer_norm(x, w_shape, weight, bias, eps) # noqa: F811, E704 + if provider == 'apex': - apex_layer_norm = apex.normalization.FusedLayerNorm( - w_shape).to(x.device).to(x.dtype) + apex_layer_norm = apex.normalization.FusedLayerNorm(w_shape).to(x.device).to(x.dtype) + + def y_fwd(): + return apex_layer_norm(x) # noqa: F811, E704 - def y_fwd(): return apex_layer_norm(x) # noqa: F811, E704 # forward pass if mode == 'forward': gbps = lambda ms: 2 * x.numel() * x.element_size() / ms * 1e-6 ms, min_ms, max_ms = triton.testing.do_bench(y_fwd, quantiles=quantiles, rep=500) # backward pass if mode == 'backward': - def gbps(ms): return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704 + + def gbps(ms): + return 3 * x.numel() * x.element_size() / ms * 1e-6 # noqa: F811, E704 + y = y_fwd() - ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), - quantiles=quantiles, grad_to_none=[x], rep=500) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: y.backward(dy, retain_graph=True), quantiles=quantiles, + grad_to_none=[x], rep=500) return gbps(ms), gbps(max_ms), gbps(min_ms) diff --git a/python/tutorials/06-fused-attention.py b/python/tutorials/06-fused-attention.py index dbe1e4c2691e..9946de42580e 100644 --- a/python/tutorials/06-fused-attention.py +++ b/python/tutorials/06-fused-attention.py @@ -19,18 +19,12 @@ @triton.jit -def _attn_fwd_inner( - acc, l_i, m_i, q, - K_block_ptr, V_block_ptr, - start_m, qk_scale, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - STAGE: tl.constexpr, - offs_m: tl.constexpr, - offs_n: tl.constexpr, - N_CTX: tl.constexpr, -): +def _attn_fwd_inner(acc, l_i, m_i, q, # + K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, # + N_CTX: tl.constexpr): # range of values handled by this stage if STAGE == 1: lo, hi = 0, start_m * BLOCK_M @@ -73,6 +67,7 @@ def _attn_fwd_inner( K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) return acc, l_i, m_i + # We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting # the code below and commenting out the equivalent parameters is convenient for # re-tuning. @@ -99,19 +94,18 @@ def _attn_fwd_inner( @triton.jit -def _attn_fwd( - Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vk, stride_vn, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - STAGE: tl.constexpr, -): +def _attn_fwd(Q, K, V, sm_scale, M, Out, # + stride_qz, stride_qh, stride_qm, stride_qk, # + stride_kz, stride_kh, stride_kn, stride_kk, # + stride_vz, stride_vh, stride_vk, stride_vn, # + stride_oz, stride_oh, stride_om, stride_on, # + Z, H, # + N_CTX: tl.constexpr, # + BLOCK_M: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr, # + BLOCK_N: tl.constexpr, # + STAGE: tl.constexpr # + ): start_m = tl.program_id(0) off_hz = tl.program_id(1) off_z = off_hz // H @@ -167,23 +161,21 @@ def _attn_fwd( # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE if STAGE & 1: - acc, l_i, m_i = _attn_fwd_inner( - acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - start_m, qk_scale, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 4 - STAGE, offs_m, offs_n, N_CTX, - ) + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, BLOCK_DMODEL, BLOCK_N, # + 4 - STAGE, offs_m, offs_n, N_CTX # + ) # stage 2: on-band if STAGE & 2: # barrier makes it easier for compielr to schedule the # two loops independently tl.debug_barrier() - acc, l_i, m_i = _attn_fwd_inner( - acc, l_i, m_i, q, K_block_ptr, V_block_ptr, - start_m, qk_scale, - BLOCK_M, BLOCK_DMODEL, BLOCK_N, - 2, offs_m, offs_n, N_CTX, - ) + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, # + start_m, qk_scale, # + BLOCK_M, BLOCK_DMODEL, BLOCK_N, # + 2, offs_m, offs_n, N_CTX # + ) # epilogue m_i += tl.math.log2(l_i) acc = acc / l_i[:, None] @@ -193,12 +185,11 @@ def _attn_fwd( @triton.jit -def _attn_bwd_preprocess( - O, DO, - Delta, - Z, H, N_CTX, - BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr, -): +def _attn_bwd_preprocess(O, DO, # + Delta, # + Z, H, N_CTX, # + BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr # + ): off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) off_hz = tl.program_id(1) off_n = tl.arange(0, D_HEAD) @@ -212,21 +203,18 @@ def _attn_bwd_preprocess( # The main inner-loop logic for computing dK and dV. @triton.jit -def _attn_bwd_dkdv( - dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - # shared by Q/K/V/DO. - stride_tok, stride_d, - H, N_CTX, - BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - # Filled in by the wrapper. - start_n, start_m, num_steps, - MASK: tl.constexpr, -): +def _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr, # + # Filled in by the wrapper. + start_n, start_m, num_steps, # + MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M1) offs_n = start_n + tl.arange(0, BLOCK_N1) offs_k = tl.arange(0, BLOCK_DMODEL) @@ -268,19 +256,17 @@ def _attn_bwd_dkdv( # the main inner-loop logic for computing dQ @triton.jit -def _attn_bwd_dq( - dq, q, K, V, - do, m, D, - # shared by Q/K/V/DO. - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - # Filled in by the wrapper. - start_m, start_n, num_steps, - MASK: tl.constexpr, -): +def _attn_bwd_dq(dq, q, K, V, # + do, m, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, # + MASK: tl.constexpr): offs_m = start_m + tl.arange(0, BLOCK_M2) offs_n = start_n + tl.arange(0, BLOCK_N2) offs_k = tl.arange(0, BLOCK_DMODEL) @@ -317,21 +303,19 @@ def _attn_bwd_dq( @triton.jit -def _attn_bwd( - Q, K, V, sm_scale, - DO, - DQ, DK, DV, - M, D, - # shared by Q/K/V/DO. - stride_z, stride_h, stride_tok, stride_d, - H, N_CTX, - BLOCK_M1: tl.constexpr, - BLOCK_N1: tl.constexpr, - BLOCK_M2: tl.constexpr, - BLOCK_N2: tl.constexpr, - BLK_SLICE_FACTOR: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, -): +def _attn_bwd(Q, K, V, sm_scale, # + DO, # + DQ, DK, DV, # + M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1: tl.constexpr, # + BLOCK_N1: tl.constexpr, # + BLOCK_M2: tl.constexpr, # + BLOCK_N2: tl.constexpr, # + BLK_SLICE_FACTOR: tl.constexpr, # + BLOCK_DMODEL: tl.constexpr): LN2: tl.constexpr = 0.6931471824645996 # = ln(2) bhid = tl.program_id(2) @@ -370,31 +354,32 @@ def _attn_bwd( num_steps = BLOCK_N1 // MASK_BLOCK_M1 - dk, dv = _attn_bwd_dkdv(dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - stride_tok, stride_d, - H, N_CTX, - MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, - start_n, start_m, num_steps, - MASK=True, + dk, dv = _attn_bwd_dkdv(dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, # + start_n, start_m, num_steps, # + MASK=True # ) start_m += num_steps * MASK_BLOCK_M1 num_steps = (N_CTX - start_m) // BLOCK_M1 # Compute dK and dV for non-masked blocks. - dk, dv = _attn_bwd_dkdv(dk, dv, - Q, k, v, sm_scale, - DO, - M, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, - start_n, start_m, num_steps, - MASK=False, - ) + dk, dv = _attn_bwd_dkdv( # + dk, dv, # + Q, k, v, sm_scale, # + DO, # + M, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, # + start_n, start_m, num_steps, # + MASK=False # + ) dv_ptrs = DV + offs_n[:, None] * stride_tok + offs_k[None, :] * stride_d tl.store(dv_ptrs, dv) @@ -424,27 +409,25 @@ def _attn_bwd( # not due to anything important. I just wanted to reuse the loop # structure for dK & dV above as much as possible. num_steps = BLOCK_M2 // MASK_BLOCK_N2 - dq = _attn_bwd_dq( - dq, q, K, V, - do, m, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, - start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, - MASK=True, - ) + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, MASK_BLOCK_N2, BLOCK_DMODEL, # + start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, # + MASK=True # + ) end_n -= num_steps * MASK_BLOCK_N2 # stage 2 num_steps = end_n // BLOCK_N2 - dq = _attn_bwd_dq( - dq, q, K, V, - do, m, D, - stride_tok, stride_d, - H, N_CTX, - BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, - start_m, end_n - num_steps * BLOCK_N2, num_steps, - MASK=False, - ) + dq = _attn_bwd_dq(dq, q, K, V, # + do, m, D, # + stride_tok, stride_d, # + H, N_CTX, # + BLOCK_M2, BLOCK_N2, BLOCK_DMODEL, # + start_m, end_n - num_steps * BLOCK_N2, num_steps, # + MASK=False # + ) # Write back dQ. dq_ptrs = DQ + offs_m[:, None] * stride_tok + offs_k[None, :] * stride_d dq *= LN2 @@ -455,6 +438,7 @@ def _attn_bwd( class _attention(torch.autograd.Function): + @staticmethod def forward(ctx, q, k, v, causal, sm_scale): # shape constraints @@ -474,19 +458,19 @@ def forward(ctx, q, k, v, causal, sm_scale): grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) _attn_fwd[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_M=BLOCK_M, - BLOCK_N=BLOCK_N, - BLOCK_DMODEL=Lk, - STAGE=stage, - num_warps=num_warps, - num_stages=num_stages, + q, k, v, sm_scale, M, o, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + k.stride(0), k.stride(1), k.stride(2), k.stride(3), # + v.stride(0), v.stride(1), v.stride(2), v.stride(3), # + o.stride(0), o.stride(1), o.stride(2), o.stride(3), # + q.shape[0], q.shape[1], # + N_CTX=q.shape[2], # + BLOCK_M=BLOCK_M, # + BLOCK_N=BLOCK_N, # + BLOCK_DMODEL=Lk, # + STAGE=stage, # + num_warps=num_warps, # + num_stages=num_stages # ) ctx.save_for_backward(q, k, v, o, M) @@ -517,23 +501,23 @@ def backward(ctx, do): pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD) delta = torch.empty_like(M) _attn_bwd_preprocess[pre_grid]( - o, do, - delta, - BATCH, N_HEAD, N_CTX, - BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL, + o, do, # + delta, # + BATCH, N_HEAD, N_CTX, # + BLOCK_M=PRE_BLOCK, D_HEAD=ctx.BLOCK_DMODEL # ) grid = (N_CTX // BLOCK_N1, 1, BATCH * N_HEAD) _attn_bwd[grid]( - q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, - M, delta, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - N_HEAD, N_CTX, - BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, - BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, - BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, - BLOCK_DMODEL=ctx.BLOCK_DMODEL, - num_warps=NUM_WARPS, - num_stages=NUM_STAGES, + q, arg_k, v, ctx.sm_scale, do, dq, dk, dv, # + M, delta, # + q.stride(0), q.stride(1), q.stride(2), q.stride(3), # + N_HEAD, N_CTX, # + BLOCK_M1=BLOCK_M1, BLOCK_N1=BLOCK_N1, # + BLOCK_M2=BLOCK_M2, BLOCK_N2=BLOCK_N2, # + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, # + BLOCK_DMODEL=ctx.BLOCK_DMODEL, # + num_warps=NUM_WARPS, # + num_stages=NUM_STAGES # ) return dq, dk, dv, None, None @@ -546,21 +530,9 @@ def backward(ctx, do): @pytest.mark.parametrize("causal", [True]) def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): torch.manual_seed(20) - q = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - k = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) - v = ( - torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda") - .normal_(mean=0.0, std=0.5) - .requires_grad_() - ) + q = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, N_CTX, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) sm_scale = 0.5 dout = torch.randn_like(q) # reference implementation @@ -621,14 +593,11 @@ def test_op(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): "mode": mode, "causal": causal, }, - ) - ) + )) @triton.testing.perf_report(configs) -def bench_flash_attention( - BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda" -): +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal, mode, provider, dtype=torch.float16, device="cuda"): assert mode in ["fwd", "bwd"] warmup = 25 rep = 100 @@ -647,9 +616,7 @@ def bench_flash_attention( fn = lambda: o.backward(do, retain_graph=True) ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) if provider == "flash": - qkv = torch.randn( - (BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True - ) + qkv = torch.randn((BATCH, N_CTX, 3, H, D_HEAD), dtype=dtype, device=device, requires_grad=True) fn = lambda: flash_attn_func(qkv, causal=causal) if mode == "bwd": o = fn() diff --git a/python/tutorials/07-math-functions.py b/python/tutorials/07-math-functions.py index 1ded3aa984d6..f60f07efbc18 100644 --- a/python/tutorials/07-math-functions.py +++ b/python/tutorials/07-math-functions.py @@ -22,10 +22,10 @@ @triton.jit def asin_kernel( - x_ptr, - y_ptr, - n_elements, - BLOCK_SIZE: tl.constexpr, + x_ptr, + y_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, ): pid = tl.program_id(axis=0) block_start = pid * BLOCK_SIZE @@ -35,12 +35,12 @@ def asin_kernel( x = tl.math.asin(x) tl.store(y_ptr + offsets, x, mask=mask) + # %% # Using the default libdevice library path # ----------------------------------------- # We can use the default libdevice library path encoded in `triton/language/math.py` - torch.manual_seed(0) size = 98432 x = torch.rand(size, device='cuda') @@ -48,14 +48,12 @@ def asin_kernel( output_torch = torch.asin(x) assert x.is_cuda and output_triton.is_cuda n_elements = output_torch.numel() -grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) +grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), ) asin_kernel[grid](x, output_triton, n_elements, BLOCK_SIZE=1024) print(output_torch) print(output_triton) -print( - f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}' -) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') # %% # Customize the libdevice library path @@ -67,7 +65,5 @@ def asin_kernel( extern_libs={'libdevice': '/usr/local/cuda/nvvm/libdevice/libdevice.10.bc'}) print(output_torch) print(output_triton) -print( - f'The maximum difference between torch and triton is ' - f'{torch.max(torch.abs(output_torch - output_triton))}' -) +print(f'The maximum difference between torch and triton is ' + f'{torch.max(torch.abs(output_torch - output_triton))}') diff --git a/python/tutorials/08-experimental-block-pointer.py b/python/tutorials/08-experimental-block-pointer.py index 7147b69de6cc..4486349fbe77 100644 --- a/python/tutorials/08-experimental-block-pointer.py +++ b/python/tutorials/08-experimental-block-pointer.py @@ -98,14 +98,22 @@ @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4), - triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), - triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, + num_warps=4), + triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), + triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, + num_warps=2), ], key=['M', 'N', 'K'], ) @@ -118,13 +126,11 @@ def matmul_kernel_with_block_pointers( # The stride variables represent how much to increase the ptr by when moving by 1 # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr` # by to get the element one row down (A has M rows). - stride_am, stride_ak, - stride_bk, stride_bn, + stride_am, stride_ak, # + stride_bk, stride_bn, # stride_cm, stride_cn, # Meta-parameters - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr -): + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): """Kernel for computing the matmul C = A x B. A has shape (M, K), B has shape (K, N) and C has shape (M, N) """ @@ -196,16 +202,13 @@ def matmul(a, b): # Allocates output. c = torch.empty((M, N), device=a.device, dtype=a.dtype) # 1D launch kernel where each block gets its own program. - grid = lambda META: ( - triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), - ) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) matmul_kernel_with_block_pointers[grid]( - a, b, c, - M, N, K, - a.stride(0), a.stride(1), - b.stride(0), b.stride(1), - c.stride(0), c.stride(1), - ) + a, b, c, # + M, N, K, # + a.stride(0), a.stride(1), # + b.stride(0), b.stride(1), # + c.stride(0), c.stride(1)) return c diff --git a/python/tutorials/09-experimental-tma-matrix-multiplication.py b/python/tutorials/09-experimental-tma-matrix-multiplication.py index 8a79720c79a0..8cf81ef69902 100644 --- a/python/tutorials/09-experimental-tma-matrix-multiplication.py +++ b/python/tutorials/09-experimental-tma-matrix-multiplication.py @@ -40,23 +40,24 @@ @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, + num_warps=4), # triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=2), # triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4), ], key=['M', 'N', 'K'], ) @triton.jit -def matmul_kernel( - a_ptr, b_ptr, z_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_zm, stride_zn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr, - A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, - B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr -): +def matmul_kernel(a_ptr, b_ptr, z_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_zm, stride_zn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, # + A_ORDER_0: tl.constexpr, A_ORDER_1: tl.constexpr, # + B_ORDER_0: tl.constexpr, B_ORDER_1: tl.constexpr # + ): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -70,9 +71,11 @@ def matmul_kernel( block_offset_n = pid_n * BLOCK_SIZE_N a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), - offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(A_ORDER_0, A_ORDER_1)) + offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), + order=(A_ORDER_0, A_ORDER_1)) b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), - offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(B_ORDER_0, B_ORDER_1)) + offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), + order=(B_ORDER_0, B_ORDER_1)) z = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) offs_m = block_offset_m + tl.arange(0, BLOCK_SIZE_M) @@ -101,15 +104,17 @@ def matmul(a, b, a_order, b_order): z = torch.empty((M, N), device=a.device, dtype=torch.float16) def grid(META): - return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) - matmul_kernel[grid](a_ptr=a, b_ptr=b, z_ptr=z, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_zm=z.stride(0), stride_zn=z.stride(1), - A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], - B_ORDER_0=b_order[0], B_ORDER_1=b_order[1] - ) + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + matmul_kernel[grid]( + a_ptr=a, b_ptr=b, z_ptr=z, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_zm=z.stride(0), stride_zn=z.stride(1), # + A_ORDER_0=a_order[0], A_ORDER_1=a_order[1], # + B_ORDER_0=b_order[0], B_ORDER_1=b_order[1] # + ) return z @@ -160,14 +165,12 @@ def test_matmul(): # label name for the lines line_names=["cuBLAS", "Triton"], # line styles - styles=[('green', '-'), ('green', '--'), - ('blue', '-'), ('blue', '--')], + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel="TFLOPS", # label name for the y-axis plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot. args={}, - ) -) + )) def benchmark(M, N, K, TRANS_A, TRANS_B, provider): if (TRANS_A): a = torch.randn((K, M), device='cuda', dtype=torch.float16).T @@ -185,14 +188,15 @@ def benchmark(M, N, K, TRANS_A, TRANS_B, provider): quantiles = [0.5, 0.2, 0.8] if provider == 'cublas': - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, + fast_flush=False) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul(a, b, a_order, b_order), rep=100, quantiles=quantiles, fast_flush=False) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b, a_order, b_order), rep=100, + quantiles=quantiles, fast_flush=False) def perf(ms): return 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) diff --git a/python/tutorials/10-experimental-tma-store-matrix-multiplication.py b/python/tutorials/10-experimental-tma-store-matrix-multiplication.py index 37d58863d083..966e8d1e25b1 100644 --- a/python/tutorials/10-experimental-tma-store-matrix-multiplication.py +++ b/python/tutorials/10-experimental-tma-store-matrix-multiplication.py @@ -40,21 +40,21 @@ @triton.autotune( configs=[ - triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4), + triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, + num_warps=4), # triton.Config({'BLOCK_SIZE_M': 512, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=7, num_warps=4, num_ctas=4), ], key=['M', 'N', 'K'], ) @triton.jit -def matmul_kernel( - a_ptr, b_ptr, c_ptr, - M, N, K, - stride_am, stride_ak, - stride_bk, stride_bn, - stride_cm, stride_cn, - BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, - GROUP_SIZE_M: tl.constexpr, -): +def matmul_kernel(a_ptr, b_ptr, c_ptr, # + M, N, K, # + stride_am, stride_ak, # + stride_bk, stride_bn, # + stride_cm, stride_cn, # + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr # + ): pid = tl.program_id(axis=0) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) @@ -67,20 +67,10 @@ def matmul_kernel( block_offset_m = pid_m * BLOCK_SIZE_M block_offset_n = pid_n * BLOCK_SIZE_N - a_tile_ptr = tl.make_block_ptr( - base=a_ptr, shape=( - M, K), strides=( - stride_am, stride_ak), offsets=( - block_offset_m, 0), block_shape=( - BLOCK_SIZE_M, BLOCK_SIZE_K), order=( - 1, 0)) - b_tile_ptr = tl.make_block_ptr( - base=b_ptr, shape=( - K, N), strides=( - stride_bk, stride_bn), offsets=( - 0, block_offset_n), block_shape=( - BLOCK_SIZE_K, BLOCK_SIZE_N), order=( - 0, 1)) + a_tile_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), + offsets=(block_offset_m, 0), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_K), order=(1, 0)) + b_tile_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), + offsets=(0, block_offset_n), block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_N), order=(0, 1)) accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for k in range(0, K, BLOCK_SIZE_K): @@ -91,7 +81,8 @@ def matmul_kernel( b_tile_ptr = tl.advance(b_tile_ptr, [BLOCK_SIZE_K, 0]) c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), - offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), order=(1, 0)) + offsets=(block_offset_m, block_offset_n), block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N), + order=(1, 0)) tl.store(c_block_ptr, accumulator) @@ -101,20 +92,19 @@ def matmul(a, b): assert a.shape[1] == b.shape[0], "incompatible dimensions" M, K = a.shape K, N = b.shape - assert ( - K % 32 == 0 - ), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K" + assert (K % 32 == 0), "We don't check memory-out-of-bounds with K so K must be divisible by BLOCK_SIZE_K" c = torch.empty((M, N), device=a.device, dtype=torch.float32) def grid(META): - return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),) - - matmul_kernel[grid](a_ptr=a, b_ptr=b, c_ptr=c, - M=M, N=N, K=K, - stride_am=a.stride(0), stride_ak=a.stride(1), - stride_bk=b.stride(0), stride_bn=b.stride(1), - stride_cm=c.stride(0), stride_cn=c.stride(1)) + return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + + matmul_kernel[grid]( + a_ptr=a, b_ptr=b, c_ptr=c, # + M=M, N=N, K=K, # + stride_am=a.stride(0), stride_ak=a.stride(1), # + stride_bk=b.stride(0), stride_bn=b.stride(1), # + stride_cm=c.stride(0), stride_cn=c.stride(1)) return c @@ -126,12 +116,7 @@ def grid(META): golden = torch.nn.functional.normalize(torch.matmul(a, b)) torch.set_printoptions(profile="full") -assert_close( - c, - golden, - rtol=1e-2, - atol=1e-3, - check_dtype=False) +assert_close(c, golden, rtol=1e-2, atol=1e-3, check_dtype=False) @triton.testing.perf_report( @@ -143,7 +128,7 @@ def grid(META): [2048, 1024, 1024], [2048, 2048, 2048], [2048, 4096, 4096], - [2048, 8192, 8192] + [2048, 8192, 8192], ], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot @@ -152,27 +137,26 @@ def grid(META): # label name for the lines line_names=["cuBLAS", "Triton"], # line styles - styles=[('green', '-'), ('green', '--'), - ('blue', '-'), ('blue', '--')], + styles=[('green', '-'), ('green', '--'), ('blue', '-'), ('blue', '--')], ylabel="TFLOPS", # label name for the y-axis plot_name="matmul-performance", # name for the plot. Used also as a file name for saving the plot. args={}, - ) -) + )) def benchmark(M, N, K, provider): a = torch.randn((M, K), device='cuda', dtype=torch.float16) b = torch.randn((N, K), device='cuda', dtype=torch.float16).T quantiles = [0.5, 0.2, 0.8] if provider == 'cublas': - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch.matmul(a, b), rep=100, quantiles=quantiles, + fast_flush=False) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench( - lambda: matmul(a, b), rep=100, quantiles=quantiles, fast_flush=False) + ms, min_ms, max_ms = triton.testing.do_bench(lambda: matmul(a, b), rep=100, quantiles=quantiles, + fast_flush=False) def perf(ms): return 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) diff --git a/python/tutorials/11-grouped-gemm.py b/python/tutorials/11-grouped-gemm.py index e27acebd1627..43be11382f9c 100644 --- a/python/tutorials/11-grouped-gemm.py +++ b/python/tutorials/11-grouped-gemm.py @@ -1,4 +1,3 @@ - """ Group GEMM ============================ @@ -35,38 +34,30 @@ @triton.autotune( configs=[ - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'NUM_SM': 84, - } - ), - triton.Config( - { - 'BLOCK_SIZE_M': 128, - 'BLOCK_SIZE_N': 128, - 'BLOCK_SIZE_K': 32, - 'NUM_SM': 128, - } - ), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'NUM_SM': 84, - } - ), - triton.Config( - { - 'BLOCK_SIZE_M': 64, - 'BLOCK_SIZE_N': 64, - 'BLOCK_SIZE_K': 32, - 'NUM_SM': 128, - } - ), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 84, + }), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 128, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 84, + }), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'NUM_SM': 128, + }), ], key=['group_size'], ) @@ -102,9 +93,7 @@ def grouped_matmul_kernel( num_n_tiles = tl.cdiv(gn, BLOCK_SIZE_N) num_tiles = num_m_tiles * num_n_tiles # iterate through the tiles in the current gemm problem - while ( - tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles - ): + while (tile_idx >= last_problem_end and tile_idx < last_problem_end + num_tiles): # pick up a tile from the current gemm problem k = gk lda = tl.load(g_lds + g * 3) @@ -124,9 +113,7 @@ def grouped_matmul_kernel( offs_k = tl.arange(0, BLOCK_SIZE_K) a_ptrs = a_ptr + offs_am[:, None] * lda + offs_k[None, :] b_ptrs = b_ptr + offs_k[:, None] * ldb + offs_bn[None, :] - accumulator = tl.zeros( - (BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32 - ) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) for kk in range(0, tl.cdiv(k, BLOCK_SIZE_K)): # hint to Triton compiler to do proper loop pipelining tl.multiple_of(a_ptrs, [16, 16]) @@ -174,7 +161,7 @@ def group_gemm_fn(group_A, group_B): group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) - C_addrs .append(C.data_ptr()) + C_addrs.append(C.data_ptr()) g_sizes += [M, N, K] g_lds += [A.stride(0), B.stride(0), C.stride(0)] @@ -182,14 +169,10 @@ def group_gemm_fn(group_A, group_B): d_a_ptrs = torch.tensor(A_addrs, device=device) d_b_ptrs = torch.tensor(B_addrs, device=device) d_c_ptrs = torch.tensor(C_addrs, device=device) - d_g_sizes = torch.tensor( - g_sizes, dtype=torch.int32, device=device - ) - d_g_lds = torch.tensor( - g_lds, dtype=torch.int32, device=device - ) + d_g_sizes = torch.tensor(g_sizes, dtype=torch.int32, device=device) + d_g_lds = torch.tensor(g_lds, dtype=torch.int32, device=device) # we use a fixed number of CTA, and it's auto-tunable - grid = lambda META: (META['NUM_SM'],) + grid = lambda META: (META['NUM_SM'], ) grouped_matmul_kernel[grid]( d_a_ptrs, d_b_ptrs, @@ -227,7 +210,7 @@ def group_gemm_fn(group_A, group_B): # only launch the kernel, no tensor preparation here to remove all overhead def triton_perf_fn(a_ptrs, b_ptrs, c_ptrs, sizes, lds, group_size): - grid = lambda META: (META['NUM_SM'],) + grid = lambda META: (META['NUM_SM'], ) grouped_matmul_kernel[grid]( a_ptrs, b_ptrs, @@ -247,7 +230,7 @@ def torch_perf_fn(group_A, group_B): triton.testing.Benchmark( # argument names to use as an x-axis for the plot x_names=['N'], - x_vals=[2 ** i for i in range(7, 11)], # different possible values for `x_name` + x_vals=[2**i for i in range(7, 11)], # different possible values for `x_name` line_arg='provider', # argument name whose value corresponds to a different line in the plot # possible values for `line_arg`` @@ -260,8 +243,7 @@ def torch_perf_fn(group_A, group_B): plot_name="group-gemm-performance", # name for the plot. Used also as a file name for saving the plot. args={}, - ) -) + )) def benchmark(N, provider): group_size = 4 group_A = [] @@ -281,7 +263,7 @@ def benchmark(N, provider): group_C.append(C) A_addrs.append(A.data_ptr()) B_addrs.append(B.data_ptr()) - C_addrs .append(C.data_ptr()) + C_addrs.append(C.data_ptr()) g_sizes += [N, N, N] g_lds += [N, N, N] @@ -295,7 +277,8 @@ def benchmark(N, provider): if provider == 'cublas': ms, min_ms, max_ms = triton.testing.do_bench(lambda: torch_perf_fn(group_A, group_B), quantiles=quantiles) if provider == 'triton': - ms, min_ms, max_ms = triton.testing.do_bench(lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles) + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: triton_perf_fn(d_a_ptrs, d_b_ptrs, d_c_ptrs, d_g_sizes, d_g_lds, group_size), quantiles=quantiles) return ms, max_ms, min_ms diff --git a/test/lit.cfg.py b/test/lit.cfg.py index 5ea9c458dcd0..b4d1528cdb31 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -30,20 +30,14 @@ config.substitutions.append(('%PATH%', config.environment['PATH'])) config.substitutions.append(('%shlibext', config.llvm_shlib_ext)) -llvm_config.with_system_environment( - ['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) +llvm_config.with_system_environment(['HOME', 'INCLUDE', 'LIB', 'TMP', 'TEMP']) # llvm_config.use_default_substitutions() # excludes: A list of directories to exclude from the testsuite. The 'Inputs' # subdirectories contain auxiliary inputs for various tests in their parent # directories. -config.excludes = [ - 'Inputs', - 'Examples', - 'CMakeLists.txt', - 'README.txt', - 'LICENSE.txt'] +config.excludes = ['Inputs', 'Examples', 'CMakeLists.txt', 'README.txt', 'LICENSE.txt'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) @@ -52,10 +46,7 @@ config.test_exec_root = os.path.join(config.triton_obj_root, 'test') config.triton_tools_dir = os.path.join(config.triton_obj_root, 'bin') config.filecheck_dir = os.path.join(config.triton_obj_root, 'bin', 'FileCheck') -tool_dirs = [ - config.triton_tools_dir, - config.llvm_tools_dir, - config.filecheck_dir] +tool_dirs = [config.triton_tools_dir, config.llvm_tools_dir, config.filecheck_dir] # Tweak the PATH to include the tools dir. for d in tool_dirs: