Skip to content

Commit

Permalink
WIP: End to End test for vision and audio (#386)
Browse files Browse the repository at this point in the history
* update readme

* extract audio related code into audio utils

* add test cases for audio and vision

* revert docs v2

* revert docs v2

* fix test cases

* add test for text only vision case

* add text only case for audio

* format code

* skip text test for not to see updated coverage

* revert cli doc from main branch
  • Loading branch information
wirthual authored Sep 29, 2024
1 parent 6c1ad68 commit 638205c
Show file tree
Hide file tree
Showing 3 changed files with 246 additions and 1 deletion.
4 changes: 3 additions & 1 deletion libs/infinity_emb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
pytest.DEFAULT_BERT_MODEL = "michaelfeil/bge-small-en-v1.5"
pytest.DEFAULT_RERANKER_MODEL = "mixedbread-ai/mxbai-rerank-xsmall-v1"
pytest.DEFAULT_CLASSIFIER_MODEL = "SamLowe/roberta-base-go_emotions"
pytest.DEFAULT_AUDIO_MODEL = "laion/clap-htsat-unfused"
pytest.DEFAULT_VISION_MODEL = "wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M"

pytest.ENGINE_METHODS = ["embed", "image_embed", "classify", "rerank"]
pytest.ENGINE_METHODS = ["embed", "image_embed", "classify", "rerank", "audio_embed"]


@pytest.fixture
Expand Down
122 changes: 122 additions & 0 deletions libs/infinity_emb/tests/end_to_end/test_torch_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest
import torch
from asgi_lifespan import LifespanManager
from fastapi import status
from httpx import AsyncClient

from infinity_emb import create_server
from infinity_emb.args import EngineArgs
from infinity_emb.primitives import Device, InferenceEngine

PREFIX = "/v1_ct2"
MODEL: str = pytest.DEFAULT_AUDIO_MODEL # type: ignore[assignment]
batch_size = 32 if torch.cuda.is_available() else 8

app = create_server(
url_prefix=PREFIX,
engine_args_list=[
EngineArgs(
model_name_or_path=MODEL,
batch_size=batch_size,
engine=InferenceEngine.torch,
device=Device.auto if not torch.backends.mps.is_available() else Device.cpu,
)
],
)


@pytest.fixture()
async def client():
async with AsyncClient(
app=app, base_url="http://test", timeout=20
) as client, LifespanManager(app):
yield client


@pytest.mark.anyio
async def test_model_route(client):
response = await client.get(f"{PREFIX}/models")
assert response.status_code == 200
rdata = response.json()
assert "data" in rdata
assert rdata["data"][0].get("id", "") == MODEL
assert isinstance(rdata["data"][0].get("stats"), dict)
assert "audio_embed" in rdata["data"][0]["capabilities"]


@pytest.mark.anyio
async def test_audio_single(client):
audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": audio_url},
)
assert response.status_code == 200
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
rdata_results = rdata["data"]
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
@pytest.mark.skip("text only")
async def test_audio_single_text_only(client):
text = "a sound of a at"

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": text},
)
assert response.status_code == 200
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
rdata_results = rdata["data"]
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
@pytest.mark.parametrize("no_of_audios", [1, 5, 10])
async def test_audio_multiple(client, no_of_audios):
audio_urls = [
"https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"
] * no_of_audios

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": audio_urls},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["data"]
assert len(rdata_results) == no_of_audios
assert "model" in rdata
assert "usage" in rdata
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
async def test_audio_fail(client):
audio_url = "https://www.google.com/404"

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": audio_url},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST


@pytest.mark.anyio
async def test_audio_empty(client):
audio_url_empty = []

response_empty = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": audio_url_empty},
)
assert response_empty.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
121 changes: 121 additions & 0 deletions libs/infinity_emb/tests/end_to_end/test_torch_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
import torch
from asgi_lifespan import LifespanManager
from fastapi import status
from httpx import AsyncClient

from infinity_emb import create_server
from infinity_emb.args import EngineArgs
from infinity_emb.primitives import Device, InferenceEngine

PREFIX = "/v1_ct2"
MODEL: str = pytest.DEFAULT_VISION_MODEL # type: ignore[assignment]
batch_size = 32 if torch.cuda.is_available() else 8

app = create_server(
url_prefix=PREFIX,
engine_args_list=[
EngineArgs(
model_name_or_path=MODEL,
batch_size=batch_size,
engine=InferenceEngine.torch,
device=Device.auto if not torch.backends.mps.is_available() else Device.cpu,
)
],
)


@pytest.fixture()
async def client():
async with AsyncClient(
app=app, base_url="http://test", timeout=20
) as client, LifespanManager(app):
yield client


@pytest.mark.anyio
async def test_model_route(client):
response = await client.get(f"{PREFIX}/models")
assert response.status_code == 200
rdata = response.json()
assert "data" in rdata
assert rdata["data"][0].get("id", "") == MODEL
assert isinstance(rdata["data"][0].get("stats"), dict)
assert "image_embed" in rdata["data"][0]["capabilities"]


@pytest.mark.anyio
async def test_vision_single(client):
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"

response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_url},
)
assert response.status_code == 200
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
rdata_results = rdata["data"]
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
@pytest.mark.skip("text only")
async def test_vision_single_text_only(client):
text = "a image of a cat"

response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": text},
)
assert response.status_code == 200
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
rdata_results = rdata["data"]
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
@pytest.mark.parametrize("no_of_images", [1, 5, 10])
async def test_vision_multiple(client, no_of_images):
image_urls = [
"http://images.cocodataset.org/val2017/000000039769.jpg"
] * no_of_images

response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_urls},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["data"]
assert len(rdata_results) == no_of_images
assert "model" in rdata
assert "usage" in rdata
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
async def test_vision_fail(client):
image_url = "https://www.google.com/404"

response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_url},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST


@pytest.mark.anyio
async def test_vision_empty(client):
image_url_empty = []
response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_url_empty},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY

0 comments on commit 638205c

Please sign in to comment.