Skip to content

Commit

Permalink
Re-organize SLL ops, pt 9 (pytorch#3665)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#3665

X-link: facebookresearch/FBGEMM#740

- Move cpu_sll and meta_sll to their own folders

Reviewed By: brad-mengchi

Differential Revision: D69227334
  • Loading branch information
q10 authored and facebook-github-bot committed Feb 9, 2025
1 parent 3182ea5 commit c07e85d
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 83 deletions.
88 changes: 5 additions & 83 deletions fbgemm_gpu/fbgemm_gpu/sll/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,8 @@

import torch

from fbgemm_gpu.sll.cpu_sll import ( # noqa F401
cpu_array_jagged_bmm_jagged_out,
cpu_dense_jagged_cat_jagged_out,
cpu_jagged2_softmax,
cpu_jagged2_to_padded_dense,
cpu_jagged_dense_bmm,
cpu_jagged_dense_elementwise_add,
cpu_jagged_dense_elementwise_mul_jagged_out,
cpu_jagged_dense_flash_attention,
cpu_jagged_flash_attention_basic,
cpu_jagged_jagged_bmm,
cpu_jagged_jagged_bmm_jagged_out,
cpu_jagged_self_substraction_jagged_out,
cpu_jagged_softmax,
)

from fbgemm_gpu.sll.meta_sll import ( # noqa F401
meta_array_jagged_bmm_jagged_out,
meta_jagged2_softmax,
meta_jagged_dense_elementwise_mul_jagged_out,
meta_jagged_jagged_bmm_jagged_out,
meta_jagged_self_substraction_jagged_out,
)

from fbgemm_gpu.sll.cpu import op_registrations as sll_cpu_registrations
from fbgemm_gpu.sll.meta import op_registrations as sll_meta_registrations
from fbgemm_gpu.utils import TorchLibraryFragment

lib = TorchLibraryFragment("fbgemm")
Expand Down Expand Up @@ -198,68 +176,12 @@
# need the autograd forward to save the context because we don't need to do
# backward.

# pyre-ignore[5]
sll_cpu_registrations = {
"sll_jagged_dense_bmm": {
"CPU": cpu_jagged_dense_bmm,
"AutogradCPU": cpu_jagged_dense_bmm,
},
"sll_jagged_jagged_bmm": {
"CPU": cpu_jagged_jagged_bmm,
"AutogradCPU": cpu_jagged_jagged_bmm,
},
"sll_dense_jagged_cat_jagged_out": {
"CPU": cpu_dense_jagged_cat_jagged_out,
},
"sll_jagged_self_substraction_jagged_out": {
"CPU": cpu_jagged_self_substraction_jagged_out,
"Meta": meta_jagged_self_substraction_jagged_out,
},
"sll_jagged2_to_padded_dense": {
"CPU": cpu_jagged2_to_padded_dense,
"AutogradCPU": cpu_jagged2_to_padded_dense,
},
"sll_jagged_dense_elementwise_mul_jagged_out": {
"CPU": cpu_jagged_dense_elementwise_mul_jagged_out,
"AutogradCPU": cpu_jagged_dense_elementwise_mul_jagged_out,
"Meta": meta_jagged_dense_elementwise_mul_jagged_out,
},
"sll_jagged_softmax": {
"CPU": cpu_jagged_softmax,
"AutogradCPU": cpu_jagged_softmax,
},
"sll_jagged2_softmax": {
"CPU": cpu_jagged2_softmax,
"AutogradCPU": cpu_jagged2_softmax,
"AutogradMeta": meta_jagged2_softmax,
},
"sll_array_jagged_bmm_jagged_out": {
"CPU": cpu_array_jagged_bmm_jagged_out,
"AutogradCPU": cpu_array_jagged_bmm_jagged_out,
"AutogradMeta": meta_array_jagged_bmm_jagged_out,
},
"sll_jagged_jagged_bmm_jagged_out": {
"CPU": cpu_jagged_jagged_bmm_jagged_out,
"AutogradCPU": cpu_jagged_jagged_bmm_jagged_out,
"AutogradMeta": meta_jagged_jagged_bmm_jagged_out,
},
"sll_jagged_flash_attention_basic": {
"CPU": cpu_jagged_flash_attention_basic,
"AutogradCPU": cpu_jagged_flash_attention_basic,
},
"sll_jagged_dense_elementwise_add": {
"CPU": cpu_jagged_dense_elementwise_add,
"AutogradCPU": cpu_jagged_dense_elementwise_add,
},
"sll_jagged_dense_flash_attention": {
"CPU": cpu_jagged_dense_flash_attention,
"AutogradCPU": cpu_jagged_dense_flash_attention,
},
}

for op_name, dispatches in sll_cpu_registrations.items():
lib.register(op_name, dispatches)

for op_name, dispatches in sll_meta_registrations.items():
lib.register(op_name, dispatches)

if torch.cuda.is_available():
from fbgemm_gpu.sll.triton import op_registrations as sll_gpu_registrations

Expand Down
80 changes: 80 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/cpu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from fbgemm_gpu.sll.cpu.cpu_sll import ( # noqa F401
cpu_array_jagged_bmm_jagged_out,
cpu_array_jagged_bmm_jagged_out_kernel, # noqa F401
cpu_dense_jagged_cat_jagged_out,
cpu_jagged2_softmax,
cpu_jagged2_to_padded_dense,
cpu_jagged_dense_bmm,
cpu_jagged_dense_elementwise_add,
cpu_jagged_dense_elementwise_mul_jagged_out,
cpu_jagged_dense_flash_attention,
cpu_jagged_flash_attention_basic,
cpu_jagged_jagged_bmm,
cpu_jagged_jagged_bmm_jagged_out,
cpu_jagged_jagged_bmm_jagged_out_kernel, # noqa F401
cpu_jagged_self_substraction_jagged_out,
cpu_jagged_softmax,
)

# pyre-ignore[5]
op_registrations = {
"sll_jagged_dense_bmm": {
"CPU": cpu_jagged_dense_bmm,
"AutogradCPU": cpu_jagged_dense_bmm,
},
"sll_jagged_jagged_bmm": {
"CPU": cpu_jagged_jagged_bmm,
"AutogradCPU": cpu_jagged_jagged_bmm,
},
"sll_dense_jagged_cat_jagged_out": {
"CPU": cpu_dense_jagged_cat_jagged_out,
},
"sll_jagged_self_substraction_jagged_out": {
"CPU": cpu_jagged_self_substraction_jagged_out,
},
"sll_jagged2_to_padded_dense": {
"CPU": cpu_jagged2_to_padded_dense,
"AutogradCPU": cpu_jagged2_to_padded_dense,
},
"sll_jagged_dense_elementwise_mul_jagged_out": {
"CPU": cpu_jagged_dense_elementwise_mul_jagged_out,
"AutogradCPU": cpu_jagged_dense_elementwise_mul_jagged_out,
},
"sll_jagged_softmax": {
"CPU": cpu_jagged_softmax,
"AutogradCPU": cpu_jagged_softmax,
},
"sll_jagged2_softmax": {
"CPU": cpu_jagged2_softmax,
"AutogradCPU": cpu_jagged2_softmax,
},
"sll_array_jagged_bmm_jagged_out": {
"CPU": cpu_array_jagged_bmm_jagged_out,
"AutogradCPU": cpu_array_jagged_bmm_jagged_out,
},
"sll_jagged_jagged_bmm_jagged_out": {
"CPU": cpu_jagged_jagged_bmm_jagged_out,
"AutogradCPU": cpu_jagged_jagged_bmm_jagged_out,
},
"sll_jagged_flash_attention_basic": {
"CPU": cpu_jagged_flash_attention_basic,
"AutogradCPU": cpu_jagged_flash_attention_basic,
},
"sll_jagged_dense_elementwise_add": {
"CPU": cpu_jagged_dense_elementwise_add,
"AutogradCPU": cpu_jagged_dense_elementwise_add,
},
"sll_jagged_dense_flash_attention": {
"CPU": cpu_jagged_dense_flash_attention,
"AutogradCPU": cpu_jagged_dense_flash_attention,
},
}
File renamed without changes.
35 changes: 35 additions & 0 deletions fbgemm_gpu/fbgemm_gpu/sll/meta/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from fbgemm_gpu.sll.meta.meta_sll import ( # noqa F401
meta_array_jagged_bmm_jagged_out,
meta_jagged2_softmax,
meta_jagged_dense_elementwise_mul_jagged_out,
meta_jagged_jagged_bmm_jagged_out,
meta_jagged_self_substraction_jagged_out,
)

# pyre-ignore[5]
op_registrations = {
"sll_jagged_self_substraction_jagged_out": {
"Meta": meta_jagged_self_substraction_jagged_out,
},
"sll_jagged_dense_elementwise_mul_jagged_out": {
"Meta": meta_jagged_dense_elementwise_mul_jagged_out,
},
"sll_jagged2_softmax": {
"AutogradMeta": meta_jagged2_softmax,
},
"sll_array_jagged_bmm_jagged_out": {
"AutogradMeta": meta_array_jagged_bmm_jagged_out,
},
"sll_jagged_jagged_bmm_jagged_out": {
"AutogradMeta": meta_jagged_jagged_bmm_jagged_out,
},
}
File renamed without changes.

0 comments on commit c07e85d

Please sign in to comment.