diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index 12810f5a7..a09cd8886 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -9,6 +9,8 @@ from llama_index.core.llms import ChatMessage, MessageRole from app.engine import get_chat_engine from app.api.routers.vercel_response import VercelStreamResponse +from app.api.routers.messaging import EventCallbackHandler +import asyncio chat_router = r = APIRouter() @@ -92,15 +94,38 @@ async def chat( ): last_message_content, messages = await parse_chat_data(data) + event_handler = EventCallbackHandler() + chat_engine.callback_manager.handlers.append(event_handler) response = await chat_engine.astream_chat(last_message_content, messages) - async def event_generator(request: Request, response: StreamingAgentChatResponse): + async def content_generator(): # Yield the text response - async for token in response.async_response_gen(): - # If client closes connection, stop sending events - if await request.is_disconnected(): - break - yield VercelStreamResponse.convert_text(token) + async def _text_generator(): + async for token in response.async_response_gen(): + yield VercelStreamResponse.convert_text(token) + # TODO: ideally we don't need is_done and we just consume till _text_generator is finished below + event_handler.is_done = True + + # Yield the events from the event handler + async def _event_generator(): + async for event in event_handler.async_event_gen(): + yield VercelStreamResponse.convert_data( + { + "type": "events", + "data": {"title": event.get_title()}, + } + ) + + # TODO: idea here is to to consume items yielded by both of the generators above in the order they are coming in + # Snippet below doesn't work - produces this error: + # async for item in asyncio.as_completed([_text_generator(), _event_generator()]): + # TypeError: 'async for' requires an object with __aiter__ method, got generator + async for item in asyncio.as_completed([_text_generator(), _event_generator()]): + async for value in item: + # If client closes connection, stop sending events + if await request.is_disconnected(): + break + yield value # Yield the source nodes yield VercelStreamResponse.convert_data( @@ -115,7 +140,7 @@ async def event_generator(request: Request, response: StreamingAgentChatResponse } ) - return VercelStreamResponse(content=event_generator(request, response)) + return VercelStreamResponse(content=content_generator()) # non-streaming endpoint - delete if not needed diff --git a/templates/types/streaming/fastapi/app/api/routers/messaging.py b/templates/types/streaming/fastapi/app/api/routers/messaging.py new file mode 100644 index 000000000..f8b0c0d6e --- /dev/null +++ b/templates/types/streaming/fastapi/app/api/routers/messaging.py @@ -0,0 +1,83 @@ +import asyncio +from typing import AsyncGenerator, Dict, Any, List, Optional + +from llama_index.core.callbacks.base import BaseCallbackHandler +from llama_index.core.callbacks.schema import CBEventType, EventPayload +from pydantic import BaseModel + + +class CallbackEvent(BaseModel): + event_type: CBEventType + payload: Optional[Dict[str, Any]] = None + event_id: str = "" + + def get_title(self): + # TODO: we get two CBEventType.RETRIEVE events + # For the on_event_start we should render: + # "Retrieving context for query " + # For the on_event_end we should render: + # "Retrieved sources to use as context for the query" + return self.event_id + + +class EventCallbackHandler(BaseCallbackHandler): + _aqueue: asyncio.Queue + is_done: False + + def __init__( + self, + ): + """Initialize the base callback handler.""" + ignored_events = [ + CBEventType.CHUNKING, + CBEventType.NODE_PARSING, + CBEventType.EMBEDDING, + CBEventType.LLM, + ] + super().__init__(ignored_events, ignored_events) + self._aqueue = asyncio.Queue() + + def on_event_start( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> str: + self._aqueue.put_nowait( + CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) + ) + + def on_event_end( + self, + event_type: CBEventType, + payload: Optional[Dict[str, Any]] = None, + event_id: str = "", + **kwargs: Any, + ) -> None: + self._aqueue.put_nowait( + CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) + ) + + def start_trace(self, trace_id: Optional[str] = None) -> None: + """No-op.""" + + def end_trace( + self, + trace_id: Optional[str] = None, + trace_map: Optional[Dict[str, List[str]]] = None, + ) -> None: + """No-op.""" + + async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]: + while True: + if not self._aqueue.empty() or not self.is_done: + try: + event = await asyncio.wait_for(self._aqueue.get(), timeout=0.1) + except asyncio.TimeoutError: + if self.is_done: + break + continue + yield event + else: + break