Skip to content

Commit

Permalink
Backing out on-device TMA store. (#3688)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3688

X-link: facebookresearch/FBGEMM#764

- Always allocate workspace.
  - Allocating is almost free with  PyTorch sub-allocation.
  - Not allocating could cause problems in multi-processing and cuda graph capturing.

- Disable TMA store for now.
  - Running into issues with on-device TMA store.

Reviewed By: jiawenliu64, jwfromm

Differential Revision: D69602533

fbshipit-source-id: 7c776231711830c777e6808894dfc89ebdd4ed2c
  • Loading branch information
levendlee authored and facebook-github-bot committed Feb 15, 2025
1 parent a4be13a commit e024eb7
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 61 deletions.
21 changes: 9 additions & 12 deletions fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import logging
import unittest
from typing import Tuple

Expand Down Expand Up @@ -72,12 +73,10 @@ def _test_grouped_gemm_fp8_rowwise(

torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2)

_test_grouped_gemm_fp8_rowwise((16, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_fp8_rowwise((8, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_fp8_rowwise((4, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_fp8_rowwise((2, 512, 256, 256), torch.device("cuda"))
# TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug.
# _test_grouped_gemm_fp8_rowwise((1, 512, 256, 256), torch.device("cuda"))
for G in (1, 2, 4, 8, 16):
for M in (64, 512):
logging.info(f"Testing FP8 GMM with G={G}, M={M}")
_test_grouped_gemm_fp8_rowwise((G, M, 256, 256), torch.device("cuda"))

def test_grouped_gemm_bf16(self) -> None:
def _test_grouped_gemm_bf16(
Expand Down Expand Up @@ -109,9 +108,7 @@ def _test_grouped_gemm_bf16(

torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)

_test_grouped_gemm_bf16((16, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_bf16((8, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_bf16((4, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_bf16((2, 512, 256, 256), torch.device("cuda"))
# TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug.
# _test_grouped_gemm_bf16((1, 512, 256, 256), torch.device("cuda"))
for G in (1, 2, 4, 8, 16):
for M in (64, 512):
logging.info(f"Testing BF16 GMM with G={G}, M={M}")
_test_grouped_gemm_bf16((G, M, 256, 256), torch.device("cuda"))
118 changes: 69 additions & 49 deletions fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
if not (block_size_m == 64 and num_warps == 8)
],
key=["G", "M_BUCKET", "N", "K"],
)
Expand All @@ -50,6 +51,7 @@ def _kernel_grouped_gemm(
N: tl.constexpr,
K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
Expand All @@ -76,16 +78,17 @@ def _kernel_grouped_gemm(
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles

# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
if USE_TMA_STORE:
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
Expand Down Expand Up @@ -113,13 +116,25 @@ def _kernel_grouped_gemm(
)
accumulator += tl.dot(a, b.T)

m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
accumulator.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
if USE_TMA_STORE:
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
accumulator.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
else:
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c = accumulator.to(c_ptr.dtype.element_ty)
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None]) * N
+ offs_bn[None, :],
c,
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
)
tidx += NUM_SMS

iterated_tiles += num_tiles
Expand All @@ -143,6 +158,7 @@ def _kernel_grouped_gemm(
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
if not (block_size_m == 64 and num_warps == 8)
],
key=["G", "M_BUCKET", "N", "K"],
)
Expand All @@ -161,6 +177,7 @@ def _kernel_grouped_gemm_fp8_rowwise(
N: tl.constexpr,
K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
Expand All @@ -187,16 +204,17 @@ def _kernel_grouped_gemm_fp8_rowwise(
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles

# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
if USE_TMA_STORE:
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
Expand Down Expand Up @@ -236,21 +254,27 @@ def _kernel_grouped_gemm_fp8_rowwise(
)
c = accumulator.to(tl.float32) * a_scale * b_scale

m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
c.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
if USE_TMA_STORE:
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
c.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
else:
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None]) * N
+ offs_bn[None, :],
c,
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
)
tidx += NUM_SMS

iterated_tiles += num_tiles


_ON_DEVICE_TMA_WORKSPACE = {}


def _grouped_gemm(
x: torch.Tensor,
w: torch.Tensor,
Expand All @@ -263,10 +287,6 @@ def _grouped_gemm(

G = m_offsets.shape[0]

# TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug.
if G == 1:
raise NotImplementedError("Grouped GEMM with NUM_GROUPS=1 is not supported yet")

assert x.is_contiguous()
assert w.is_contiguous()
assert m_offsets.is_contiguous()
Expand All @@ -283,14 +303,11 @@ def _grouped_gemm(

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count

global _ON_DEVICE_TMA_WORKSPACE
if x.device not in _ON_DEVICE_TMA_WORKSPACE:
_ON_DEVICE_TMA_WORKSPACE[x.device] = torch.empty(
NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE,
device=x.device,
dtype=torch.uint8,
)
workspace = _ON_DEVICE_TMA_WORKSPACE[x.device]
workspace = torch.empty(
NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE,
device=x.device,
dtype=torch.uint8,
)

def grid(META):
nonlocal desc_helper
Expand Down Expand Up @@ -320,6 +337,7 @@ def grid(META):
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")

M_BUCKET = triton.next_power_of_2(M)
USE_TMA_STORE = False
if x_scale is not None and w_scale is not None:
assert x_scale.is_contiguous()
assert w_scale.is_contiguous()
Expand All @@ -336,6 +354,7 @@ def grid(META):
N,
K,
NUM_SMS,
USE_TMA_STORE,
)
else:
assert x_scale is None
Expand All @@ -351,6 +370,7 @@ def grid(META):
N,
K,
NUM_SMS,
USE_TMA_STORE,
)

return y
Expand Down

0 comments on commit e024eb7

Please sign in to comment.