-
Notifications
You must be signed in to change notification settings - Fork 119
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
WIP: End to End test for vision and audio (#386)
* update readme * extract audio related code into audio utils * add test cases for audio and vision * revert docs v2 * revert docs v2 * fix test cases * add test for text only vision case * add text only case for audio * format code * skip text test for not to see updated coverage * revert cli doc from main branch
- Loading branch information
Showing
3 changed files
with
246 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import pytest | ||
import torch | ||
from asgi_lifespan import LifespanManager | ||
from fastapi import status | ||
from httpx import AsyncClient | ||
|
||
from infinity_emb import create_server | ||
from infinity_emb.args import EngineArgs | ||
from infinity_emb.primitives import Device, InferenceEngine | ||
|
||
PREFIX = "/v1_ct2" | ||
MODEL: str = pytest.DEFAULT_AUDIO_MODEL # type: ignore[assignment] | ||
batch_size = 32 if torch.cuda.is_available() else 8 | ||
|
||
app = create_server( | ||
url_prefix=PREFIX, | ||
engine_args_list=[ | ||
EngineArgs( | ||
model_name_or_path=MODEL, | ||
batch_size=batch_size, | ||
engine=InferenceEngine.torch, | ||
device=Device.auto if not torch.backends.mps.is_available() else Device.cpu, | ||
) | ||
], | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
async def client(): | ||
async with AsyncClient( | ||
app=app, base_url="http://test", timeout=20 | ||
) as client, LifespanManager(app): | ||
yield client | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_model_route(client): | ||
response = await client.get(f"{PREFIX}/models") | ||
assert response.status_code == 200 | ||
rdata = response.json() | ||
assert "data" in rdata | ||
assert rdata["data"][0].get("id", "") == MODEL | ||
assert isinstance(rdata["data"][0].get("stats"), dict) | ||
assert "audio_embed" in rdata["data"][0]["capabilities"] | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_audio_single(client): | ||
audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" | ||
|
||
response = await client.post( | ||
f"{PREFIX}/embeddings_audio", | ||
json={"model": MODEL, "input": audio_url}, | ||
) | ||
assert response.status_code == 200 | ||
rdata = response.json() | ||
assert "model" in rdata | ||
assert "usage" in rdata | ||
rdata_results = rdata["data"] | ||
assert rdata_results[0]["object"] == "embedding" | ||
assert len(rdata_results[0]["embedding"]) > 0 | ||
|
||
|
||
@pytest.mark.anyio | ||
@pytest.mark.skip("text only") | ||
async def test_audio_single_text_only(client): | ||
text = "a sound of a at" | ||
|
||
response = await client.post( | ||
f"{PREFIX}/embeddings_audio", | ||
json={"model": MODEL, "input": text}, | ||
) | ||
assert response.status_code == 200 | ||
rdata = response.json() | ||
assert "model" in rdata | ||
assert "usage" in rdata | ||
rdata_results = rdata["data"] | ||
assert rdata_results[0]["object"] == "embedding" | ||
assert len(rdata_results[0]["embedding"]) > 0 | ||
|
||
|
||
@pytest.mark.anyio | ||
@pytest.mark.parametrize("no_of_audios", [1, 5, 10]) | ||
async def test_audio_multiple(client, no_of_audios): | ||
audio_urls = [ | ||
"https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav" | ||
] * no_of_audios | ||
|
||
response = await client.post( | ||
f"{PREFIX}/embeddings_audio", | ||
json={"model": MODEL, "input": audio_urls}, | ||
) | ||
assert response.status_code == 200 | ||
rdata = response.json() | ||
rdata_results = rdata["data"] | ||
assert len(rdata_results) == no_of_audios | ||
assert "model" in rdata | ||
assert "usage" in rdata | ||
assert rdata_results[0]["object"] == "embedding" | ||
assert len(rdata_results[0]["embedding"]) > 0 | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_audio_fail(client): | ||
audio_url = "https://www.google.com/404" | ||
|
||
response = await client.post( | ||
f"{PREFIX}/embeddings_audio", | ||
json={"model": MODEL, "input": audio_url}, | ||
) | ||
assert response.status_code == status.HTTP_400_BAD_REQUEST | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_audio_empty(client): | ||
audio_url_empty = [] | ||
|
||
response_empty = await client.post( | ||
f"{PREFIX}/embeddings_audio", | ||
json={"model": MODEL, "input": audio_url_empty}, | ||
) | ||
assert response_empty.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY |
121 changes: 121 additions & 0 deletions
121
libs/infinity_emb/tests/end_to_end/test_torch_vision.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
import pytest | ||
import torch | ||
from asgi_lifespan import LifespanManager | ||
from fastapi import status | ||
from httpx import AsyncClient | ||
|
||
from infinity_emb import create_server | ||
from infinity_emb.args import EngineArgs | ||
from infinity_emb.primitives import Device, InferenceEngine | ||
|
||
PREFIX = "/v1_ct2" | ||
MODEL: str = pytest.DEFAULT_VISION_MODEL # type: ignore[assignment] | ||
batch_size = 32 if torch.cuda.is_available() else 8 | ||
|
||
app = create_server( | ||
url_prefix=PREFIX, | ||
engine_args_list=[ | ||
EngineArgs( | ||
model_name_or_path=MODEL, | ||
batch_size=batch_size, | ||
engine=InferenceEngine.torch, | ||
device=Device.auto if not torch.backends.mps.is_available() else Device.cpu, | ||
) | ||
], | ||
) | ||
|
||
|
||
@pytest.fixture() | ||
async def client(): | ||
async with AsyncClient( | ||
app=app, base_url="http://test", timeout=20 | ||
) as client, LifespanManager(app): | ||
yield client | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_model_route(client): | ||
response = await client.get(f"{PREFIX}/models") | ||
assert response.status_code == 200 | ||
rdata = response.json() | ||
assert "data" in rdata | ||
assert rdata["data"][0].get("id", "") == MODEL | ||
assert isinstance(rdata["data"][0].get("stats"), dict) | ||
assert "image_embed" in rdata["data"][0]["capabilities"] | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_vision_single(client): | ||
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" | ||
|
||
response = await client.post( | ||
f"{PREFIX}/embeddings_image", | ||
json={"model": MODEL, "input": image_url}, | ||
) | ||
assert response.status_code == 200 | ||
rdata = response.json() | ||
assert "model" in rdata | ||
assert "usage" in rdata | ||
rdata_results = rdata["data"] | ||
assert rdata_results[0]["object"] == "embedding" | ||
assert len(rdata_results[0]["embedding"]) > 0 | ||
|
||
|
||
@pytest.mark.anyio | ||
@pytest.mark.skip("text only") | ||
async def test_vision_single_text_only(client): | ||
text = "a image of a cat" | ||
|
||
response = await client.post( | ||
f"{PREFIX}/embeddings_image", | ||
json={"model": MODEL, "input": text}, | ||
) | ||
assert response.status_code == 200 | ||
rdata = response.json() | ||
assert "model" in rdata | ||
assert "usage" in rdata | ||
rdata_results = rdata["data"] | ||
assert rdata_results[0]["object"] == "embedding" | ||
assert len(rdata_results[0]["embedding"]) > 0 | ||
|
||
|
||
@pytest.mark.anyio | ||
@pytest.mark.parametrize("no_of_images", [1, 5, 10]) | ||
async def test_vision_multiple(client, no_of_images): | ||
image_urls = [ | ||
"http://images.cocodataset.org/val2017/000000039769.jpg" | ||
] * no_of_images | ||
|
||
response = await client.post( | ||
f"{PREFIX}/embeddings_image", | ||
json={"model": MODEL, "input": image_urls}, | ||
) | ||
assert response.status_code == 200 | ||
rdata = response.json() | ||
rdata_results = rdata["data"] | ||
assert len(rdata_results) == no_of_images | ||
assert "model" in rdata | ||
assert "usage" in rdata | ||
assert rdata_results[0]["object"] == "embedding" | ||
assert len(rdata_results[0]["embedding"]) > 0 | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_vision_fail(client): | ||
image_url = "https://www.google.com/404" | ||
|
||
response = await client.post( | ||
f"{PREFIX}/embeddings_image", | ||
json={"model": MODEL, "input": image_url}, | ||
) | ||
assert response.status_code == status.HTTP_400_BAD_REQUEST | ||
|
||
|
||
@pytest.mark.anyio | ||
async def test_vision_empty(client): | ||
image_url_empty = [] | ||
response = await client.post( | ||
f"{PREFIX}/embeddings_image", | ||
json={"model": MODEL, "input": image_url_empty}, | ||
) | ||
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY |