Skip to content

Commit

Permalink
[kernel] update triton init #4740 (#4740)
Browse files Browse the repository at this point in the history
  • Loading branch information
oahzxl authored Sep 18, 2023
1 parent d151dca commit 32e7f99
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .softmax import softmax
from .token_attention_kernel import token_attention_fwd
try:
import triton
HAS_TRITON = True

__all__ = [
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
]
from .context_attention import bloom_context_attn_fwd, llama_context_attn_fwd
from .copy_kv_cache_dest import copy_kv_cache_to_dest
from .fused_layernorm import layer_norm
from .rms_norm import rmsnorm_forward
from .rotary_embedding_kernel import rotary_embedding_fwd
from .softmax import softmax
from .token_attention_kernel import token_attention_fwd

__all__ = [
"llama_context_attn_fwd", "bloom_context_attn_fwd", "softmax", "layer_norm", "rmsnorm_forward",
"copy_kv_cache_to_dest", "rotary_embedding_fwd", "token_attention_fwd"
]

except ImportError:
HAS_TRITON = False
print("Triton is not installed. Please install Triton to use Triton kernels.")

0 comments on commit 32e7f99

Please sign in to comment.