Skip to content

Commit

Permalink
Add vLLM-XPU version's README/examples (intel#9536)
Browse files Browse the repository at this point in the history
* test

* test

* fix last kv cache

* add xpu readme

* remove numactl for xpu example

* fix link error

* update max_num_batched_tokens logic

* add explaination

* add xpu environement version requirement

* refine gpu memory

* fix

* fix style
  • Loading branch information
gc-fu authored Nov 28, 2023
1 parent 22d2157 commit fe2def0
Show file tree
Hide file tree
Showing 11 changed files with 287 additions and 60 deletions.
15 changes: 9 additions & 6 deletions python/llm/example/CPU/vLLM-Serving/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pip3 install "pydantic<2" # Required for OpenAI server.
### 2. Configure recommended environment variables

```bash
source bigdl-llm-init
source bigdl-llm-init -t
```

### 3. Offline inference/Service
Expand All @@ -55,9 +55,12 @@ To fully utilize the continuous batching feature of the `vLLM`, you can send req

```bash
#!/bin/bash
numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.examples.api_server \
# You may also want to adjust the `--max-num-batched-tokens` argument, it indicates the hard limit
# of batched prompt length the server will accept
numactl -C 48-95 -m 1 python -m bigdl.llm.vllm.entrypoints.openai.api_server \
--model /MODEL_PATH/Llama-2-7b-chat-hf-bigdl/ --port 8000 \
--load-format 'auto' --device cpu --dtype bfloat16
--load-format 'auto' --device cpu --dtype bfloat16 \
--max-num-batched-tokens 4096
```

Then you can access the api server as follows:
Expand All @@ -80,12 +83,12 @@ Currently we have only supported LLaMA family model (including `llama`, `vicuna`

#### 4.1 Add model code

Create or clone the Pytorch model code to `./models`.
Create or clone the Pytorch model code to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models`.

#### 4.2 Rewrite the forward methods

Refering to `./models/bigdl_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after that you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished.
Refering to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after that you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished.

#### 4.3 Register new model

Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `./models/model_loader.py`.
Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/model_loader.py`.
6 changes: 3 additions & 3 deletions python/llm/example/CPU/vLLM-Serving/offline_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/examples/offline_inference.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
Expand All @@ -31,8 +31,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from bigdl.llm.vllm.examples.llm import LLM
from bigdl.llm.vllm.structure.sampling_params import SamplingParams
from bigdl.llm.vllm.entrypoints.llm import LLM
from bigdl.llm.vllm.sampling_params import SamplingParams

# Sample prompts.
prompts = [
Expand Down
1 change: 1 addition & 0 deletions python/llm/example/GPU/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ This folder contains examples of running BigDL-LLM on Intel GPU:

- [HF-Transformers-AutoModels](HF-Transformers-AutoModels): running any ***Hugging Face Transformers*** model on BigDL-LLM (using the standard AutoModel APIs)
- [QLoRA-FineTuning](QLoRA-FineTuning): running ***QLoRA finetuning*** using BigDL-LLM on Intel GPUs
- [vLLM-Serving](vLLM-Serving): running ***vLLM*** serving framework on intel GPUs (with BigDL-LLM low-bit optimized models)
- [Deepspeed-AutoTP](Deepspeed-AutoTP): running distributed inference using ***DeepSpeed AutoTP*** (with BigDL-LLM low-bit optimized models) on Intel GPUs
- [PyTorch-Models](PyTorch-Models): running any PyTorch model on BigDL-LLM (with "one-line code change")

Expand Down
109 changes: 109 additions & 0 deletions python/llm/example/GPU/vLLM-Serving/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# vLLM continuous batching on Intel GPUs (experimental support)

This example demonstrates how to serve a LLaMA2-7B model using vLLM continuous batching on Intel GPU (with BigDL-LLM low-bits optimizations).

The code shown in the following example is ported from [vLLM](https://github.com/vllm-project/vllm/tree/v0.2.1.post1).

## Example: Serving LLaMA2-7B using Intel GPU

In this example, we will run Llama2-7b model using Arc A770 and provide `OpenAI-compatible` interface for users.

### 0. Environment

To use Intel GPUs for deep-learning tasks, you should install the XPU driver and the oneAPI Base Toolkit. Please check the requirements at [here](https://github.com/intel-analytics/BigDL/tree/main/python/llm/example/GPU#requirements).

After install the toolkit, run the following commands in your environment before starting vLLM GPU:
```bash
source /opt/intel/oneapi/setvars.sh
# sycl-ls will list all the compatible Intel GPUs in your environment
sycl-ls

# Example output with one Arc A770:
[opencl:acc:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device 1.2 [2023.16.7.0.21_160000]
[opencl:cpu:1] Intel(R) OpenCL, 13th Gen Intel(R) Core(TM) i9-13900K 3.0 [2023.16.7.0.21_160000]
[opencl:gpu:2] Intel(R) OpenCL Graphics, Intel(R) Arc(TM) A770 Graphics 3.0 [23.17.26241.33]
[ext_oneapi_level_zero:gpu:0] Intel(R) Level-Zero, Intel(R) Arc(TM) A770 Graphics 1.3 [1.3.26241]
```

### 1. Install

To run vLLM continuous batching on Intel GPUs, install the dependencies as follows:

```bash
# First create an conda environment
conda create -n bigdl-vllm python==3.9
conda activate bigdl-vllm
# Install dependencies
pip3 install psutil
pip3 install sentencepiece # Required for LLaMA tokenizer.
pip3 install numpy
pip3 install "transformers>=4.33.1" # Required for Code Llama.
pip install --pre --upgrade bigdl-llm[xpu] -f https://developer.intel.com/ipex-whl-stable-xpu
pip3 install fastapi
pip3 install "uvicorn[standard]"
pip3 install "pydantic<2" # Required for OpenAI server.
```

### 2. Configure recommended environment variables

```bash
export USE_XETLA=OFF
export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1
```

### 3. Offline inference/Service

#### Offline inference

To run offline inference using vLLM for a quick impression, use the following example:

```bash
#!/bin/bash

# Please first modify the MODEL_PATH in offline_inference.py
python offline_inference.py
```

#### Service

To fully utilize the continuous batching feature of the `vLLM`, you can send requests to the service using curl or other similar methods. The requests sent to the engine will be batched at token level. Queries will be executed in the same `forward` step of the LLM and be removed when they are finished instead of waiting for all sequences to be finished.

```bash
#!/bin/bash
# You may also want to adjust the `--max-num-batched-tokens` argument, it indicates the hard limit
# of batched prompt length the server will accept
python -m bigdl.llm.vllm.entrypoints.openai.api_server \
--model /MODEL_PATH/Llama-2-7b-chat-hf/ --port 8000 \
--load-format 'auto' --device xpu --dtype bfloat16 \
--max-num-batched-tokens 4096
```

Then you can access the api server as follows:

```bash

curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "/MODEL_PATH/Llama-2-7b-chat-hf-bigdl/",
"prompt": "San Francisco is a",
"max_tokens": 128,
"temperature": 0
}' &
```

### 4. (Optional) Add a new model

Currently we have only supported LLaMA family model (including `llama`, `vicuna`, `llama-2`, etc.). To use aother model, you may need add some adaptions.

#### 4.1 Add model code

Create or clone the Pytorch model code to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models`.

#### 4.2 Rewrite the forward methods

Refering to `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/models/bigdl_llama.py`, it's necessary to maintain a `kv_cache`, which is a nested list of dictionary that maps `req_id` to a three-dimensional tensor **(the structure may vary from models)**. Before the model's actual `forward` method, you could prepare a `past_key_values` according to current `req_id`, and after that you need to update the `kv_cache` with `output.past_key_values`. The clearence will be executed when the request is finished.

#### 4.3 Register new model

Finally, register your `*ForCausalLM` class to the _MODEL_REGISTRY in `BigDL/python/llm/src/bigdl/llm/vllm/model_executor/model_loader.py`.
57 changes: 57 additions & 0 deletions python/llm/example/GPU/vLLM-Serving/offline_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Some parts of this file is adapted from
# https://github.com/vllm-project/vllm/blob/v0.2.1.post1/examples/offline_inference.py
# which is licensed under Apache License 2.0
#
# Copyright 2023 The vLLM team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from bigdl.llm.vllm.entrypoints.llm import LLM
from bigdl.llm.vllm.sampling_params import SamplingParams

# Sample prompts.
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM.
# llm = LLM(model="facebook/opt-125m")
llm = LLM(model="YOUR_MODEL_PATH", dtype="bfloat16", device="xpu")
# Generate texts from the prompts. The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
outputs = llm.generate(prompts, sampling_params)
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
19 changes: 17 additions & 2 deletions python/llm/src/bigdl/llm/vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class FixedWindowScheduler:
def __init__(
self,
scheduler_config: SchedulerConfig,
kv_cache: Optional,
) -> None:
self.scheduler_config = scheduler_config
self.prompt_limit = min(self.scheduler_config.max_model_len,
Expand All @@ -98,6 +99,7 @@ def __init__(
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
self.cleaned: List[int] = []
self.kv_cache = kv_cache
# Co(gc): We no longer have the swapped space as we are not deciding which to swap
# bigdl-llm change end

Expand Down Expand Up @@ -150,6 +152,8 @@ def _schedule(self) -> SchedulerOutputs:
# We restrict how many requests that can be run using these three arguments
# Co(gc): If there are waiting requests, we will just try to add it into the
# running state if not exceeds the stage
# Co(gc): Record seq_len for prefill requests
seq_lens = []
# Co(gc): prefilled requests are prioritized over decoding stage requests
while self.waiting:
seq_group = self.waiting[0]
Expand Down Expand Up @@ -178,7 +182,9 @@ def _schedule(self) -> SchedulerOutputs:
# bigdl-llm change end

# If the number of batched tokens exceeds the limit, stop.
if (num_batched_tokens + num_prompt_tokens >
new_seq_lens = seq_lens + [num_prompt_tokens]
num_batched_tokens = len(new_seq_lens) * max(new_seq_lens)
if (num_batched_tokens >
self.scheduler_config.max_num_batched_tokens):
break

Expand All @@ -192,6 +198,8 @@ def _schedule(self) -> SchedulerOutputs:
seq_group = self.waiting.pop(0)
for seq in seq_group.get_seqs():
seq.status = SequenceStatus.RUNNING
# Co(gc): Only updated the seq_lens when all check passes
seq_lens = new_seq_lens
# bigdl-llm change start
# summary: removing block_manager related logic.
# self._allocate(seq_group)
Expand All @@ -204,7 +212,7 @@ def _schedule(self) -> SchedulerOutputs:
scheduler_outputs = SchedulerOutputs(
scheduled_seq_groups=scheduled,
prompt_run=True,
num_batched_tokens=num_batched_tokens,
num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0,
ignored_seq_groups=ignored_seq_groups,
finished_seqs=finished_seqs,
)
Expand Down Expand Up @@ -258,6 +266,13 @@ def free_seq(self, seq: Sequence) -> None:
# summary: The original code free the block in block_manager.
# now, we added it into a list to pass to worker in the next model_execute stage.
self.cleaned.append(seq.seq_id)
for i in range(len(self.kv_cache)):
for j in range(2):
if not self.kv_cache[i][j].get(seq.seq_id) is None:
del self.kv_cache[i][j][seq.seq_id]
# del self.kv_cache[seq.seq_id]
# logger.info(f"freed seqs: {seq.seq_id} .
# now kv cache is: {list(self.kv_cache[0][0].keys())} ")
# bigdl-llm change end

def free_finished_seq_groups(self) -> None:
Expand Down
7 changes: 4 additions & 3 deletions python/llm/src/bigdl/llm/vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,15 +243,16 @@ async def _run_workers_async(
) -> Any:
"""Runs the given method on all workers."""
# bigdl-llm change start
all_outputs = []
coros = []
for worker in self.workers:
# if self.parallel_config.worker_use_ray:
# executor = partial(worker.execute_method.remote, method)
# else:
executor = getattr(worker, method)
coros.append(asyncio.get_event_loop().run_in_executor(
None, partial(executor, *args, **kwargs)))

output = executor(*args, **kwargs)
all_outputs.append(output)
all_outputs = await asyncio.gather(*coros)

# if self.parallel_config.worker_use_ray:
# all_outputs = await asyncio.gather(*all_outputs)
Expand Down
6 changes: 4 additions & 2 deletions python/llm/src/bigdl/llm/vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#

import time
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union, Dict

from bigdl.llm.vllm.config import ModelConfig, SchedulerConfig
from bigdl.llm.vllm.core.scheduler import SchedulerOutputs, FixedWindowScheduler
Expand Down Expand Up @@ -127,6 +127,7 @@ def __init__(
# self.parallel_config = parallel_config
self.scheduler_config = scheduler_config
self.log_stats = log_stats
self.kv_cache = [[dict() for _ in range(2)] for _ in range(32)]
# self._verify_args()

self.tokenizer = get_tokenizer(
Expand All @@ -142,7 +143,7 @@ def __init__(
self._init_workers()

# Co(gc): we create a fixed scheduler
self.scheduler = FixedWindowScheduler(scheduler_config)
self.scheduler = FixedWindowScheduler(scheduler_config, kv_cache=self.kv_cache)

# Logging.
self.last_logging_time = 0.0
Expand Down Expand Up @@ -170,6 +171,7 @@ def _init_workers(self):
self.scheduler_config,
0,
# distributed_init_method,
kv_cache=self.kv_cache
)
self.workers.append(worker)
self._run_workers(
Expand Down
Loading

0 comments on commit fe2def0

Please sign in to comment.