From 0e0bd309e2903c18cc6827c0885d26a12fbeb5cb Mon Sep 17 00:00:00 2001 From: "Wang, Jian4" <61138589+hzjane@users.noreply.github.com> Date: Mon, 6 May 2024 10:06:20 +0800 Subject: [PATCH] LLM: Enable Speculative on Fastchat (#10909) * init * enable streamer * update * update * remove deprecated * update * update * add gpu example --- .../doc/LLM/Quickstart/fastchat_quickstart.md | 17 ++++++ .../src/ipex_llm/serving/fastchat/README.md | 52 +++++++------------ .../serving/fastchat/ipex_llm_worker.py | 16 +++++- .../llm/src/ipex_llm/transformers/loader.py | 6 +++ .../src/ipex_llm/transformers/speculative.py | 16 +++--- 5 files changed, 66 insertions(+), 41 deletions(-) diff --git a/docs/readthedocs/source/doc/LLM/Quickstart/fastchat_quickstart.md b/docs/readthedocs/source/doc/LLM/Quickstart/fastchat_quickstart.md index 726b7684a72..2eb44aced0c 100644 --- a/docs/readthedocs/source/doc/LLM/Quickstart/fastchat_quickstart.md +++ b/docs/readthedocs/source/doc/LLM/Quickstart/fastchat_quickstart.md @@ -61,6 +61,23 @@ export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path REPO_ID_OR_YOUR_MODEL_PATH --low-bit "sym_int4" --trust-remote-code --device "xpu" ``` +#### For self-speculative decoding example: + +You can use IPEX-LLM to run `self-speculative decoding` example. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel MAX GPUs. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel CPUs. + +```bash +# Available low_bit format only including bf16 on CPU. +source ipex-llm-init -t +python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "bf16" --trust-remote-code --device "cpu" --speculative + +# Available low_bit format only including fp16 on GPU. +source /opt/intel/oneapi/setvars.sh +export ENABLE_SDP_FUSION=1 +export SYCL_CACHE_PERSISTENT=1 +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "fp16" --trust-remote-code --device "xpu" --speculative +``` + You can get output like this: ```bash diff --git a/python/llm/src/ipex_llm/serving/fastchat/README.md b/python/llm/src/ipex_llm/serving/fastchat/README.md index ec905219ab5..553f5fa7d3a 100644 --- a/python/llm/src/ipex_llm/serving/fastchat/README.md +++ b/python/llm/src/ipex_llm/serving/fastchat/README.md @@ -25,9 +25,11 @@ You may install **`ipex-llm`** with `FastChat` as follows: ```bash pip install --pre --upgrade ipex-llm[serving] +pip install transformers==4.36.0 # Or pip install --pre --upgrade ipex-llm[all] + ``` To add GPU support for FastChat, you may install **`ipex-llm`** as follows: @@ -51,39 +53,6 @@ python3 -m fastchat.serve.controller Using IPEX-LLM in FastChat does not impose any new limitations on model usage. Therefore, all Hugging Face Transformer models can be utilized in FastChat. -#### IPEX-LLM model worker (deprecated) - -> Warning: This method has been deprecated, please change to use `IPEX-LLM` [worker](#ipex-llm-worker) instead. - -FastChat determines the Model adapter to use through path matching. Therefore, in order to load models using IPEX-LLM, you need to make some modifications to the model's name. - -For instance, assuming you have downloaded the `llama-7b-hf` from [HuggingFace](https://huggingface.co/decapoda-research/llama-7b-hf). Then, to use the `IPEX-LLM` as backend, you need to change name from `llama-7b-hf` to `ipex-llm-7b`.The key point here is that the model's path should include "ipex" and **should not include paths matched by other model adapters**. - -Then we will use `ipex-llm-7b` as model-path. - -> note: This is caused by the priority of name matching list. The new added `IPEX-LLM` adapter is at the tail of the name-matching list so that it has the lowest priority. If model path contains other keywords like `vicuna` which matches to another adapter with higher priority, then the `IPEX-LLM` adapter will not work. - -A special case is `ChatGLM` models. For these models, you do not need to do any changes after downloading the model and the `IPEX-LLM` backend will be used automatically. - -Then we can run model workers - -```bash -# On CPU -python3 -m ipex_llm.serving.fastchat.model_worker --model-path PATH/TO/ipex-llm-7b --device cpu - -# On GPU -python3 -m ipex_llm.serving.fastchat.model_worker --model-path PATH/TO/ipex-llm-7b --device xpu -``` - -If you run successfully using `ipex_llm` backend, you can see the output in log like this: - -```bash -INFO - Converting the current model to sym_int4 format...... -``` - -> note: We currently only support int4 quantization for this method. - - #### IPEX-LLM worker To integrate IPEX-LLM with `FastChat` efficiently, we have provided a new model_worker implementation named `ipex_llm_worker.py`. @@ -104,6 +73,23 @@ For GPU example: python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "sym_int4" --trust-remote-code --device "xpu" ``` +#### For self-speculative decoding example: + +You can use IPEX-LLM to run `self-speculative decoding` example. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel MAX GPUs. Refer to [here](https://github.com/intel-analytics/ipex-llm/tree/c9fac8c26bf1e1e8f7376fa9a62b32951dd9e85d/python/llm/example/GPU/Speculative-Decoding) for more details on intel CPUs. + +```bash +# Available low_bit format only including bf16 on CPU. +source ipex-llm-init -t +python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "bf16" --trust-remote-code --device "cpu" --speculative + +# Available low_bit format only including fp16 on GPU. +source /opt/intel/oneapi/setvars.sh +export ENABLE_SDP_FUSION=1 +export SYCL_CACHE_PERSISTENT=1 +export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=1 +python3 -m ipex_llm.serving.fastchat.ipex_llm_worker --model-path lmsys/vicuna-7b-v1.5 --low-bit "fp16" --trust-remote-code --device "xpu" --speculative +``` + For a full list of accepted arguments, you can refer to the main method of the `ipex_llm_worker.py` #### IPEX-LLM vLLM worker diff --git a/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py b/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py index 9fbe1cb7329..2c2346458bb 100644 --- a/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py +++ b/python/llm/src/ipex_llm/serving/fastchat/ipex_llm_worker.py @@ -63,6 +63,7 @@ def __init__( device: str = "cpu", no_register: bool = False, trust_remote_code: bool = False, + speculative: bool = False, stream_interval: int = 4, ): super().__init__( @@ -82,11 +83,13 @@ def __init__( ) logger.info(f"Using low bit format: {self.load_in_low_bit}, device: {device}") + if speculative: + logger.info(f"Using Self-Speculative decoding to generate") self.device = device - + self.speculative = speculative self.model, self.tokenizer = load_model( - model_path, device, self.load_in_low_bit, trust_remote_code + model_path, device, self.load_in_low_bit, trust_remote_code, speculative ) self.stream_interval = stream_interval self.context_len = get_context_length(self.model.config) @@ -98,6 +101,7 @@ def generate_stream_gate(self, params): # context length is self.context_length prompt = params["prompt"] temperature = float(params.get("temperature", 1.0)) + do_sample = bool(params.get("do_sample", False)) repetition_penalty = float(params.get("repetition_penalty", 1.0)) top_p = float(params.get("top_p", 1.0)) top_k = int(params.get("top_k", 1)) @@ -165,6 +169,7 @@ def generate_stream_gate(self, params): streamer=streamer, temperature=temperature, repetition_penalty=repetition_penalty, + do_sample=do_sample, top_p=top_p, top_k=top_k, ) @@ -314,6 +319,12 @@ async def api_model_details(request: Request): parser.add_argument( "--device", type=str, default="cpu", help="Device for executing model, cpu/xpu" ) + parser.add_argument( + "--speculative", + action="store_true", + default=False, + help="To use self-speculative or not", + ) parser.add_argument( "--trust-remote-code", action="store_true", @@ -335,5 +346,6 @@ async def api_model_details(request: Request): args.device, args.no_register, args.trust_remote_code, + args.speculative, ) uvicorn.run(app, host=args.host, port=args.port, log_level="info") diff --git a/python/llm/src/ipex_llm/transformers/loader.py b/python/llm/src/ipex_llm/transformers/loader.py index 2e00eb09172..1a05ba3bee8 100644 --- a/python/llm/src/ipex_llm/transformers/loader.py +++ b/python/llm/src/ipex_llm/transformers/loader.py @@ -45,6 +45,7 @@ def load_model( device: str = "cpu", low_bit: str = 'sym_int4', trust_remote_code: bool = True, + speculative: bool = False, ): """Load a model using BigDL LLM backend.""" @@ -64,6 +65,11 @@ def load_model( else: model_kwargs.update({"load_in_low_bit": low_bit, "torch_dtype": 'auto'}) + if speculative: + invalidInputError(low_bit == "fp16" or low_bit == "bf16", + "Self-Speculative only supports low_bit fp16 or bf16") + model_kwargs["speculative"] = True + # Load tokenizer tokenizer = tokenizer_cls.from_pretrained(model_path, trust_remote_code=True) model = model_cls.from_pretrained(model_path, **model_kwargs) diff --git a/python/llm/src/ipex_llm/transformers/speculative.py b/python/llm/src/ipex_llm/transformers/speculative.py index 31309451c06..2128998cb9e 100644 --- a/python/llm/src/ipex_llm/transformers/speculative.py +++ b/python/llm/src/ipex_llm/transformers/speculative.py @@ -97,6 +97,7 @@ def generate( new_speculative_kwargs[var] = value return self.speculative_generate(inputs=inputs, draft_model=self.draft_model, + streamer=streamer, **new_speculative_kwargs) else: # When `draft_model` is false, these attributes @@ -512,7 +513,7 @@ def _crop_past_key_values(self, past_key_values, new_cache_size, _enable_ipex=Fa return past_key_values -def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs): +def _prepare_generate_args(self, inputs, generation_config, streamer=None, **sampling_kwargs): if generation_config is None: generation_config = self.generation_config @@ -591,8 +592,8 @@ def _prepare_generate_args(self, inputs, generation_config, **sampling_kwargs): # 5. Prepare `input_ids` which will be used for auto-regressive generation input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") - # if streamer is not None: - # streamer.put(input_ids.cpu()) + if streamer is not None: + streamer.put(input_ids.cpu()) input_ids_length = input_ids.shape[-1] @@ -658,6 +659,7 @@ def speculative_generate(self, min_step_draft=3, generation_config: Optional[GenerationConfig] = None, attention_mask=None, + streamer: Optional["BaseStreamer"] = None, **sampling_kwargs): invalidInputError(draft_model is not None, "Draft model should be provided.") @@ -666,7 +668,7 @@ def speculative_generate(self, min_step_draft = min_step_draft if min_step_draft >= 1 else 1 input_ids, generation_config, logits_processor, stopping_criteria, \ - model_kwargs = _prepare_generate_args(self, inputs, generation_config, + model_kwargs = _prepare_generate_args(self, inputs, generation_config, streamer, **sampling_kwargs) step = 0 @@ -1061,7 +1063,8 @@ def speculative_generate(self, generate_ids[:, step:step+output_ids.size(1)] = output_ids current_input_ids = output_ids[:, -1:] - + if streamer is not None: + streamer.put(output_ids.cpu()) step += output_ids.size(1) # remove one generated by the base model @@ -1094,7 +1097,8 @@ def speculative_generate(self, idx = output_ids_list.index(generation_config.eos_token_id) step -= (len(output_ids_list) - idx - 1) break - + if streamer is not None: + streamer.end() step = min(step, max_new_tokens) e2e_toc = time.time() self.n_token_generated = step