Skip to content

Commit

Permalink
started python implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusschiesser committed Apr 18, 2024
1 parent ce6aeb5 commit 814c49c
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 7 deletions.
39 changes: 32 additions & 7 deletions templates/types/streaming/fastapi/app/api/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from llama_index.core.llms import ChatMessage, MessageRole
from app.engine import get_chat_engine
from app.api.routers.vercel_response import VercelStreamResponse
from app.api.routers.messaging import EventCallbackHandler
import asyncio

chat_router = r = APIRouter()

Expand Down Expand Up @@ -92,15 +94,38 @@ async def chat(
):
last_message_content, messages = await parse_chat_data(data)

event_handler = EventCallbackHandler()
chat_engine.callback_manager.handlers.append(event_handler)
response = await chat_engine.astream_chat(last_message_content, messages)

async def event_generator(request: Request, response: StreamingAgentChatResponse):
async def content_generator():
# Yield the text response
async for token in response.async_response_gen():
# If client closes connection, stop sending events
if await request.is_disconnected():
break
yield VercelStreamResponse.convert_text(token)
async def _text_generator():
async for token in response.async_response_gen():
yield VercelStreamResponse.convert_text(token)
# TODO: ideally we don't need is_done and we just consume till _text_generator is finished below
event_handler.is_done = True

# Yield the events from the event handler
async def _event_generator():
async for event in event_handler.async_event_gen():
yield VercelStreamResponse.convert_data(
{
"type": "events",
"data": {"title": event.get_title()},
}
)

# TODO: idea here is to to consume items yielded by both of the generators above in the order they are coming in
# Snippet below doesn't work - produces this error:
# async for item in asyncio.as_completed([_text_generator(), _event_generator()]):
# TypeError: 'async for' requires an object with __aiter__ method, got generator
async for item in asyncio.as_completed([_text_generator(), _event_generator()]):
async for value in item:
# If client closes connection, stop sending events
if await request.is_disconnected():
break
yield value

# Yield the source nodes
yield VercelStreamResponse.convert_data(
Expand All @@ -115,7 +140,7 @@ async def event_generator(request: Request, response: StreamingAgentChatResponse
}
)

return VercelStreamResponse(content=event_generator(request, response))
return VercelStreamResponse(content=content_generator())


# non-streaming endpoint - delete if not needed
Expand Down
83 changes: 83 additions & 0 deletions templates/types/streaming/fastapi/app/api/routers/messaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import asyncio
from typing import AsyncGenerator, Dict, Any, List, Optional

from llama_index.core.callbacks.base import BaseCallbackHandler
from llama_index.core.callbacks.schema import CBEventType, EventPayload
from pydantic import BaseModel


class CallbackEvent(BaseModel):
event_type: CBEventType
payload: Optional[Dict[str, Any]] = None
event_id: str = ""

def get_title(self):
# TODO: we get two CBEventType.RETRIEVE events
# For the on_event_start we should render:
# "Retrieving context for query <query_str>"
# For the on_event_end we should render:
# "Retrieved <nodes> sources to use as context for the query"
return self.event_id


class EventCallbackHandler(BaseCallbackHandler):
_aqueue: asyncio.Queue
is_done: False

def __init__(
self,
):
"""Initialize the base callback handler."""
ignored_events = [
CBEventType.CHUNKING,
CBEventType.NODE_PARSING,
CBEventType.EMBEDDING,
CBEventType.LLM,
]
super().__init__(ignored_events, ignored_events)
self._aqueue = asyncio.Queue()

def on_event_start(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> str:
self._aqueue.put_nowait(
CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
)

def on_event_end(
self,
event_type: CBEventType,
payload: Optional[Dict[str, Any]] = None,
event_id: str = "",
**kwargs: Any,
) -> None:
self._aqueue.put_nowait(
CallbackEvent(event_id=event_id, event_type=event_type, payload=payload)
)

def start_trace(self, trace_id: Optional[str] = None) -> None:
"""No-op."""

def end_trace(
self,
trace_id: Optional[str] = None,
trace_map: Optional[Dict[str, List[str]]] = None,
) -> None:
"""No-op."""

async def async_event_gen(self) -> AsyncGenerator[CallbackEvent, None]:
while True:
if not self._aqueue.empty() or not self.is_done:
try:
event = await asyncio.wait_for(self._aqueue.get(), timeout=0.1)
except asyncio.TimeoutError:
if self.is_done:
break
continue
yield event
else:
break

0 comments on commit 814c49c

Please sign in to comment.