Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical Fix. #3688

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 9 additions & 12 deletions fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

# pyre-strict

import logging
import unittest
from typing import Tuple

Expand Down Expand Up @@ -72,12 +73,10 @@ def _test_grouped_gemm_fp8_rowwise(

torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2)

_test_grouped_gemm_fp8_rowwise((16, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_fp8_rowwise((8, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_fp8_rowwise((4, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_fp8_rowwise((2, 512, 256, 256), torch.device("cuda"))
# TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug.
# _test_grouped_gemm_fp8_rowwise((1, 512, 256, 256), torch.device("cuda"))
for G in (1, 2, 4, 8, 16):
for M in (64, 512):
logging.info(f"Testing FP8 GMM with G={G}, M={M}")
_test_grouped_gemm_fp8_rowwise((G, M, 256, 256), torch.device("cuda"))

def test_grouped_gemm_bf16(self) -> None:
def _test_grouped_gemm_bf16(
Expand Down Expand Up @@ -109,9 +108,7 @@ def _test_grouped_gemm_bf16(

torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)

_test_grouped_gemm_bf16((16, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_bf16((8, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_bf16((4, 512, 256, 256), torch.device("cuda"))
_test_grouped_gemm_bf16((2, 512, 256, 256), torch.device("cuda"))
# TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug.
# _test_grouped_gemm_bf16((1, 512, 256, 256), torch.device("cuda"))
for G in (1, 2, 4, 8, 16):
for M in (64, 512):
logging.info(f"Testing BF16 GMM with G={G}, M={M}")
_test_grouped_gemm_bf16((G, M, 256, 256), torch.device("cuda"))
118 changes: 69 additions & 49 deletions fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
if not (block_size_m == 64 and num_warps == 8)
],
key=["G", "M_BUCKET", "N", "K"],
)
Expand All @@ -50,6 +51,7 @@ def _kernel_grouped_gemm(
N: tl.constexpr,
K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
Expand All @@ -76,16 +78,17 @@ def _kernel_grouped_gemm(
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles

# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
if USE_TMA_STORE:
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
Expand Down Expand Up @@ -113,13 +116,25 @@ def _kernel_grouped_gemm(
)
accumulator += tl.dot(a, b.T)

m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
accumulator.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
if USE_TMA_STORE:
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
accumulator.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
else:
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c = accumulator.to(c_ptr.dtype.element_ty)
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None]) * N
+ offs_bn[None, :],
c,
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
)
tidx += NUM_SMS

iterated_tiles += num_tiles
Expand All @@ -143,6 +158,7 @@ def _kernel_grouped_gemm(
for num_stages in [3, 4]
for num_warps in [4, 8]
for num_ctas in [1]
if not (block_size_m == 64 and num_warps == 8)
],
key=["G", "M_BUCKET", "N", "K"],
)
Expand All @@ -161,6 +177,7 @@ def _kernel_grouped_gemm_fp8_rowwise(
N: tl.constexpr,
K: tl.constexpr,
NUM_SMS: tl.constexpr,
USE_TMA_STORE: tl.constexpr,
# tile sizes
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
Expand All @@ -187,16 +204,17 @@ def _kernel_grouped_gemm_fp8_rowwise(
num_n_tiles = tl.cdiv(n_size, BLOCK_SIZE_N)
num_tiles = num_m_tiles * num_n_tiles

# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)
if USE_TMA_STORE:
# pyre-ignore
tl.extra.cuda.experimental_device_tensormap_create2d(
desc_ptr=c_desc_ptr,
global_address=c_ptr + M_start_offset * N,
load_size=[BLOCK_SIZE_M, BLOCK_SIZE_N],
global_size=[m_size, n_size],
element_ty=c_ptr.dtype.element_ty,
)
# pyre-ignore
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(c_desc_ptr)

# Move across tiles
while tidx >= iterated_tiles and tidx < iterated_tiles + num_tiles:
Expand Down Expand Up @@ -236,21 +254,27 @@ def _kernel_grouped_gemm_fp8_rowwise(
)
c = accumulator.to(tl.float32) * a_scale * b_scale

m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
c.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
if USE_TMA_STORE:
m_offset = (tile_m_idx * BLOCK_SIZE_M).to(tl.int32)
n_offset = (tile_n_idx * BLOCK_SIZE_N).to(tl.int32)
tl._experimental_descriptor_store(
c_desc_ptr,
c.to(c_ptr.dtype.element_ty),
[m_offset, n_offset],
)
else:
tl.store(
c_ptr
+ (M_start_offset + offs_am[:, None]) * N
+ offs_bn[None, :],
c,
mask=offs_am[:, None] < m_size and offs_bn[None, :] < n_size,
)
tidx += NUM_SMS

iterated_tiles += num_tiles


_ON_DEVICE_TMA_WORKSPACE = {}


def _grouped_gemm(
x: torch.Tensor,
w: torch.Tensor,
Expand All @@ -263,10 +287,6 @@ def _grouped_gemm(

G = m_offsets.shape[0]

# TODO(shikaili): G=1 could produce NaNs results with on-device TMA store. Need to debug.
if G == 1:
raise NotImplementedError("Grouped GEMM with NUM_GROUPS=1 is not supported yet")

assert x.is_contiguous()
assert w.is_contiguous()
assert m_offsets.is_contiguous()
Expand All @@ -283,14 +303,11 @@ def _grouped_gemm(

NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count

global _ON_DEVICE_TMA_WORKSPACE
if x.device not in _ON_DEVICE_TMA_WORKSPACE:
_ON_DEVICE_TMA_WORKSPACE[x.device] = torch.empty(
NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE,
device=x.device,
dtype=torch.uint8,
)
workspace = _ON_DEVICE_TMA_WORKSPACE[x.device]
workspace = torch.empty(
NUM_SMS * utils.TmaAutoTuneHelper.TMA_SIZE,
device=x.device,
dtype=torch.uint8,
)

def grid(META):
nonlocal desc_helper
Expand Down Expand Up @@ -320,6 +337,7 @@ def grid(META):
desc_w = desc_helper.get_tma_descriptor_kernel_param("w")

M_BUCKET = triton.next_power_of_2(M)
USE_TMA_STORE = False
if x_scale is not None and w_scale is not None:
assert x_scale.is_contiguous()
assert w_scale.is_contiguous()
Expand All @@ -336,6 +354,7 @@ def grid(META):
N,
K,
NUM_SMS,
USE_TMA_STORE,
)
else:
assert x_scale is None
Expand All @@ -351,6 +370,7 @@ def grid(META):
N,
K,
NUM_SMS,
USE_TMA_STORE,
)

return y
Expand Down
Loading