diff --git a/python/llm/example/CPU/vLLM-Serving/README.md b/python/llm/example/CPU/vLLM-Serving/README.md index 44162d30f3a..1ff3161ece9 100644 --- a/python/llm/example/CPU/vLLM-Serving/README.md +++ b/python/llm/example/CPU/vLLM-Serving/README.md @@ -31,7 +31,7 @@ pip3 install "pydantic<2" # Required for OpenAI server. ### 2. Configure recommended environment variables ```bash -source bigdl-llm-init +source bigdl-llm-init -t ``` ### 3. Offline inference/Service @@ -55,9 +55,12 @@ To fully utilize the continuous batching feature of the `vLLM`, you can send req ```bash #!/bin/bash -numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.examples.api_server \ +# You may also want to adjust the `--max-num-batched-tokens` argument, it indicates the hard limit +# of batched prompt length the server will accept +numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.entrypoints.openai.api_server \ --model /MODEL_PATH/Llama-2-7b-chat-hf-bigdl/ --port 8000 \ - --load-format 'auto' --device cpu --dtype bfloat16 + --load-format 'auto' --device cpu --dtype bfloat16 \ + --max-num-batched-tokens 4096 ``` Then you can access the api server as follows: @@ -80,12 +83,12 @@ Currently we have only supported LLaMA family model (including `llama`, `vicuna` #### 4.1 Add model code -Create or clone the Pytorch model code to `./models`. +Create or clone the Pytorch model code to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models`. #### 4.2 Rewrite the forward methods -Refering to `./models/bigdl_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after that you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished. +Refering to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after that you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished. #### 4.3 Register new model -Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `./models/model_loader.py`. +Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/model_loader.py`. diff --git a/python/llm/example/CPU/vLLM-Serving/offline_inference.py b/python/llm/example/CPU/vLLM-Serving/offline_inference.py index 99605420c84..45f4aa18988 100644 --- a/python/llm/example/CPU/vLLM-Serving/offline_inference.py +++ b/python/llm/example/CPU/vLLM-Serving/offline_inference.py @@ -14,7 +14,7 @@ # limitations under the License. # # Some parts of this file is adapted from -# https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py +# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/examples/offline_inference.py # which is licensed under Apache License 2.0 # # Copyright 2023 The vLLM team. All rights reserved. @@ -31,8 +31,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from bigdl.llm.vllm.examples.llm import LLM -from bigdl.llm.vllm.structure.sampling_params import SamplingParams +from bigdl.llm.vllm.entrypoints.llm import LLM +from bigdl.llm.vllm.sampling_params import SamplingParams # Sample prompts. prompts = [ diff --git a/python/llm/example/GPU/README.md b/python/llm/example/GPU/README.md index 0649cd30b17..e0b1d6f83ed 100644 --- a/python/llm/example/GPU/README.md +++ b/python/llm/example/GPU/README.md @@ -4,6 +4,7 @@ This folder contains examples of running BigDL-LLM on Intel GPU: - [HF-Transformers-AutoModels](HF-Transformers-AutoModels): running any ***Hugging Face Transformers*** model on BigDL-LLM (using the standard AutoModel APIs) - [QLoRA-FineTuning](QLoRA-FineTuning): running ***QLoRA finetuning*** using BigDL-LLM on Intel GPUs +- [vLLM-Serving](vLLM-Serving): running ***vLLM*** serving framework on intel GPUs (with BigDL-LLM low-bit optimized models) - [Deepspeed-AutoTP](Deepspeed-AutoTP): running distributed inference using ***DeepSpeed AutoTP*** (with BigDL-LLM low-bit optimized models) on Intel GPUs - [PyTorch-Models](PyTorch-Models): running any PyTorch model on BigDL-LLM (with "one-line code change") diff --git a/python/llm/example/GPU/vLLM-Serving/README.md b/python/llm/example/GPU/vLLM-Serving/README.md new file mode 100644 index 00000000000..b2b3f2f7908 --- /dev/null +++ b/python/llm/example/GPU/vLLM-Serving/README.md @@ -0,0 +1,109 @@ +# vLLM continuous batching on Intel GPUs (experimental support) + +This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with BigDL-LLM low-bits optimizations). + +The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.2.1.post1). + +## Example: Serving LLaMA2-7B using Intel GPU + +In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI-compatible` interface for users. + +### 0. Environment + +To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit. Please check the requirements at [here](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU#requirements). + +After install the toolkit, run the following commands in your environment before starting vLLM GPU: +```bash +source /opt/intel/oneapi/setvars.sh +# sycl-ls will list all the compatible Intel GPUs in your environment +sycl-ls + +# Example output with one Arc A770: +[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device 1.2 [2023.16.7.0.21_160000] +[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i9-13900K 3.0 [2023.16.7.0.21_160000] +[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics 3.0 [23.17.26241.33] +[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26241] +``` + +### 1. Install + +To run vLLM continuous batching on Intel GPUs, install the dependencies as follows: + +```bash +# First create an conda environment +conda create -n bigdl-vllm python==3.9 +conda activate bigdl-vllm +# Install dependencies +pip3 install psutil +pip3 install sentencepiece # Required for LLaMA tokenizer. +pip3 install numpy +pip3 install "transformers>=4.33.1" # Required for Code Llama. +pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu +pip3 install fastapi +pip3 install "uvicorn[standard]" +pip3 install "pydantic<2" # Required for OpenAI server. +``` + +### 2. Configure recommended environment variables + +```bash +export USE_XETLA=OFF +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +``` + +### 3. Offline inference/Service + +#### Offline inference + +To run offline inference using vLLM for a quick impression, use the following example: + +```bash +#!/bin/bash + +# Please first modify the MODEL_PATH in offline_inference.py +python offline_inference.py +``` + +#### Service + +To fully utilize the continuous batching feature of the `vLLM`, you can send requests to the service using curl or other similar methods. The requests sent to the engine will be batched at token level. Queries will be executed in the same `forward` step of the LLM and be removed when they are finished instead of waiting for all sequences to be finished. + +```bash +#!/bin/bash +# You may also want to adjust the `--max-num-batched-tokens` argument, it indicates the hard limit +# of batched prompt length the server will accept +python -m bigdl.llm.vllm.entrypoints.openai.api_server \ + --model /MODEL_PATH/Llama-2-7b-chat-hf/ --port 8000 \ + --load-format 'auto' --device xpu --dtype bfloat16 \ + --max-num-batched-tokens 4096 +``` + +Then you can access the api server as follows: + +```bash + + curl http://localhost:8000/v1/completions \ + -H "Content-Type: application/json" \ + -d '{ + "model": "/MODEL_PATH/Llama-2-7b-chat-hf-bigdl/", + "prompt": "San Francisco is a", + "max_tokens": 128, + "temperature": 0 + }' & +``` + +### 4. (Optional) Add a new model + +Currently we have only supported LLaMA family model (including `llama`, `vicuna`, `llama-2`, etc.). To use aother model, you may need add some adaptions. + +#### 4.1 Add model code + +Create or clone the Pytorch model code to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models`. + +#### 4.2 Rewrite the forward methods + +Refering to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after that you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished. + +#### 4.3 Register new model + +Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/model_loader.py`. diff --git a/python/llm/example/GPU/vLLM-Serving/offline_inference.py b/python/llm/example/GPU/vLLM-Serving/offline_inference.py new file mode 100644 index 00000000000..994781d6b45 --- /dev/null +++ b/python/llm/example/GPU/vLLM-Serving/offline_inference.py @@ -0,0 +1,57 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/examples/offline_inference.py +# which is licensed under Apache License 2.0 +# +# Copyright 2023 The vLLM team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from bigdl.llm.vllm.entrypoints.llm import LLM +from bigdl.llm.vllm.sampling_params import SamplingParams + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +# llm = LLM(model="facebook/opt-125m") +llm = LLM(model="YOUR_MODEL_PATH", dtype="bfloat16", device="xpu") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/python/llm/src/bigdl/llm/vllm/core/scheduler.py b/python/llm/src/bigdl/llm/vllm/core/scheduler.py index 8a5018d5964..e5f51b94cec 100644 --- a/python/llm/src/bigdl/llm/vllm/core/scheduler.py +++ b/python/llm/src/bigdl/llm/vllm/core/scheduler.py @@ -79,6 +79,7 @@ class FixedWindowScheduler: def __init__( self, scheduler_config: SchedulerConfig, + kv_cache: Optional, ) -> None: self.scheduler_config = scheduler_config self.prompt_limit = min(self.scheduler_config.max_model_len, @@ -98,6 +99,7 @@ def __init__( # Sequence groups in the RUNNING state. self.running: List[SequenceGroup] = [] self.cleaned: List[int] = [] + self.kv_cache = kv_cache # Co(gc): We no longer have the swapped space as we are not deciding which to swap # bigdl-llm change end @@ -150,6 +152,8 @@ def _schedule(self) -> SchedulerOutputs: # We restrict how many requests that can be run using these three arguments # Co(gc): If there are waiting requests, we will just try to add it into the # running state if not exceeds the stage + # Co(gc): Record seq_len for prefill requests + seq_lens = [] # Co(gc): prefilled requests are prioritized over decoding stage requests while self.waiting: seq_group = self.waiting[0] @@ -178,7 +182,9 @@ def _schedule(self) -> SchedulerOutputs: # bigdl-llm change end # If the number of batched tokens exceeds the limit, stop. - if (num_batched_tokens + num_prompt_tokens > + new_seq_lens = seq_lens + [num_prompt_tokens] + num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) + if (num_batched_tokens > self.scheduler_config.max_num_batched_tokens): break @@ -192,6 +198,8 @@ def _schedule(self) -> SchedulerOutputs: seq_group = self.waiting.pop(0) for seq in seq_group.get_seqs(): seq.status = SequenceStatus.RUNNING + # Co(gc): Only updated the seq_lens when all check passes + seq_lens = new_seq_lens # bigdl-llm change start # summary: removing block_manager related logic. # self._allocate(seq_group) @@ -204,7 +212,7 @@ def _schedule(self) -> SchedulerOutputs: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, prompt_run=True, - num_batched_tokens=num_batched_tokens, + num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0, ignored_seq_groups=ignored_seq_groups, finished_seqs=finished_seqs, ) @@ -258,6 +266,13 @@ def free_seq(self, seq: Sequence) -> None: # summary: The original code free the block in block_manager. # now, we added it into a list to pass to worker in the next model_execute stage. self.cleaned.append(seq.seq_id) + for i in range(len(self.kv_cache)): + for j in range(2): + if not self.kv_cache[i][j].get(seq.seq_id) is None: + del self.kv_cache[i][j][seq.seq_id] + # del self.kv_cache[seq.seq_id] + # logger.info(f"freed seqs: {seq.seq_id} . + # now kv cache is: {list(self.kv_cache[0][0].keys())} ") # bigdl-llm change end def free_finished_seq_groups(self) -> None: diff --git a/python/llm/src/bigdl/llm/vllm/engine/async_llm_engine.py b/python/llm/src/bigdl/llm/vllm/engine/async_llm_engine.py index aa8cadf9b62..46eaf741c9c 100644 --- a/python/llm/src/bigdl/llm/vllm/engine/async_llm_engine.py +++ b/python/llm/src/bigdl/llm/vllm/engine/async_llm_engine.py @@ -243,15 +243,16 @@ async def _run_workers_async( ) -> Any: """Runs the given method on all workers.""" # bigdl-llm change start - all_outputs = [] + coros = [] for worker in self.workers: # if self.parallel_config.worker_use_ray: # executor = partial(worker.execute_method.remote, method) # else: executor = getattr(worker, method) + coros.append(asyncio.get_event_loop().run_in_executor( + None, partial(executor, *args, **kwargs))) - output = executor(*args, **kwargs) - all_outputs.append(output) + all_outputs = await asyncio.gather(*coros) # if self.parallel_config.worker_use_ray: # all_outputs = await asyncio.gather(*all_outputs) diff --git a/python/llm/src/bigdl/llm/vllm/engine/llm_engine.py b/python/llm/src/bigdl/llm/vllm/engine/llm_engine.py index c40f45613a7..5978fc38373 100644 --- a/python/llm/src/bigdl/llm/vllm/engine/llm_engine.py +++ b/python/llm/src/bigdl/llm/vllm/engine/llm_engine.py @@ -36,7 +36,7 @@ # import time -from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, Dict from bigdl.llm.vllm.config import ModelConfig, SchedulerConfig from bigdl.llm.vllm.core.scheduler import SchedulerOutputs, FixedWindowScheduler @@ -127,6 +127,7 @@ def __init__( # self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.log_stats = log_stats + self.kv_cache = [[dict() for _ in range(2)] for _ in range(32)] # self._verify_args() self.tokenizer = get_tokenizer( @@ -142,7 +143,7 @@ def __init__( self._init_workers() # Co(gc): we create a fixed scheduler - self.scheduler = FixedWindowScheduler(scheduler_config) + self.scheduler = FixedWindowScheduler(scheduler_config, kv_cache=self.kv_cache) # Logging. self.last_logging_time = 0.0 @@ -170,6 +171,7 @@ def _init_workers(self): self.scheduler_config, 0, # distributed_init_method, + kv_cache=self.kv_cache ) self.workers.append(worker) self._run_workers( diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py index a59890397ab..b53bfc8f5bf 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py @@ -23,9 +23,9 @@ from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata from bigdl.llm.vllm.model_executor.layers.bigdl_sampler import BigDLSampler from bigdl.llm.vllm.model_executor.models.bigdl_model import BigDLModelForCausalLM +from bigdl.llm.vllm.logger import init_logger import math import time - from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, @@ -35,6 +35,9 @@ ) +logger = init_logger(__name__) + + def _pad_to_max(x: List[int], max_len: int, padding_id: int = 0) -> List[int]: return x + [padding_id] * (max_len - len(x)) @@ -87,7 +90,6 @@ def __init__( else: self.device = torch.device(device) self.dtype = self.model.dtype - self.kv_cache_size = [0] self.last_seq_ids = [] self.tmp_kv_cache = None self.pad_token_id = config.pad_token_id @@ -170,15 +172,25 @@ def forward( # "return_dict": True, } # pdb.set_trace() + if self.device.type == 'xpu': + torch.xpu.empty_cache() st_timestamp = time.perf_counter() outputs = self.model.forward(**kwargs) + # tmp = torch.xpu.memory_stats() + # logger.info(f"0: {tmp['allocated_bytes.all.current']}") + # self.last_seq_ids = cur_seq_ids[:] + # self.last_kv_cache = outputs.past_key_values + self._set_last_seq_ids(cur_seq_ids[:]) + self._set_last_kv_cache(outputs.past_key_values) - self.last_seq_ids = cur_seq_ids[:] - self.tmp_kv_cache = outputs.past_key_values logits = outputs.logits[:, -1, :] bigdl_output = self.sampler(logits, input_metadata, st_timestamp) + # tmp = torch.xpu.memory_stats() + # logger.info(f"before: {tmp['allocated_bytes.all.current']}") - self.update_kv_cache(cur_seq_ids, outputs.past_key_values, + self.update_kv_cache(cur_seq_ids, kv_cache, kv_cache_size_0, kv_cache_size_1) + # tmp = torch.xpu.memory_stats() + # logger.info(f"after: {tmp['allocated_bytes.all.current']}") return bigdl_output diff --git a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py index 3a8198f94c6..4bc95d1382e 100644 --- a/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py +++ b/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_model.py @@ -20,21 +20,31 @@ from transformers import LlamaConfig from bigdl.llm.vllm.sequence import SequenceOutputs, SequenceGroupMetadata +from bigdl.llm.transformers.models.utils import extend_kv_cache + +zero_cache_dict = {} + + +def get_zero_tensor(length, cur_size, device, pos): + if length not in zero_cache_dict: + tmp_size = cur_size[:] + tmp_size[pos] = length + zero_cache_dict[length] = torch.zeros(tmp_size, device=device) + return zero_cache_dict[length].narrow(pos, 0, length - cur_size[pos]) def _pad_kv_cache_view(t: torch.Tensor, len: int, device: torch.device, pos: int = 2) -> torch.Tensor: cur_size = list(t.size()) if cur_size[pos] < len: - tmp_size = cur_size[:] - tmp_size[pos] = len - cur_size[pos] - zeros = torch.zeros(tmp_size, device=device) + zeros = get_zero_tensor(len, cur_size, device, pos) padded_view = torch.cat((zeros, t), dim=pos) return padded_view - if cur_size[pos] > len: + elif cur_size[pos] > len: padded_view = t.narrow(pos, cur_size[pos] - len, len) return padded_view - return t + else: + return t class BigDLModelForCausalLM(nn.Module): @@ -52,10 +62,23 @@ def __init__( "cuda" if torch.cuda.is_available() else "cpu") else: self.device = torch.device(device) + if device == 'xpu': + try: + import intel_extension_for_pytorch as ipex + except ImportError: + print("Intel Extension for PyTorch is not installed, \ + but is required for xpu inference.") + self.max_seq_limit = max_model_len self.last_kv_cache = None self.last_seq_ids = None + def _set_last_kv_cache(self, last_kv_cache): + self.last_kv_cache = last_kv_cache + + def _set_last_seq_ids(self, last_seq_ids): + self.last_seq_ids = last_seq_ids + # This is an implementation for models that KV Cache shape in (batch_size, num_heads, # sequence_length, embed_size_per_head). def prepare_kv_cache( @@ -69,39 +92,41 @@ def prepare_kv_cache( max_seq_limit = self.max_seq_limit if (self.last_kv_cache is not None) and cur_seq_ids == self.last_seq_ids: if self.last_kv_cache[0][0].size(2) < max_seq_limit: - bigdl_kv_cache = self.tmp_kv_cache + bigdl_kv_cache = self.last_kv_cache else: bigdl_kv_cache = [[tmp.narrow(2, self.last_kv_cache[0][0].size(2) - max_seq_limit, max_seq_limit) for tmp in tmp_list] for tmp_list in self.last_kv_cache] + del self.last_kv_cache else: + del self.last_kv_cache bigdl_kv_cache = [] for i in range(kv_cache_size_0): cur_list = [] for j in range(kv_cache_size_1): - cur_view = None + views = [] + max_len = 0 for seq_group_meta_data in seq_group_meta_data_lists: seq_ids = list(seq_group_meta_data.seq_data.keys()) seq_id = seq_ids[0] seq_data = seq_group_meta_data.seq_data[seq_id] - view_size = [1] + list(kv_cache[seq_id][i][j].shape) - if cur_view is None: - cur_view = kv_cache[seq_id][i][j].view(view_size) - else: - if cur_view.size(2) != view_size[2]: - max_len = max(cur_view.size(2), view_size[2]) - cur_view = _pad_kv_cache_view(cur_view, max_len, self.device) - tmp_view = _pad_kv_cache_view( - kv_cache[seq_id][i][j].view(view_size), - max_len, self.device) - cur_view = torch.cat((cur_view, tmp_view), dim=0) - else: - cur_view = torch.cat( - (cur_view, kv_cache[seq_id][i][j].view(view_size)), dim=0) - if cur_view.size(2) > max_seq_limit: + view_size = [1] + list(kv_cache[i][j][seq_id].shape) + views.append(kv_cache[i][j][seq_id].view(view_size)) + max_len = max(max_len, view_size[2]) + + views = [_pad_kv_cache_view(v, max_len, self.device) for v in views] + cur_view = torch.cat(views, dim=0) + + if cur_view.size(2) > max_seq_limit * 1.5: cur_view = _pad_kv_cache_view(cur_view, max_seq_limit, self.device) cur_list.append(cur_view) + + for seq_group_meta_data in seq_group_meta_data_lists: + seq_ids = list(seq_group_meta_data.seq_data.keys()) + seq_id = seq_ids[0] + del kv_cache[i][j][seq_id] bigdl_kv_cache.append(cur_list) + return bigdl_kv_cache # This is an implementation for models that KV Cache shape in (batch_size, num_heads, @@ -109,20 +134,16 @@ def prepare_kv_cache( def update_kv_cache( self, cur_seq_ids: List[int], - past_key_values: List[List[torch.Tensor]], - kv_cache: Dict, + kv_cache, kv_cache_size_0: int, kv_cache_size_1: int, ) -> None: - index = 0 - for seq_id in cur_seq_ids: - if kv_cache.get(seq_id) is None: - kv_cache[seq_id] = [[[] for _ in range(kv_cache_size_1)] - for _ in range(kv_cache_size_0)] - for i in range(kv_cache_size_0): - for j in range(kv_cache_size_1): - kv_cache[seq_id][i][j] = past_key_values[i][j][index] - index = index + 1 + for i in range(kv_cache_size_0): + for j in range(kv_cache_size_1): + index = 0 + for seq_id in cur_seq_ids: + kv_cache[i][j][seq_id] = self.last_kv_cache[i][j][index] + index = index + 1 def forward( self, diff --git a/python/llm/src/bigdl/llm/vllm/worker/worker.py b/python/llm/src/bigdl/llm/vllm/worker/worker.py index b01824bb2f4..b0005e798d8 100644 --- a/python/llm/src/bigdl/llm/vllm/worker/worker.py +++ b/python/llm/src/bigdl/llm/vllm/worker/worker.py @@ -68,6 +68,7 @@ def __init__( scheduler_config: SchedulerConfig, rank: Optional[int] = None, # distributed_init_method: Optional[str] = None, + kv_cache: Optional[Dict] = None, ) -> None: self.model_config = model_config # self.parallel_config = parallel_config @@ -84,17 +85,18 @@ def __init__( self.cache_events = None self.gpu_cache = None - self.kv_cache = dict() + self.kv_cache = kv_cache def clean_finished_seqs(self, finished_seqs: List[int]): """ This function cleans the finished sequences and their KVCache in self.kv_cache """ - for seq_id in finished_seqs: - if seq_id not in self.kv_cache.keys(): - warnings.warn(f"Duplicate key {seq_id} received during clean worker's KVCache") - continue - del self.kv_cache[seq_id] + pass + # for seq_id in finished_seqs: + # if seq_id not in self.kv_cache.keys(): + # # warnings.warn(f"Duplicate key {seq_id} received during clean worker's KVCache") + # continue + # del self.kv_cache[seq_id] def init_model(self): if self.model_config.device == 'gpu': @@ -282,6 +284,10 @@ def execute_model( if finished_seqs: self.clean_finished_seqs(finished_seqs) + # if self.model_config.device == 'xpu': + # import intel_extension_for_pytorch as ipex + # torch.xpu.empty_cache() + cache_events = None # If there is no input, we don't need to execute the model. if not seq_group_metadata_list: