Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix streaming handling for builtin assistants #462

Merged
merged 8 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ dependencies:
- pip
- git-lfs
- pip:
- httpx_sse
- ijson
- sse_starlette
- python-dotenv
- pytest >=6
- pytest-mock
Expand Down
49 changes: 49 additions & 0 deletions tests/assistants/streaming_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import json
import random

import sse_starlette
from fastapi import FastAPI, Request, Response, status
from fastapi.responses import StreamingResponse

app = FastAPI()


@app.get("/health")
async def health():
return Response(b"", status_code=status.HTTP_200_OK)


@app.post("/sse")
async def sse(request: Request):
data = await request.json()

async def stream():
for obj in data:
yield sse_starlette.ServerSentEvent(json.dumps(obj))

return sse_starlette.EventSourceResponse(stream())


@app.post("/jsonl")
async def jsonl(request: Request):
data = await request.json()

async def stream():
for obj in data:
yield f"{json.dumps(obj)}\n"

return StreamingResponse(stream())


@app.post("/json")
async def json_(request: Request):
data = await request.body()

async def stream():
start = 0
while start < len(data):
end = start + random.randint(1, 10)
yield data[start:end]
start = end

return StreamingResponse(stream())
112 changes: 110 additions & 2 deletions tests/assistants/test_api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import asyncio
import itertools
import json
import os
import time
from pathlib import Path

import httpx
import pytest

from ragna import assistants
from ragna._compat import anext
from ragna.assistants._http_api import HttpApiAssistant
from ragna._utils import timeout_after
from ragna.assistants._http_api import HttpApiAssistant, HttpStreamingProtocol
from ragna.core import Message, RagnaException
from tests.utils import skip_on_windows
from tests.utils import background_subprocess, get_available_port, skip_on_windows

HTTP_API_ASSISTANTS = [
assistant
Expand All @@ -30,3 +37,104 @@ async def test_api_call_error_smoke(mocker, assistant):

with pytest.raises(RagnaException, match="API call failed"):
await anext(chunks)


@pytest.fixture
def streaming_server():
port = get_available_port()
base_url = f"http://localhost:{port}"

with background_subprocess(
"uvicorn",
f"--app-dir={Path(__file__).parent}",
f"--port={port}",
"streaming_server:app",
):

def up():
try:
return httpx.get(f"{base_url}/health").is_success
except httpx.ConnectError:
return False

@timeout_after(10, message="Streaming server failed to start")
def wait():
while not up():
time.sleep(0.2)

wait()

yield base_url


class HttpStreamingAssistant(HttpApiAssistant):
_API_KEY_ENV_VAR = None

@staticmethod
def new(base_url, streaming_protocol):
cls = type(
f"{streaming_protocol.name.title()}{HttpStreamingAssistant.__name__}",
(HttpStreamingAssistant,),
dict(_STREAMING_PROTOCOL=streaming_protocol),
)
return cls(base_url)

def __init__(self, base_url):
super().__init__()
self._endpoint = f"{base_url}/{self._STREAMING_PROTOCOL.name.lower()}"

async def answer(self, messages):
if self._STREAMING_PROTOCOL is HttpStreamingProtocol.JSON:
parse_kwargs = dict(item="item")
else:
parse_kwargs = dict()

async for chunk in self._call_api(
"POST",
self._endpoint,
content=messages[-1].content,
parse_kwargs=parse_kwargs,
):
if chunk.get("break"):
break

yield chunk


@pytest.mark.parametrize("streaming_protocol", list(HttpStreamingProtocol))
async def test_http_streaming(streaming_server, streaming_protocol):
assistant = HttpStreamingAssistant.new(streaming_server, streaming_protocol)

data = [{"chunk": chunk} for chunk in ["foo", "bar", "baz"]]
expected_chunks = iter(data)
actual_chunks = assistant.answer([Message(content=json.dumps(data))])
async for actual_chunk in actual_chunks:
expected_chunk = next(expected_chunks)
assert actual_chunk == expected_chunk

with pytest.raises(StopIteration):
next(expected_chunks)


@pytest.mark.parametrize("streaming_protocol", list(HttpStreamingProtocol))
def test_http_streaming_termination(streaming_server, streaming_protocol):
# Non-regression test for https://github.com/Quansight/ragna/pull/462
Comment on lines +122 to +124
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nenb This test fails when I revert my patch in this PR. See the CI runs for 0b9211e.


async def main():
assistant = HttpStreamingAssistant.new(streaming_server, streaming_protocol)

data = [
{"chunk": "foo", "break": False},
{"chunk": "bar", "break": False},
{"chunk": "baz", "break": True},
]
expected_chunks = itertools.takewhile(lambda chunk: not chunk["break"], data)
actual_chunks = assistant.answer([Message(content=json.dumps(data))])
async for actual_chunk in actual_chunks:
expected_chunk = next(expected_chunks)
assert actual_chunk == expected_chunk

with pytest.raises(StopIteration):
next(expected_chunks)

asyncio.run(main())
8 changes: 1 addition & 7 deletions tests/deploy/ui/test_ui.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import socket
import subprocess
import sys
import time
Expand All @@ -11,12 +10,7 @@
from ragna._utils import timeout_after
from ragna.deploy import Config
from tests.deploy.utils import TestAssistant


def get_available_port():
with socket.socket() as s:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved it to the generic test utils as this is no longer just needed for deploy.

s.bind(("", 0))
return s.getsockname()[1]
from tests.utils import get_available_port


@pytest.fixture
Expand Down
20 changes: 20 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,27 @@
import contextlib
import platform
import socket
import subprocess
import sys

import pytest

skip_on_windows = pytest.mark.skipif(
platform.system() == "Windows", reason="Test is broken skipped on Windows"
)


@contextlib.contextmanager
def background_subprocess(*args, stdout=sys.stdout, stderr=sys.stdout, **kwargs):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had this before, but it was removed in #322. Probably can also be used by the UI tests, but we can do that in a follow-up.

process = subprocess.Popen(args, stdout=stdout, stderr=stderr, **kwargs)
try:
yield process
finally:
process.kill()
process.communicate()


def get_available_port():
with socket.socket() as s:
s.bind(("", 0))
return s.getsockname()[1]
Loading