From cdd60b5b9b5081b53314eec381e6f6835b538e3b Mon Sep 17 00:00:00 2001 From: arkohut Date: Sat, 30 Dec 2023 17:48:46 +0800 Subject: [PATCH] Add gradio chatbot for openai webserver --- examples/gradio_openai_webserver.py | 81 +++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 examples/gradio_openai_webserver.py diff --git a/examples/gradio_openai_webserver.py b/examples/gradio_openai_webserver.py new file mode 100644 index 0000000000000..61e91d6b0c8b6 --- /dev/null +++ b/examples/gradio_openai_webserver.py @@ -0,0 +1,81 @@ +import argparse +from openai import OpenAI +import gradio as gr + +# Argument parser setup +parser = argparse.ArgumentParser( + description='Chatbot Interface with Customizable Parameters') +parser.add_argument('--model-url', + type=str, + default='http://localhost:8000/v1', + help='Model URL') +parser.add_argument('-m', + '--model', + type=str, + required=True, + help='Model name for the chatbot') +parser.add_argument('--temp', + type=float, + default=0.8, + help='Temperature for text generation') +parser.add_argument('--stop-token-ids', + type=str, + default='', + help='Comma-separated stop token IDs') +parser.add_argument("--host", type=str, default=None) +parser.add_argument("--port", type=int, default=8001) + +# Parse the arguments +args = parser.parse_args() + +# Set OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = args.model_url + +# Create an OpenAI client to interact with the API server +client = OpenAI( + api_key=openai_api_key, + base_url=openai_api_base, +) + + +def predict(message, history): + # Convert chat history to OpenAI format + history_openai_format = [{ + "role": "system", + "content": "You are a great ai assistant." + }] + for human, assistant in history: + history_openai_format.append({"role": "user", "content": human}) + history_openai_format.append({ + "role": "assistant", + "content": assistant + }) + history_openai_format.append({"role": "user", "content": message}) + + # Create a chat completion request and send it to the API server + stream = client.chat.completions.create( + model=args.model, # Model name to use + messages=history_openai_format, # Chat history + temperature=args.temp, # Temperature for text generation + stream=True, # Stream response + extra_body={ + 'repetition_penalty': + 1, + 'stop_token_ids': [ + int(id.strip()) for id in args.stop_token_ids.split(',') + if id.strip() + ] if args.stop_token_ids else [] + }) + + # Read and return generated text from response stream + partial_message = "" + for chunk in stream: + partial_message += (chunk.choices[0].delta.content or "") + yield partial_message + + +# Create and launch a chat interface with Gradio +gr.ChatInterface(predict).queue().launch(server_name=args.host, + server_port=args.port, + share=True)