diff --git a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py index 0567280d26c..bbcb392f70a 100644 --- a/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py +++ b/python/llm/example/GPU/Pipeline-Parallel-FastAPI/pipeline_serving.py @@ -192,7 +192,6 @@ async def generate(prompt_request: PromptRequest): return request_id, "".join(output_str) -@app.post("/generate_stream/") async def generate_stream(prompt_request: PromptRequest): request_id = str(uuid.uuid4()) + "stream" await local_model.waiting_requests.put((request_id, prompt_request)) @@ -211,6 +210,11 @@ async def generate_stream(prompt_request: PromptRequest): content=cur_generator, media_type="text/event-stream" ) +@app.post("/generate_stream/") +async def generate_stream_api(prompt_request: PromptRequest): + request_id, result = await generate_stream(prompt_request) + return result + DEFAULT_SYSTEM_PROMPT = """\ """