From 4359ab3172c5d5f37cde8475ee03328e3fe3619e Mon Sep 17 00:00:00 2001 From: Xiangyu Tian <109123695+xiangyuT@users.noreply.github.com> Date: Fri, 14 Jun 2024 15:15:32 +0800 Subject: [PATCH] LLM: Add /generate_stream endpoint for Pipeline-Parallel-FastAPI example (#11187) Add /generate_stream and OpenAI-formatted endpoint for Pipeline-Parallel-FastAPI example --- .../GPU/Pipeline-Parallel-FastAPI/README.md | 28 +- .../Pipeline-Parallel-FastAPI/benchmark.py | 270 +++++++++++++ .../Pipeline-Parallel-FastAPI/gradio_webui.py | 69 ++++ .../openai_protocol.py | 367 ++++++++++++++++++ .../pipeline_models.py | 70 +++- .../pipeline_serving.py | 240 +++++++++++- .../Pipeline-Parallel-FastAPI/prompt/1024.txt | 1 + .../Pipeline-Parallel-FastAPI/prompt/128.txt | 1 + .../Pipeline-Parallel-FastAPI/prompt/2048.txt | 1 + .../Pipeline-Parallel-FastAPI/prompt/32.txt | 1 + .../GPU/Pipeline-Parallel-FastAPI/run.sh | 11 +- 11 files changed, 1041 insertions(+), 18 deletions(-) create mode 100644 python/llm/example/GPU/Pipeline-Parallel-FastAPI/benchmark.py create mode 100644 python/llm/example/GPU/Pipeline-Parallel-FastAPI/gradio_webui.py create mode 100644 python/llm/example/GPU/Pipeline-Parallel-FastAPI/openai_protocol.py create mode 100644 python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/1024.txt create mode 100644 python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/128.txt create mode 100644 python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/2048.txt create mode 100644 python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/32.txt diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md index 9bc8f254d44..b4203eb6f58 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/README.md @@ -18,7 +18,8 @@ pip install --pre --upgrade ipex-llm[xpu] --extra-index-url https://pytorch-exte pip install oneccl_bind_pt==2.1.100 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ # configures OneAPI environment variables source /opt/intel/oneapi/setvars.sh -pip install mpi4py fastapi uvicorn +pip install mpi4py fastapi uvicorn openai +pip install gradio # for gradio web UI conda install -c conda-forge -y gperftools=2.10 # to enable tcmalloc pip install transformers==4.31.0 # for llama2 models @@ -69,3 +70,28 @@ Please change the test url accordingly. # set t/c to the number of concurrencies to test full throughput. wrk -t1 -c1 -d5m -s ./wrk_script_1024.lua http://127.0.0.1:8000/generate/ --timeout 1m ``` + +## 5. Using the `benchmark.py` Script + +The `benchmark.py` script is designed to evaluate the performance of a streaming service by measuring response times and other relevant metrics. Below are the details on how to use the script effectively: + +### Command Line Arguments + +- `--prompt_length`: Specifies the length of the prompt used in the test. Acceptable values are `32`, `128`, `1024`, and `2048`. +- `--max_concurrent_requests`: Defines the levels of concurrency for the requests. You can specify multiple values to test different levels of concurrency in one run. +- `--max_new_tokens`: Sets the maximum number of new tokens that the model will generate per request. Default is `128`. + +### Usage Example +You can run the script with specific settings for prompt length, concurrent requests, and max new tokens by using the following command: + +```bash +python benchmark.py --prompt_length 1024 --max_concurrent_requests 1 2 3 --max_new_tokens 128 +``` + +This command sets the prompt length to 1024, tests concurrency levels of 1, 2, and 3, and configures the model to generate up to 128 new tokens per request. The results are saved in log files named according to the concurrency level (1.log, 2.log, 3.log). + +## 6. Gradio Web UI + +```bash +python ./gradio_webui.py -m Llama-2-13b-chat-hf +``` \ No newline at end of file diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/benchmark.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/benchmark.py new file mode 100644 index 00000000000..5b32796c105 --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/benchmark.py @@ -0,0 +1,270 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import requests +import time +from concurrent.futures import ThreadPoolExecutor +import concurrent +import numpy as np +from tqdm import tqdm +import json +import argparse +from typing import List, Tuple + + +# Execute single request +def perform_request(session, url, payload, headers): + start_time = time.perf_counter() + with session.post(url, json=payload, headers=headers, stream=True) as response: + response.raise_for_status() + + first_token_time = None + last_token_time = 0 + first_token_inference_time = None + next_token_inference_time = None + next_token_time = [] + i = 0 + for line in response.iter_lines(): + + token_time = time.perf_counter() - start_time + if line: + data = line.decode("utf-8").strip() + i = i + 1 + try: + json_data = json.loads(data) + if json_data["message"] is not None: + if first_token_time is None: + first_token_time = token_time + else: + next_token_time.append(token_time - last_token_time) + last_token_time = token_time + except json.JSONDecodeError: + pass + end_time = time.perf_counter() + return ( + first_token_time, + np.mean(next_token_time), + end_time - start_time, + first_token_inference_time, + next_token_inference_time, + ) + + +def extend_list_to_length(lst, target_length): + if target_length <= len(lst): + return lst[:] + times = target_length // len(lst) + remainder = target_length % len(lst) + extended_list = lst * times + lst[:remainder] + + return extended_list + + +def benchmark( + llm_urls, + prompt, + num_requests, + max_concurrent_requests, + max_tokens, + prompt_length, + is_warmup=False, +): + + headers = {"Content-Type": "application/json"} + + first_token_latencies = [] + next_token_latencies = [] + total_responce_times = [] + first_token_inference_times = [] + next_token_inference_times = [] + cur_url_index = 0 + + with requests.Session() as session: + with ThreadPoolExecutor(max_workers=max_concurrent_requests) as executor: + llm_url = llm_urls[cur_url_index] + cur_url_index = (cur_url_index + 1) % len(llm_urls) + + cur_llm_urls = extend_list_to_length(llm_urls, max_concurrent_requests) + cur_len = len(cur_llm_urls) + + payload = { + "prompt": prompt, + "n_predict": max_tokens, + } + futures = [ + executor.submit( + perform_request, + session, + cur_llm_urls[index % cur_len], + payload, + headers, + ) + for index in range(num_requests) + ] + + start_time = time.perf_counter() + + if is_warmup: + phase = "Warm Up" + else: + phase = "Benchmarking" + with tqdm(total=num_requests, desc=phase, unit="req", ncols=100) as pbar: + for future in concurrent.futures.as_completed(futures): + try: + ( + first_token_latency, + next_token_latency, + total_responce_time, + first_token_inference_time, + next_token_inference_time, + ) = future.result() + first_token_latencies.append(first_token_latency) + next_token_latencies.append(next_token_latency) + total_responce_times.append(total_responce_time) + if first_token_inference_time: + first_token_inference_times.append( + first_token_inference_time + ) + if next_token_inference_time: + next_token_inference_times.append(next_token_inference_time) + except Exception as e: + print(f"Request failed: {e}") + pbar.update(1) + + if is_warmup: + return + total_time = time.perf_counter() - start_time + log_file = f"{max_concurrent_requests}.log" + + with open(log_file, "w") as file: + print( + f"Total time for {num_requests} requests with {max_concurrent_requests} concurrent requests: {total_time} seconds.", + file=file, + ) + print( + f"Average response time: {np.mean(total_responce_times)}", file=file + ) + + print( + f"Token throughput: {num_requests * max_tokens / total_time}", + file=file, + ) + print( + f"Total token throughput: {(max_tokens + prompt_length) * num_requests / total_time}", + file=file, + ) + print(file=file) + + if first_token_latencies: + average_first_token_latency = sum(first_token_latencies) / len( + first_token_latencies + ) + p90_first_token_latency = np.percentile(first_token_latencies, 90) + p95_first_token_latency = np.percentile(first_token_latencies, 95) + # average_first_token_inference_latency = np.mean( + # first_token_inference_times + # ) + print( + f"Average first token latency: {average_first_token_latency * 1000} milliseconds.", + file=file, + ) + print( + f"P90 first token latency: {p90_first_token_latency * 1000} milliseconds.", + file=file, + ) + print( + f"P95 first token latency: {p95_first_token_latency * 1000} milliseconds.", + file=file, + ) + # print( + # f"Average first token inference latency: {average_first_token_inference_latency * 1000} milliseconds.", + # file=file, + # ) + print(file=file) + + if next_token_latencies: + average_next_token_latency = sum(next_token_latencies) / len( + next_token_latencies + ) + p90_next_token_latency = np.percentile(next_token_latencies, 90) + p95_next_token_latency = np.percentile(next_token_latencies, 95) + # average_next_token_inference_latency = np.mean( + # next_token_inference_times + # ) + print( + f"Average next token latency: {average_next_token_latency * 1000} milliseconds.", + file=file, + ) + print( + f"P90 next token latency: {p90_next_token_latency * 1000} milliseconds.", + file=file, + ) + print( + f"P95 next token latency: {p95_next_token_latency * 1000} milliseconds.", + file=file, + ) + # print( + # f"Average next token inference latency: {average_next_token_inference_latency * 1000} milliseconds.", + # file=file, + # ) + print(file=file) + + +LLM_URLS = [f"http://localhost:{PORT}/generate_stream/" for PORT in [8000]] + +parser = argparse.ArgumentParser(description="Set prompt length.") +parser.add_argument( + "--prompt_length", + type=int, + choices=[32, 128, 1024, 2048], + default=1024, + help="Length of the prompt: 32, 1024, or 2048", +) +parser.add_argument( + "--max_concurrent_requests", + type=int, + nargs="+", + default=[1, 2, 4, 5, 6], + help="List of maximum concurrent requests to test.", +) +parser.add_argument( + "--max_new_tokens", + type=int, + default=128, + help="Maximum number of new tokens that the model will generate per request.", +) +args = parser.parse_args() +PROMPT_LENGTH = args.prompt_length +PROMPT = open(f"prompt/{PROMPT_LENGTH}.txt", "r").read() +MAX_TOKENS = args.max_new_tokens + + +for MAX_CONCURRENT_REQUESTS in args.max_concurrent_requests: + NUM_WARMUP = 5 * MAX_CONCURRENT_REQUESTS + NUM_REQUESTS = 10 * MAX_CONCURRENT_REQUESTS + + # warm up + benchmark( + LLM_URLS, + PROMPT, + NUM_WARMUP, + MAX_CONCURRENT_REQUESTS, + MAX_TOKENS, + PROMPT_LENGTH, + is_warmup=True, + ) + + benchmark(LLM_URLS, PROMPT, NUM_REQUESTS, MAX_CONCURRENT_REQUESTS, MAX_TOKENS, PROMPT_LENGTH) diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/gradio_webui.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/gradio_webui.py new file mode 100644 index 00000000000..252bc31334e --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/gradio_webui.py @@ -0,0 +1,69 @@ +import argparse + +import gradio as gr +from openai import OpenAI + +# Argument parser setup +parser = argparse.ArgumentParser( + description='Chatbot Interface with Customizable Parameters') +parser.add_argument('--model-url', + type=str, + default='http://localhost:8000/v1', + help='Model URL') +parser.add_argument('-m', + '--model', + type=str, + required=True, + help='Model name for the chatbot') +parser.add_argument("--host", type=str, default=None) +parser.add_argument("--port", type=int, default=8001) + +# Parse the arguments +args = parser.parse_args() + +# Set OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = args.model_url + +# Create an OpenAI client to interact with the API server +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + + +def predict(message, history): + # Convert chat history to OpenAI format + history_openai_format = [{ + "role": "system", + "content": "You are a great ai assistant." + }] + for human, assistant in history: + history_openai_format.append({"role": "user", "content": human}) + history_openai_format.append({ + "role": "assistant", + "content": assistant + }) + history_openai_format.append({"role": "user", "content": message}) + + # Create a chat completion request and send it to the API server + stream = client.chat.completions.create( + model=args.model, # Model name to use + messages=history_openai_format, # Chat history + stream=True, # Stream response + ) + + # Read and return generated text from response stream + partial_message = "" + for chunk in stream: + # import pdb + # pdb.set_trace() + # partial_message += (chunk.delta['content'] or "") + partial_message += (chunk.choices[0].delta.content or "") + yield partial_message + + +# Create and launch a chat interface with Gradio +gr.ChatInterface(predict).queue().launch(server_name=args.host, + server_port=args.port, + share=True) \ No newline at end of file diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/openai_protocol.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/openai_protocol.py new file mode 100644 index 00000000000..546431b7f06 --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/openai_protocol.py @@ -0,0 +1,367 @@ +# Adapted from +# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py +import time +from typing import Dict, List, Literal, Optional, Union + +import torch +from openai.types.chat import ChatCompletionMessageParam +from pydantic import BaseModel, ConfigDict, Field, model_validator +from typing_extensions import Annotated + +# from vllm.sampling_params import SamplingParams +def random_uuid() -> str: + return str(uuid.uuid4().hex) + +class OpenAIBaseModel(BaseModel): + # OpenAI API does not allow extra fields + model_config = ConfigDict(extra="forbid") + + +class ErrorResponse(OpenAIBaseModel): + object: str = "error" + message: str + type: str + param: Optional[str] = None + code: int + + +class ModelPermission(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") + object: str = "model_permission" + created: int = Field(default_factory=lambda: int(time.time())) + allow_create_engine: bool = False + allow_sampling: bool = True + allow_logprobs: bool = True + allow_search_indices: bool = False + allow_view: bool = True + allow_fine_tuning: bool = False + organization: str = "*" + group: Optional[str] = None + is_blocking: bool = False + + +class ModelCard(OpenAIBaseModel): + id: str + object: str = "model" + created: int = Field(default_factory=lambda: int(time.time())) + owned_by: str = "vllm" + root: Optional[str] = None + parent: Optional[str] = None + permission: List[ModelPermission] = Field(default_factory=list) + + +class ModelList(OpenAIBaseModel): + object: str = "list" + data: List[ModelCard] = Field(default_factory=list) + + +class UsageInfo(OpenAIBaseModel): + prompt_tokens: int = 0 + total_tokens: int = 0 + completion_tokens: Optional[int] = 0 + + +class ResponseFormat(OpenAIBaseModel): + # type must be "json_object" or "text" + type: Literal["text", "json_object"] + + +class ChatCompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/chat/create + messages: List[ChatCompletionMessageParam] + model: str + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[bool] = False + top_logprobs: Optional[int] = None + max_tokens: Optional[int] = None + n: Optional[int] = 1 + presence_penalty: Optional[float] = 0.0 + response_format: Optional[ResponseFormat] = None + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + temperature: Optional[float] = 0.7 + top_p: Optional[float] = 1.0 + user: Optional[str] = None + + # doc: begin-chat-completion-sampling-params + best_of: Optional[int] = None + use_beam_search: Optional[bool] = False + top_k: Optional[int] = -1 + min_p: Optional[float] = 0.0 + repetition_penalty: Optional[float] = 1.0 + length_penalty: Optional[float] = 1.0 + early_stopping: Optional[bool] = False + ignore_eos: Optional[bool] = False + min_tokens: Optional[int] = 0 + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + skip_special_tokens: Optional[bool] = True + spaces_between_special_tokens: Optional[bool] = True + # doc: end-chat-completion-sampling-params + + # doc: begin-chat-completion-extra-params + echo: Optional[bool] = Field( + default=False, + description=( + "If true, the new message will be prepended with the last message " + "if they belong to the same role."), + ) + add_generation_prompt: Optional[bool] = Field( + default=True, + description= + ("If true, the generation prompt will be added to the chat template. " + "This is a parameter used by chat template in tokenizer config of the " + "model."), + ) + include_stop_str_in_output: Optional[bool] = Field( + default=False, + description=( + "Whether to include the stop string in the output. " + "This is only applied when the stop or stop_token_ids is set."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + + # doc: end-chat-completion-extra-params + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + +class CompletionRequest(OpenAIBaseModel): + # Ordered by official OpenAI API documentation + # https://platform.openai.com/docs/api-reference/completions/create + model: str + prompt: Union[List[int], List[List[int]], str, List[str]] + best_of: Optional[int] = None + echo: Optional[bool] = False + frequency_penalty: Optional[float] = 0.0 + logit_bias: Optional[Dict[str, float]] = None + logprobs: Optional[int] = None + max_tokens: Optional[int] = 16 + n: int = 1 + presence_penalty: Optional[float] = 0.0 + seed: Optional[int] = Field(None, + ge=torch.iinfo(torch.long).min, + le=torch.iinfo(torch.long).max) + stop: Optional[Union[str, List[str]]] = Field(default_factory=list) + stream: Optional[bool] = False + suffix: Optional[str] = None + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 1.0 + user: Optional[str] = None + + # doc: begin-completion-sampling-params + use_beam_search: Optional[bool] = False + top_k: Optional[int] = -1 + min_p: Optional[float] = 0.0 + repetition_penalty: Optional[float] = 1.0 + length_penalty: Optional[float] = 1.0 + early_stopping: Optional[bool] = False + stop_token_ids: Optional[List[int]] = Field(default_factory=list) + ignore_eos: Optional[bool] = False + min_tokens: Optional[int] = 0 + skip_special_tokens: Optional[bool] = True + spaces_between_special_tokens: Optional[bool] = True + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + # doc: end-completion-sampling-params + + # doc: begin-completion-extra-params + include_stop_str_in_output: Optional[bool] = Field( + default=False, + description=( + "Whether to include the stop string in the output. " + "This is only applied when the stop or stop_token_ids is set."), + ) + response_format: Optional[ResponseFormat] = Field( + default=None, + description= + ("Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'} or {'type': 'text' } is " + "supported."), + ) + guided_json: Optional[Union[str, dict, BaseModel]] = Field( + default=None, + description=("If specified, the output will follow the JSON schema."), + ) + guided_regex: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the regex pattern."), + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description=( + "If specified, the output will be exactly one of the choices."), + ) + guided_grammar: Optional[str] = Field( + default=None, + description=( + "If specified, the output will follow the context free grammar."), + ) + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be one of " + "'outlines' / 'lm-format-enforcer'")) + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding.")) + + # doc: end-completion-extra-params + + @model_validator(mode="before") + @classmethod + def check_guided_decoding_count(cls, data): + guide_count = sum([ + "guided_json" in data and data["guided_json"] is not None, + "guided_regex" in data and data["guided_regex"] is not None, + "guided_choice" in data and data["guided_choice"] is not None + ]) + if guide_count > 1: + raise ValueError( + "You can only use one kind of guided decoding " + "('guided_json', 'guided_regex' or 'guided_choice').") + return data + + +class LogProbs(OpenAIBaseModel): + text_offset: List[int] = Field(default_factory=list) + token_logprobs: List[Optional[float]] = Field(default_factory=list) + tokens: List[str] = Field(default_factory=list) + top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None + + +class CompletionResponseChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class CompletionResponseStreamChoice(OpenAIBaseModel): + index: int + text: str + logprobs: Optional[LogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( + default=None, + description=( + "The stop string or token id that caused the completion " + "to stop, None if the completion finished for some other reason " + "including encountering the EOS token"), + ) + + +class CompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") + object: str = "text_completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[CompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class ChatMessage(OpenAIBaseModel): + role: str + content: str + + +class ChatCompletionResponseChoice(OpenAIBaseModel): + index: int + message: ChatMessage + logprobs: Optional[LogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseChoice] + usage: Optional[UsageInfo] = Field(default=None) + + +class DeltaMessage(OpenAIBaseModel): + role: Optional[str] = None + content: Optional[str] = None + + +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): + index: int + delta: DeltaMessage + logprobs: Optional[LogProbs] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class ChatCompletionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") + object: str = "chat.completion.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: List[ChatCompletionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py index dfc09d6d57c..f8f5ba77ca8 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_models.py @@ -289,6 +289,12 @@ def __init__(self, checkpoint, rank, world_size, low_bit, max_num_seqs): self.send_buff = None self.dict_lock = threading.Lock() + self.streamer = {} + self.token_cache = {} + self.print_len = {} + self.is_finish = {} + self.model_name = checkpoint + # def generate(self, input_ids=None, max_tokens=5, attention_mask=None): # times = [] @@ -422,7 +428,7 @@ async def process_step(self, tokenizer, result_dict): if cur_batch is None: if not self.waiting_requests.empty(): - # await asyncio.sleep(0.01) + await asyncio.sleep(0.01) cur_batch = await self.add_request(tokenizer) cur_input = self.input_ids_dict[cur_batch.batch_id] else: @@ -447,6 +453,44 @@ async def process_step(self, tokenizer, result_dict): # cur_batch.input_len += 1 cur_batch.input_len = 1 cur_batch.prompt_lengths = [x + 1 for x in cur_batch.prompt_lengths] + + for index, request_id in enumerate(cur_batch.request_ids): + + if not self.is_finish.get(request_id, False): + remain = cur_batch.max_tokens - len(self.tokens[cur_id]) + + if self.streamer.get(request_id, None) is None: + self.streamer[request_id] = asyncio.Queue() + + if next_ids[index].int() == tokenizer.eos_token_id: + remain = 0 + self.is_finish[request_id] = True + + if self.token_cache.get(request_id, None) is None: + self.token_cache[request_id] = [] + self.print_len[request_id] = 0 + self.token_cache[request_id].extend(next_ids[index].tolist()) + + text = tokenizer.decode(self.token_cache[request_id]) + if text.endswith("\n"): + printable_text = text[self.print_len[request_id]:] + self.token_cache[request_id] = [] + self.print_len[request_id] = 0 + elif len(text) > 0 and _is_chinese_char(ord(text[-1])): + printable_text = text[self.print_len[request_id]:] + self.print_len[request_id] += len(printable_text) + else: + printable_text = text[self.print_len[request_id] : text.rfind(" ") + 1] + self.print_len[request_id] += len(printable_text) + + if remain > 0: + await self.streamer[request_id].put((remain, printable_text)) + else: + printable_text = printable_text + text[self.print_len[request_id]:] + self.token_cache.pop(request_id, None) + self.print_len.pop(request_id, None) + await self.streamer[request_id].put((remain, printable_text)) + if len(self.tokens[cur_id]) >= cur_batch.max_tokens: # Finish a batch # logger.info(self.tokens[cur_id]) @@ -509,3 +553,27 @@ async def process_step(self, tokenizer, result_dict): self.on_going_batches[:-1] = self.on_going_batches[1:] self.on_going_batches[self.world_size - 1] = cur_batch + +def _is_chinese_char(cp): + """Checks whether CP is the codepoint of a CJK character.""" + # This defines a "chinese character" as anything in the CJK Unicode block: + # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block) + # + # Note that the CJK Unicode block is NOT all Japanese and Korean characters, + # despite its name. The modern Korean Hangul alphabet is a different block, + # as is Japanese Hiragana and Katakana. Those alphabets are used to write + # space-separated words, so they are not treated specially and handled + # like the all of the other languages. + if ( + (cp >= 0x4E00 and cp <= 0x9FFF) + or (cp >= 0x3400 and cp <= 0x4DBF) # + or (cp >= 0x20000 and cp <= 0x2A6DF) # + or (cp >= 0x2A700 and cp <= 0x2B73F) # + or (cp >= 0x2B740 and cp <= 0x2B81F) # + or (cp >= 0x2B820 and cp <= 0x2CEAF) # + or (cp >= 0xF900 and cp <= 0xFAFF) + or (cp >= 0x2F800 and cp <= 0x2FA1F) # + ): # + return True + + return False \ No newline at end of file diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py index 6b85a611f31..aeaf0d1741a 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py @@ -3,7 +3,10 @@ import torch.distributed as dist import os +import ipex_llm +from ipex_llm.utils.common import invalidInputError import oneccl_bindings_for_pytorch +import json from transformers.utils import logging logger = logging.get_logger(__name__) @@ -20,11 +23,12 @@ import time from transformers import AutoTokenizer, AutoConfig, LlamaTokenizer -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import StreamingResponse from pydantic import BaseModel import uvicorn import asyncio, uuid -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Any, Callable, Union import argparse def get_int_from_env(env_keys, default): @@ -38,8 +42,22 @@ def get_int_from_env(env_keys, default): class PromptRequest(BaseModel): prompt: str - n_predict: int = 32 + n_predict: Optional[int] = 256 + req_type: str = 'completion' +from openai.types.chat import ChatCompletionMessageParam +class ChatCompletionRequest(BaseModel): + messages: List[ChatCompletionMessageParam] + model: str + max_tokens: Optional[int] = None + stream: Optional[bool] = False + + +class CompletionRequest(BaseModel): + model: str + prompt: Union[List[int], List[List[int]], str, List[str]] + max_tokens: Optional[int] = None + stream: Optional[bool] = False empty_req = PromptRequest(prompt="", n_predict=0) @@ -49,8 +67,112 @@ class PromptRequest(BaseModel): request_queue: asyncio.Queue = asyncio.Queue() result_dict: Dict[str, str] = {} +streamer_dict = {} local_rank = my_rank -max_num_seqs = get_int_from_env(["MAX_NUM_SEQS"], "16") + + +from openai_protocol import ( + ChatCompletionResponseStreamChoice, + ChatCompletionStreamResponse, + ChatCompletionResponseChoice, + ChatCompletionResponse, + ChatMessage, + DeltaMessage, + CompletionResponseChoice, + CompletionResponse, + CompletionResponseStreamChoice, + CompletionStreamResponse, +) + + +async def chat_stream_generator(local_model, delta_text_queue, request_id): + model_name = local_model.model_name + index = 0 + while True: + if not delta_text_queue.empty(): + with local_model.dict_lock: + remain, delta_text = await delta_text_queue.get() + # print(remain) + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(role="assistant", content=delta_text), + logprobs=None, + finish_reason=None) + chunk = ChatCompletionStreamResponse( + id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + index = index + 1 + if remain == 0: + choice_data = ChatCompletionResponseStreamChoice( + index=index, + delta=DeltaMessage(role="assistant", content=None), + logprobs=None, + finish_reason="length") + chunk = ChatCompletionStreamResponse( + id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + break + else: + await asyncio.sleep(0) + local_model.streamer.pop(request_id, None) + + +async def completion_stream_generator(local_model, delta_text_queue, request_id): + model_name = local_model.model_name + index = 0 + while True: + if not delta_text_queue.empty(): + with local_model.dict_lock: + remain, delta_text = await delta_text_queue.get() + # print(remain) + choice_data = CompletionResponseStreamChoice( + index=index, + text=delta_text, + logprobs=None, + finish_reason=None) + chunk = CompletionStreamResponse( + id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + index = index + 1 + if remain == 0: + choice_data = CompletionResponseStreamChoice( + index=index, + text=None, + logprobs=None, + finish_reason="length") + chunk = CompletionStreamResponse( + id=request_id, + choices=[choice_data], + model=model_name) + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + break + else: + await asyncio.sleep(0) + local_model.streamer.pop(request_id, None) + + +async def generator(local_model, delta_text_queue, request_id): + while True: + if not delta_text_queue.empty(): + with local_model.dict_lock: + remain, delta_text = await delta_text_queue.get() + yield delta_text + if remain == 0: + break + else: + await asyncio.sleep(0) + # streamer_dict.pop(request_id, None) + local_model.streamer.pop(request_id, None) @app.post("/generate/") @@ -58,16 +180,106 @@ async def generate(prompt_request: PromptRequest): request_id = str(uuid.uuid4()) await local_model.waiting_requests.put((request_id, prompt_request)) while True: - if request_id in result_dict: - with local_model.dict_lock: - output_str = result_dict[request_id] - if len(output_str) == 0: - logger.info(f"Why? {request_id}") - # await asyncio.sleep(0.1) - # continue - result_dict.pop(request_id) - return {"generated_text": output_str} - await asyncio.sleep(0) + await asyncio.sleep(0) + cur_streamer = local_model.streamer.get(request_id, None) + if cur_streamer is not None: + output_str = [] + async for item in generator(local_model, cur_streamer, request_id): + output_str.append(item) + return request_id, "".join(output_str) + + +@app.post("/generate_stream/") +async def generate_stream(prompt_request: PromptRequest): + request_id = str(uuid.uuid4()) + "stream" + await local_model.waiting_requests.put((request_id, prompt_request)) + while True: + await asyncio.sleep(0) + cur_streamer = local_model.streamer.get(request_id, None) + if cur_streamer is not None: + if prompt_request.req_type == 'completion': + cur_generator = completion_stream_generator(local_model, cur_streamer, request_id) + elif prompt_request.req_type == 'chat': + cur_generator = chat_stream_generator(local_model, cur_streamer, request_id) + else: + invalidInputError(False, "Invalid Request Type.") + + return request_id, StreamingResponse( + content=cur_generator, media_type="text/event-stream" + ) + + +DEFAULT_SYSTEM_PROMPT = """\ +""" + +def get_prompt(messages) -> str: + prompt = "" + for msg in messages: + role = msg["role"] + content = msg["content"] + if role == "system": + prompt += f"<>\n{content}\n<>\n\n" + elif role == "user": + prompt += f"[INST] {content} [/INST] " + elif role == "assistant": + prompt += f"{content} " + else: + raise ValueError(f"Unknown role: {role}") + return prompt.strip() + +@app.post("/v1/chat/completions") +async def create_chat_completion(request: ChatCompletionRequest): + model_name = local_model.model_name + if request.max_tokens is None: + n_predict = 256 + else: + n_predict = request.max_tokens + prompt_request = PromptRequest( + prompt=get_prompt(request.messages), + n_predict=n_predict, + req_type="chat" + ) + if request.stream: + request_id, result = await generate_stream(prompt_request) + else: + request_id, result = await generate(prompt_request) + choice_data = ChatCompletionResponseChoice( + index=0, + message=ChatMessage(role="assistant", content=result), + logprobs=None, + finish_reason="length") + result = ChatCompletionResponse( + id=request_id, + choices=[choice_data], + model=model_name) + return result + +@app.post("/v1/completions") +async def create_completion(request: CompletionRequest): + model_name = local_model.model_name + if request.max_tokens is None: + n_predict = 256 + else: + n_predict = request.max_tokens + prompt_request = PromptRequest( + prompt=request.prompt, + n_predict=n_predict, + req_type="completion" + ) + if request.stream: + request_id, result = await generate_stream(prompt_request) + else: + request_id, result = await generate(prompt_request) + choice_data = CompletionResponseChoice( + index=0, + text=result, + logprobs=None, + finish_reason="length") + result = CompletionResponse( + id=request_id, + choices=[choice_data], + model=model_name) + return result def generate_text(prompt: List[str], n_predict = 32): diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/1024.txt b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/1024.txt new file mode 100644 index 00000000000..4146a49f586 --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/1024.txt @@ -0,0 +1 @@ +Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun. However, her parents were always telling her to stay close to home, to be careful, and to avoid any danger. But the little girl was stubborn, and she wanted to see what was on the other side of the mountain. So she sneaked out of the house one night, leaving a note for her parents, and set off on her journey. As she climbed the mountain, the little girl felt a sense of excitement and wonder. She had never been this far away from home before, and she couldnt wait to see what she would find on the other side. She climbed higher and higher, her lungs burning from the thin air, until she finally reached the top of the mountain. And there, she found a beautiful meadow filled with wildflowers and a sparkling stream. The little girl danced and played in the meadow, feeling free and alive. She knew she had to return home eventually, but for now, she was content to enjoy her adventure. As the sun began to set, the little girl reluctantly made her way back down the mountain, but she knew that she would never forget her adventure and the joy of discovering something new and exciting. And whenever she felt scared or unsure, she would remember the thrill of climbing the mountain and the beauty of the meadow on the other side, and she would know that she could face any challenge that came her way, with courage and determination. She carried the memories of her journey in her heart, a constant reminder of the strength she possessed. The little girl returned home to her worried parents, who had discovered her note and anxiously awaited her arrival. They scolded her for disobeying their instructions and venturing into the unknown. But as they looked into her sparkling eyes and saw the glow on her face, their anger softened. They realized that their little girl had grown, that she had experienced something extraordinary. The little girl shared her tales of the mountain and the meadow with her parents, painting vivid pictures with her words. She spoke of the breathtaking view from the mountaintop, where the world seemed to stretch endlessly before her. She described the delicate petals of the wildflowers, vibrant hues that danced in the gentle breeze. And she recounted the soothing melody of the sparkling stream, its waters reflecting the golden rays of the setting sun. Her parents listened intently, captivated by her story. They realized that their daughter had discovered a part of herself on that journey—a spirit of curiosity and a thirst for exploration. They saw that she had learned valuable lessons about independence, resilience, and the beauty that lies beyond ones comfort zone. From that day forward, the little girls parents encouraged her to pursue her dreams and embrace new experiences. They understood that while there were risks in the world, there were also rewards waiting to be discovered. They supported her as she continued to embark on adventures, always reminding her to stay safe but never stifling her spirit. As the years passed, the little girl grew into a remarkable woman, fearlessly exploring the world and making a difference wherever she went. The lessons she had learned on that fateful journey stayed with her, guiding her through challenges and inspiring her to live life to the fullest. And so, the once timid little girl became a symbol of courage and resilience, a reminder to all who knew her that the greatest joys in life often lie just beyond the mountains we fear to climb. Her story spread far and wide, inspiring others to embrace their own journeys and discover the wonders that awaited them. In the end, the little girls adventure became a timeless tale, passed down through generations, reminding us all that sometimes, the greatest rewards come to those who dare to step into the unknown and follow their hearts. With each passing day, the little girls story continued to inspire countless individuals, igniting a spark within their souls and encouraging them to embark on their own extraordinary adventures. The tale of her bravery and determination resonated deeply with people from all walks of life, reminding them of the limitless possibilities that awaited them beyond the boundaries of their comfort zones. People marveled at the little girls unwavering spirit and her unwavering belief in the power of dreams. They saw themselves reflected in her journey, finding solace in the knowledge that they too could overcome their fears and pursue their passions. The little girl's story became a beacon of hope, a testament to the human spirit diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/128.txt b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/128.txt new file mode 100644 index 00000000000..bd766236a53 --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/128.txt @@ -0,0 +1 @@ +In a distant future, humanity has expanded across the galaxy, establishing colonies on numerous planets. The interstellar community thrives under the guidance of the United Galactic Federation, which ensures peace and prosperity. However, a new threat emerges from the unknown regions of space, challenging the stability and security of the galaxy. Brave explorers and seasoned warriors must unite to uncover the secrets of this mysterious force and protect the future of all sentient beings. Please continue the above story as long as possible, preferably more than 1000 tokens. \ No newline at end of file diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/2048.txt b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/2048.txt new file mode 100644 index 00000000000..b05a7e7518a --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/2048.txt @@ -0,0 +1 @@ +“You’re an idiot,” she said.\nI smiled and leaned back in the chair, looking at her over my glasses. “No, I’m not.”\n“If you were smart you would have learned to dance years ago. You’ve got two left feet.” She held up both of her hands with four fingers extended then made a circular motion that looked like an airplane.\nI leaned forward and put my glasses on the table in front of me, reaching for her hands as I did so, grabbing them before they could leave mine. “The next time you do something like this, call me. The phone number is right here,” I said as I pointed at a piece of paper under a stack of papers on my desk.\n“Fine,” she huffed and turned to leave the room. But she stopped at the doorway when she saw the bookshelves that lined one wall. “What are these for?” She stepped closer, tilting her head back and forth as she looked up. The shelves were three stories high with stacks of books on every level.\n“Books.” I smiled again. “I have a lot of books.”\nShe didn’t respond to that so I continued: “And there are more in the basement.”\n“But you can’t move them all here, right? This place is just too small for all those books. Maybe we should look for a bigger office building.” She looked back at me but said nothing as she took another few steps towards the door and then stopped again when she saw my grandfather clock on the wall.\n“And this?” she pointed to the clock, which had been in the family for over seventy years. “It’s just a clock isn’t it?”\nI laughed. “You can say that, but I know better.” It was then that I told her my grandfather’s story. He made that clock, and it was his favorite possession. When he died she inherited the clock; or at least she thought she did. After a few weeks of trying to sell it on eBay, she gave up because no one would pay what she felt it was worth.\n“You should have had an auction,” she suggested, leaning in towards me again. “Then maybe you could get more for it.”\n“No,” I shook my head. “I don’t want to sell the clock.”\nShe smiled, but this time it didn’t reach her eyes. She took a step back and looked at me again, not saying anything, just staring. The only sound was the ticking of the grandfather clock in the background as she waited for my next words.\n“My grandfather made this clock. He did everything by hand.” I could see that she had no idea what to say or do so I continued: “It’s his favorite possession, and it means more to me than anything else he ever owned. So, if you want the books, you can have them…” I looked at her face for just a second before continuing, “but you won’t take the clock.”\nShe finally responded with: “But what about the money?” She looked around again and said, “I think we could make more selling these books than you would get from all of them. You must have thousands of books here!”\nI took another step forward and put my hand on her shoulder as I spoke to her in a very low voice. “You’ve got it all wrong,” I told her. “There are only two or three hundred books. I’m not looking for money – I’m looking for someone who understands how important this clock is.”\n“How much do you want for the books?” she asked, still staring at me intently as she waited for my answer.\n“Forget about the money,” I said again. “If you really want to buy them, we can take our time and talk more later. But if you just want their value in paperbacks, that’s not what they’re worth.” She still seemed confused by everything I had said so far, so I tried to simplify my words as much as possible: “The books are mine; the clock is my grandfather’s. These books have been passed down through several generations of our family and are irreplaceable. Do you understand?”\n“I guess not,” she answered as she walked away from me, still looking at me but not saying a word. She took two more steps before turning around to say one last thing: “Well, good luck with the books, then.” With that, she went back into her house and out of sight, still walking without talking.\nAfter a few minutes, I slowly walked back toward my grandfather’s home. As I got closer, I could see the roof in the distance; the white crosses on the top of it were hard to miss. It seemed as if the entire town had gathered around there at that moment – people were all over the road around us, watching the commotion and chattering about what was going on.\nWhen my grandfather first saw me, he looked up from his chair with a smile on his face: “There you are.” He looked down at his hands, then back toward me as I walked forward to talk to him for the first time in years: “It’s been too long since we last spoke; it’s good to see you again.”\n“And you,” I said. Then, looking past my grandfather and directly into the face of the man who was sitting next to him (my mother’s father), I said, “I see he got your clock back for you, too. How is he?” My grandfather smiled as he looked up at me again:\n“He’s fine,” he answered, still smiling as he watched my mother’s family and mine chat with one another in the middle of all these people – a situation that I had never seen before. “Come on inside.” He stood up from his chair to do just that; my mom and her sister were already walking out of the building. “I have things for you.”\nMy grandfather led us inside, down some steps where he used to serve as the pastor in his church; there was a big room full of chairs at the bottom with pictures on the wall – all kinds of pictures, from when my family first started coming here to visit and other pictures we took while staying here over the years. All these photographs were all around us as I followed my grandfather through the building:\n“My house is just up the street,” he said. He stopped at a picture on the wall that was taken in the summer when we came to visit, smiling as he looked toward it with his arms folded – the picture was of him and his wife and two of their daughters, all standing together by one of the trees outside; there were other pictures around this one, some from much earlier than when my grandfather first started serving here. “We used to sit in a booth in that restaurant right over there – you remember?” I nodded as we went past it.\nMy grandfather stopped at another picture on the wall: it was of him and his wife with two other families, all sitting around a table together, smiling. He looked down at this one for a moment; then he said, “We used to do things like this every year, when we came to visit.” It was an older picture than the last one my grandfather had stopped in front of; I didn’t know it before but now I realized how much he has aged.\nMy grandparents have lived together for many years. They used to live in a house right next door, so they could walk over whenever they wanted; that is what they have done here all these years – as my grandfather said, “we’ve come here every summer since I was eleven.” But he and his wife are getting old now. He isn’t able to walk much anymore, but it makes him happy when he does: “My health has not been good lately,” he said.\n“You will never have a better time in your life than this one right now; you will never be as happy as you are now.” And for the first time since I have known him – since I was very little and started coming here every summer – my grandfather smiled at me, his eyes sparkling with excitement.\n“I know,” I said. “That’s why I’m really looking forward to it. It will be a lot of fun.” Then he turned back to the picture again; “See this?” he asked, pointing. “I remember that day, all sixteen of us there together. I was eleven then – my dad had taken me and my brother for our first trip away from home – and that was when we used to go to the cottage.” He stared at it for a while longer; he had tears in his eyes. “I loved this picture,” he said, turning it over again with one hand so I could see the back of it.\n“This is my best memory,” he explained. “It was taken on my birthday. That’s what makes me happiest.” He pointed to a man who had a pipe in his mouth. “That’s my uncle,” he said. “He gave all of us kids cigars for our birthdays, and we used to take turns lighting them – then everyone would sit around outside in the sunshine and smoke together like that. It was such a good time.” Then he held up his hand, as if to say, that’s enough now; and he went on, “Anyway, I don’ diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/32.txt b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/32.txt new file mode 100644 index 00000000000..4dbb31cb2a2 --- /dev/null +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/prompt/32.txt @@ -0,0 +1 @@ +Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh index a15b6c51ff5..3c6243d613e 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/run.sh @@ -1,7 +1,12 @@ source /opt/intel/oneapi/setvars.sh export no_proxy=localhost export FI_PROVIDER=tcp -export OMP_NUM_THREADS=8 +export OMP_NUM_THREADS=32 + +export LD_PRELOAD=${LD_PRELOAD}:${CONDA_PREFIX}/lib/libtcmalloc.so +basekit_root=/opt/intel/oneapi +source $basekit_root/setvars.sh --force +source $basekit_root/ccl/latest/env/vars.sh --force export USE_XETLA=OFF export SYCL_PI_LEVEL_ZERO_USE_IMMEDIATE_COMMANDLISTS=2 @@ -9,4 +14,6 @@ export TORCH_LLM_ALLREDUCE=0 export MODEL_PATH=YOUR_MODEL_PATH export NUM_GPUS=2 -CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 +export BIGDL_QUANTIZE_KV_CACHE=1 + +CCL_ZE_IPC_EXCHANGE=sockets torchrun --standalone --nnodes=1 --nproc-per-node $NUM_GPUS pipeline_serving.py --repo-id-or-model-path $MODEL_PATH --low-bit fp8 --max-num-seqs 4