diff --git a/medusa/model/modeling_llama_kv.py b/medusa/model/modeling_llama_kv.py index abf9382..3d61719 100644 --- a/medusa/model/modeling_llama_kv.py +++ b/medusa/model/modeling_llama_kv.py @@ -22,7 +22,7 @@ from transformers.utils import ( add_start_docstrings, add_start_docstrings_to_model_forward, - is_flash_attn_available, + is_flash_attn_2_available, logging, replace_return_docstrings, )