Skip to content

Commit

Permalink
Re-organize SLL ops, pt 8 (#3663)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#738

Pull Request resolved: #3663

- Re-organize the remaining SLL triton ops

Differential Revision: D68970862
  • Loading branch information
q10 authored and facebook-github-bot committed Feb 7, 2025
1 parent 2cef43a commit c108e35
Show file tree
Hide file tree
Showing 18 changed files with 206 additions and 197 deletions.
12 changes: 0 additions & 12 deletions .github/scripts/fbgemm_gpu_test.bash
Original file line number Diff line number Diff line change
Expand Up @@ -83,18 +83,6 @@ __configure_fbgemm_gpu_test_cpu () {
# These tests have non-CPU operators referenced in @given
./uvm/copy_test.py
./uvm/uvm_test.py
./sll/triton_sll_test.py
./sll/array_jagged_bmm_jagged_out_test.py
./sll/jagged_dense_elementwise_add_test.py
./sll/jagged_flash_attention_basic_test.py
./sll/jagged_jagged_bmm_jagged_out_test.py
./sll/jagged_dense_flash_attention_test.py
./sll/jagged_dense_bmm_test.py
./sll/jagged_dense_elementwise_mul_jagged_out_test.py
./sll/jagged_jagged_bmm_test.py
./sll/jagged_softmax_test.py
./sll/jagged2_to_padded_dense_test.py
./sll/multi_head_jagged_flash_attention_test.py
)
}

Expand Down
21 changes: 1 addition & 20 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,6 @@
meta_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.sll.triton_sll import ( # noqa F401
jagged_dense_elementwise_mul_jagged_out,
triton_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.utils import TorchLibraryFragment

lib = TorchLibraryFragment("fbgemm")
Expand Down Expand Up @@ -262,25 +257,11 @@
},
}

# pyre-ignore[5]
sll_gpu_registrations = {
"sll_jagged_self_substraction_jagged_out": {
"CUDA": triton_jagged_self_substraction_jagged_out,
},
"sll_jagged_dense_elementwise_mul_jagged_out": {
"CUDA": jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
},
}

for op_name, dispatches in sll_cpu_registrations.items():
lib.register(op_name, dispatches)

if torch.cuda.is_available():
from fbgemm_gpu.sll.triton import op_registrations

for op_name, dispatches in op_registrations.items():
lib.register(op_name, dispatches)
from fbgemm_gpu.sll.triton import op_registrations as sll_gpu_registrations

for op_name, dispatches in sll_gpu_registrations.items():
lib.register(op_name, dispatches)
16 changes: 16 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@
JaggedDenseAdd, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_dense_elementwise_mul_jagged_out import ( # noqa F401
jagged_dense_elementwise_mul_jagged_out,
JaggedDenseElementwiseMul, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_dense_flash_attention import ( # noqa F401
jagged_dense_flash_attention,
JaggedDenseFlashAttention, # noqa F401
Expand All @@ -47,6 +52,10 @@
JaggedFlashAttentionBasic, # noqa F401
)

from fbgemm_gpu.sll.triton.triton_jagged_self_substraction_jagged_out import (
triton_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.sll.triton.triton_jagged_softmax import ( # noqa F401
jagged2_softmax,
Jagged2Softmax, # noqa F401
Expand Down Expand Up @@ -108,4 +117,11 @@
"CUDA": multi_head_jagged_flash_attention,
"AutogradCUDA": multi_head_jagged_flash_attention,
},
"sll_jagged_self_substraction_jagged_out": {
"CUDA": triton_jagged_self_substraction_jagged_out,
},
"sll_jagged_dense_elementwise_mul_jagged_out": {
"CUDA": jagged_dense_elementwise_mul_jagged_out,
"AutogradCUDA": jagged_dense_elementwise_mul_jagged_out,
},
}
22 changes: 22 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/triton/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@
import torch


def next_power_of_two(N: int) -> int:
if N > 4096:
raise Exception(f"{N} is too large that is not supported yet")

if N > 2048:
return 4096
elif N > 1024:
return 2048
elif N > 512:
return 1024
elif N > 256:
return 512
elif N > 128:
return 256
elif N > 64:
return 128
elif N > 32:
return 64
else:
return 32


def expect_contiguous(x: torch.Tensor) -> torch.Tensor:
if not x.is_contiguous():
return x.contiguous()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,61 +11,6 @@
import triton.language as tl


def next_power_of_two(N: int) -> int:
if N > 4096:
raise Exception(f"{N} is too large that is not supported yet")

if N > 2048:
return 4096
elif N > 1024:
return 2048
elif N > 512:
return 1024
elif N > 256:
return 512
elif N > 128:
return 256
elif N > 64:
return 128
elif N > 32:
return 64
else:
return 32


@triton.jit
def jagged_self_substraction_jagged_out_kernel(
a_ptr, # jagged
b_ptr, # jagged
a_offsets_ptr,
b_offsets_ptr,
max_seq_len,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
pid_index = tl.program_id(1)

a_offset = tl.load(a_offsets_ptr + pid_batch)
a_length = tl.load(a_offsets_ptr + pid_batch + 1) - a_offset
a_length = tl.minimum(a_length, max_seq_len + 1)

if a_length <= 1:
return

N = a_length - 1
if pid_index >= N:
return

a_cur = tl.load(a_ptr + a_offset + pid_index)
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < N
a_row = tl.load(a_ptr + a_offset + offs + 1, mask=mask)
b = a_cur - a_row

b_offset = tl.load(b_offsets_ptr + pid_batch)
tl.store(b_ptr + b_offset + pid_index * N + offs, b, mask=mask)


@triton.jit
def jagged_dense_elementwise_mul_jagged_out_kernel(
a_ptr, # 1d jagged
Expand Down Expand Up @@ -123,33 +68,6 @@ def jagged_dense_elementwise_mul_jagged_out_kernel(
c_ptrs += BLOCK_N


def triton_jagged_self_substraction_jagged_out(
jagged_A: torch.Tensor,
offsets_a: torch.Tensor,
offsets_b: torch.Tensor,
max_seq_len,
) -> torch.Tensor:
B = offsets_a.size(0) - 1

jagged_B = torch.empty(
(int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype
)

BLOCK_SIZE = max(next_power_of_two(max_seq_len), 16)
grid = (B, max_seq_len)

jagged_self_substraction_jagged_out_kernel[grid](
jagged_A,
jagged_B,
offsets_a,
offsets_b,
max_seq_len,
BLOCK_SIZE, # pyre-fixme[6]: For 6th argument expected `constexpr` but got `int`.
)

return jagged_B


def triton_jagged_dense_elementwise_mul_jagged_out(
jagged_A,
dense_B,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
import triton
import triton.language as tl

from .common import next_power_of_two


@triton.jit
def jagged_self_substraction_jagged_out_kernel(
a_ptr, # jagged
b_ptr, # jagged
a_offsets_ptr,
b_offsets_ptr,
max_seq_len,
BLOCK_SIZE: tl.constexpr,
):
pid_batch = tl.program_id(0)
pid_index = tl.program_id(1)

a_offset = tl.load(a_offsets_ptr + pid_batch)
a_length = tl.load(a_offsets_ptr + pid_batch + 1) - a_offset
a_length = tl.minimum(a_length, max_seq_len + 1)

if a_length <= 1:
return

N = a_length - 1
if pid_index >= N:
return

a_cur = tl.load(a_ptr + a_offset + pid_index)
offs = tl.arange(0, BLOCK_SIZE)
mask = offs < N
a_row = tl.load(a_ptr + a_offset + offs + 1, mask=mask)
b = a_cur - a_row

b_offset = tl.load(b_offsets_ptr + pid_batch)
tl.store(b_ptr + b_offset + pid_index * N + offs, b, mask=mask)


def triton_jagged_self_substraction_jagged_out(
jagged_A: torch.Tensor,
offsets_a: torch.Tensor,
offsets_b: torch.Tensor,
max_seq_len,
) -> torch.Tensor:
B = offsets_a.size(0) - 1

jagged_B = torch.empty(
(int(offsets_b[-1].item())), device=jagged_A.device, dtype=jagged_A.dtype
)

BLOCK_SIZE = max(next_power_of_two(max_seq_len), 16)
grid = (B, max_seq_len)

jagged_self_substraction_jagged_out_kernel[grid](
jagged_A,
jagged_B,
offsets_a,
offsets_b,
max_seq_len,
BLOCK_SIZE,
)

return jagged_B
6 changes: 3 additions & 3 deletions fbgemm_gpu/test/sll/array_jagged_bmm_jagged_out_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class ArrayJaggedBmmJaggedTest(unittest.TestCase):
)
@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(*running_on_rocm)
@settings(deadline=20000)
@settings(deadline=30000)
def test_triton_array_jagged_bmm_jagged_out(
self,
B: int,
Expand Down Expand Up @@ -157,7 +157,7 @@ def ref_array_jagged_bmm_jagged_out(
)
@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(*running_on_rocm)
@settings(deadline=20000)
@settings(deadline=30000)
def test_triton_array_jagged_bmm_jagged_out_with_grad(
self,
B: int,
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_triton_array_jagged_bmm_jagged_out_with_grad(
)
@unittest.skipIf(*gpu_unavailable)
@unittest.skipIf(*running_on_rocm)
@settings(deadline=20000)
@settings(deadline=30000)
def test_triton_array_jagged_bmm_jagged_out_meta_backend(
self,
B: int,
Expand Down
3 changes: 1 addition & 2 deletions fbgemm_gpu/test/sll/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
# pyre-ignore-all-errors[56]

import fbgemm_gpu
import fbgemm_gpu.sll.cpu_sll
import fbgemm_gpu.sll.triton_sll
import fbgemm_gpu.sll
import torch

# pyre-fixme[16]: Module `fbgemm_gpu` has no attribute `open_source`.
Expand Down
Loading

0 comments on commit c108e35

Please sign in to comment.