From 638205c8e21de0e3f87a6e4809ea2eb246079a77 Mon Sep 17 00:00:00 2001 From: wirthual Date: Sat, 28 Sep 2024 20:15:49 -0700 Subject: [PATCH] WIP: End to End test for vision and audio (#386) * update readme * extract audio related code into audio utils * add test cases for audio and vision * revert docs v2 * revert docs v2 * fix test cases * add test for text only vision case * add text only case for audio * format code * skip text test for not to see updated coverage * revert cli doc from main branch --- libs/infinity_emb/tests/conftest.py | 4 +- .../tests/end_to_end/test_torch_audio.py | 122 ++++++++++++++++++ .../tests/end_to_end/test_torch_vision.py | 121 +++++++++++++++++ 3 files changed, 246 insertions(+), 1 deletion(-) create mode 100644 libs/infinity_emb/tests/end_to_end/test_torch_audio.py create mode 100644 libs/infinity_emb/tests/end_to_end/test_torch_vision.py diff --git a/libs/infinity_emb/tests/conftest.py b/libs/infinity_emb/tests/conftest.py index 4fec5bd5..df557695 100644 --- a/libs/infinity_emb/tests/conftest.py +++ b/libs/infinity_emb/tests/conftest.py @@ -8,8 +8,10 @@ pytest.DEFAULT_BERT_MODEL = "michaelfeil/bge-small-en-v1.5" pytest.DEFAULT_RERANKER_MODEL = "mixedbread-ai/mxbai-rerank-xsmall-v1" pytest.DEFAULT_CLASSIFIER_MODEL = "SamLowe/roberta-base-go_emotions" +pytest.DEFAULT_AUDIO_MODEL = "laion/clap-htsat-unfused" +pytest.DEFAULT_VISION_MODEL = "wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M" -pytest.ENGINE_METHODS = ["embed", "image_embed", "classify", "rerank"] +pytest.ENGINE_METHODS = ["embed", "image_embed", "classify", "rerank", "audio_embed"] @pytest.fixture diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_audio.py b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py new file mode 100644 index 00000000..80625b43 --- /dev/null +++ b/libs/infinity_emb/tests/end_to_end/test_torch_audio.py @@ -0,0 +1,122 @@ +import pytest +import torch +from asgi_lifespan import LifespanManager +from fastapi import status +from httpx import AsyncClient + +from infinity_emb import create_server +from infinity_emb.args import EngineArgs +from infinity_emb.primitives import Device, InferenceEngine + +PREFIX = "/v1_ct2" +MODEL: str = pytest.DEFAULT_AUDIO_MODEL # type: ignore[assignment] +batch_size = 32 if torch.cuda.is_available() else 8 + +app = create_server( + url_prefix=PREFIX, + engine_args_list=[ + EngineArgs( + model_name_or_path=MODEL, + batch_size=batch_size, + engine=InferenceEngine.torch, + device=Device.auto if not torch.backends.mps.is_available() else Device.cpu, + ) + ], +) + + +@pytest.fixture() +async def client(): + async with AsyncClient( + app=app, base_url="http://test", timeout=20 + ) as client, LifespanManager(app): + yield client + + +@pytest.mark.anyio +async def test_model_route(client): + response = await client.get(f"{PREFIX}/models") + assert response.status_code == 200 + rdata = response.json() + assert "data" in rdata + assert rdata["data"][0].get("id", "") == MODEL + assert isinstance(rdata["data"][0].get("stats"), dict) + assert "audio_embed" in rdata["data"][0]["capabilities"] + + +@pytest.mark.anyio +async def test_audio_single(client): + audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" + + response = await client.post( + f"{PREFIX}/embeddings_audio", + json={"model": MODEL, "input": audio_url}, + ) + assert response.status_code == 200 + rdata = response.json() + assert "model" in rdata + assert "usage" in rdata + rdata_results = rdata["data"] + assert rdata_results[0]["object"] == "embedding" + assert len(rdata_results[0]["embedding"]) > 0 + + +@pytest.mark.anyio +@pytest.mark.skip("text only") +async def test_audio_single_text_only(client): + text = "a sound of a at" + + response = await client.post( + f"{PREFIX}/embeddings_audio", + json={"model": MODEL, "input": text}, + ) + assert response.status_code == 200 + rdata = response.json() + assert "model" in rdata + assert "usage" in rdata + rdata_results = rdata["data"] + assert rdata_results[0]["object"] == "embedding" + assert len(rdata_results[0]["embedding"]) > 0 + + +@pytest.mark.anyio +@pytest.mark.parametrize("no_of_audios", [1, 5, 10]) +async def test_audio_multiple(client, no_of_audios): + audio_urls = [ + "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" + ] * no_of_audios + + response = await client.post( + f"{PREFIX}/embeddings_audio", + json={"model": MODEL, "input": audio_urls}, + ) + assert response.status_code == 200 + rdata = response.json() + rdata_results = rdata["data"] + assert len(rdata_results) == no_of_audios + assert "model" in rdata + assert "usage" in rdata + assert rdata_results[0]["object"] == "embedding" + assert len(rdata_results[0]["embedding"]) > 0 + + +@pytest.mark.anyio +async def test_audio_fail(client): + audio_url = "https://www.google.com/404" + + response = await client.post( + f"{PREFIX}/embeddings_audio", + json={"model": MODEL, "input": audio_url}, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + +@pytest.mark.anyio +async def test_audio_empty(client): + audio_url_empty = [] + + response_empty = await client.post( + f"{PREFIX}/embeddings_audio", + json={"model": MODEL, "input": audio_url_empty}, + ) + assert response_empty.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY diff --git a/libs/infinity_emb/tests/end_to_end/test_torch_vision.py b/libs/infinity_emb/tests/end_to_end/test_torch_vision.py new file mode 100644 index 00000000..f6e5b8a1 --- /dev/null +++ b/libs/infinity_emb/tests/end_to_end/test_torch_vision.py @@ -0,0 +1,121 @@ +import pytest +import torch +from asgi_lifespan import LifespanManager +from fastapi import status +from httpx import AsyncClient + +from infinity_emb import create_server +from infinity_emb.args import EngineArgs +from infinity_emb.primitives import Device, InferenceEngine + +PREFIX = "/v1_ct2" +MODEL: str = pytest.DEFAULT_VISION_MODEL # type: ignore[assignment] +batch_size = 32 if torch.cuda.is_available() else 8 + +app = create_server( + url_prefix=PREFIX, + engine_args_list=[ + EngineArgs( + model_name_or_path=MODEL, + batch_size=batch_size, + engine=InferenceEngine.torch, + device=Device.auto if not torch.backends.mps.is_available() else Device.cpu, + ) + ], +) + + +@pytest.fixture() +async def client(): + async with AsyncClient( + app=app, base_url="http://test", timeout=20 + ) as client, LifespanManager(app): + yield client + + +@pytest.mark.anyio +async def test_model_route(client): + response = await client.get(f"{PREFIX}/models") + assert response.status_code == 200 + rdata = response.json() + assert "data" in rdata + assert rdata["data"][0].get("id", "") == MODEL + assert isinstance(rdata["data"][0].get("stats"), dict) + assert "image_embed" in rdata["data"][0]["capabilities"] + + +@pytest.mark.anyio +async def test_vision_single(client): + image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + response = await client.post( + f"{PREFIX}/embeddings_image", + json={"model": MODEL, "input": image_url}, + ) + assert response.status_code == 200 + rdata = response.json() + assert "model" in rdata + assert "usage" in rdata + rdata_results = rdata["data"] + assert rdata_results[0]["object"] == "embedding" + assert len(rdata_results[0]["embedding"]) > 0 + + +@pytest.mark.anyio +@pytest.mark.skip("text only") +async def test_vision_single_text_only(client): + text = "a image of a cat" + + response = await client.post( + f"{PREFIX}/embeddings_image", + json={"model": MODEL, "input": text}, + ) + assert response.status_code == 200 + rdata = response.json() + assert "model" in rdata + assert "usage" in rdata + rdata_results = rdata["data"] + assert rdata_results[0]["object"] == "embedding" + assert len(rdata_results[0]["embedding"]) > 0 + + +@pytest.mark.anyio +@pytest.mark.parametrize("no_of_images", [1, 5, 10]) +async def test_vision_multiple(client, no_of_images): + image_urls = [ + "http://images.cocodataset.org/val2017/000000039769.jpg" + ] * no_of_images + + response = await client.post( + f"{PREFIX}/embeddings_image", + json={"model": MODEL, "input": image_urls}, + ) + assert response.status_code == 200 + rdata = response.json() + rdata_results = rdata["data"] + assert len(rdata_results) == no_of_images + assert "model" in rdata + assert "usage" in rdata + assert rdata_results[0]["object"] == "embedding" + assert len(rdata_results[0]["embedding"]) > 0 + + +@pytest.mark.anyio +async def test_vision_fail(client): + image_url = "https://www.google.com/404" + + response = await client.post( + f"{PREFIX}/embeddings_image", + json={"model": MODEL, "input": image_url}, + ) + assert response.status_code == status.HTTP_400_BAD_REQUEST + + +@pytest.mark.anyio +async def test_vision_empty(client): + image_url_empty = [] + response = await client.post( + f"{PREFIX}/embeddings_image", + json={"model": MODEL, "input": image_url_empty}, + ) + assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY