Skip to content

Commit

Permalink
add unsupported backend error
Browse files Browse the repository at this point in the history
  • Loading branch information
IsaacRe committed Oct 28, 2024
1 parent e4b0287 commit 5e2a639
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 2 deletions.
3 changes: 2 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ def __init__(
attn_backend = get_attn_backend(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size, blocksparse_params
is not None)
is not None,
cache_config.enable_kvcompress)
impl_cls = attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
Expand Down
11 changes: 10 additions & 1 deletion vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class _Backend(enum.Enum):
IPEX = enum.auto()


KVC_SUPPORTED_BACKENDS = [_Backend.FLASH_ATTN]


def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None

Expand Down Expand Up @@ -96,6 +99,7 @@ def get_attn_backend(
kv_cache_dtype: Optional[str],
block_size: int,
is_blocksparse: bool = False,
enable_kvcompress: bool = False,
) -> Type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""

Expand All @@ -107,7 +111,7 @@ def get_attn_backend(

backend = which_attn_to_use(num_heads, head_size, num_kv_heads,
sliding_window, dtype, kv_cache_dtype,
block_size)
block_size, enable_kvcompress)
if backend == _Backend.FLASH_ATTN:
from vllm.attention.backends.flash_attn import ( # noqa: F401
FlashAttentionBackend)
Expand Down Expand Up @@ -158,6 +162,7 @@ def which_attn_to_use(
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
enable_kvcompress: bool = False,
) -> _Backend:
"""Returns which flash attention backend to use."""
# Default case.
Expand Down Expand Up @@ -262,6 +267,10 @@ def which_attn_to_use(
"`pip install vllm-flash-attn` for better performance.")
selected_backend = _Backend.XFORMERS

if enable_kvcompress and selected_backend not in KVC_SUPPORTED_BACKENDS:
raise ValueError(f"selected backend {selected_backend.name} is not "
"compatible with KV-Compress.")

return selected_backend


Expand Down
1 change: 1 addition & 0 deletions vllm/worker/cache_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
self.cache_config.enable_kvcompress,
)

# Initialize the cache. KV-Compress uses a unified KV cache where
Expand Down

0 comments on commit 5e2a639

Please sign in to comment.