Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add asyncio support for tritonclient (beta) #23

Merged
merged 3 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/pre-commit_pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ jobs:
TRITON_CLIENT_TIMEOUT: 1
run: |
tritonserver --model-repo=$GITHUB_WORKSPACE/model_repository &
pip install .[tests]
pip install uv
uv pip install .[tests]
sleep 3

curl -v ${TRITON_HOST}:${TRITON_HTTP}/v2/health/ready
Expand Down
2 changes: 1 addition & 1 deletion bin/run_triton_tritony_sample.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ docker run -it --rm --name triton_tritony \
-e OMP_NUM_THREADS=2 \
-e OPENBLAS_NUM_THREADS=2 \
--shm-size=1g \
nvcr.io/nvidia/tritonserver:23.08-pyt-python-py3 \
nvcr.io/nvidia/tritonserver:24.05-pyt-python-py3 \
tritonserver --model-repository=/models \
--exit-timeout-secs 15 \
--min-supported-compute-capability 7.0 \
Expand Down
29 changes: 29 additions & 0 deletions model_repository/sample_sleep_1sec/1/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import json
import time

import triton_python_backend_utils as pb_utils


class TritonPythonModel:
def initialize(self, args):
self.model_config = model_config = json.loads(args["model_config"])
output_configs = model_config["output"]

self.output_name_list = [output_config["name"] for output_config in output_configs]
self.output_dtype_list = [
pb_utils.triton_string_to_numpy(output_config["data_type"]) for output_config in output_configs
]

def execute(self, requests):
responses = [None for _ in requests]
for idx, request in enumerate(requests):
current_add_value = int(json.loads(request.parameters()).get("add", 0))
in_tensor = [item.as_numpy() + current_add_value for item in request.inputs() if item.name() == "model_in"]
out_tensor = [
pb_utils.Tensor(output_name, x.astype(output_dtype))
for x, output_name, output_dtype in zip(in_tensor, self.output_name_list, self.output_dtype_list)
]
inference_response = pb_utils.InferenceResponse(output_tensors=out_tensor)
responses[idx] = inference_response
time.sleep(1)
return responses
40 changes: 40 additions & 0 deletions model_repository/sample_sleep_1sec/config.pbtxt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
backend: "python"
max_batch_size: 0

input [
{
name: "model_in"
data_type: TYPE_FP32
dims: [ -1 ]
}
]

output [
{
name: "model_out"
data_type: TYPE_FP32
dims: [ -1 ]
}
]

instance_group [{ kind: KIND_CPU, count: 10 }]

model_warmup {
name: "RandomSampleInput"
batch_size: 1
inputs [{
key: "model_in"
value: {
data_type: TYPE_FP32
dims: [ 10 ]
random_data: true
}
}, {
key: "model_in"
value: {
data_type: TYPE_FP32
dims: [ 10 ]
zero_data: true
}
}]
}
16 changes: 0 additions & 16 deletions packaging.md

This file was deleted.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,7 @@ line-length = 120
ignore = ["F811","F841","E203","E402","E501","E712","B019"]

[tool.lint.isort]
forced-separate = ["tests"]
forced-separate = ["tests"]

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"
3 changes: 3 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[pytest]
log_cli=true
log_level=NOTSET
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ tests =
pytest-xdist
pytest-mpl
pytest-cov
pytest-asyncio
pytest
pre-commit
coveralls
Expand Down
16 changes: 16 additions & 0 deletions tests/common_fixtures.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import logging
import os

import pytest

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)


MODEL_NAME = os.environ.get("MODEL_NAME", "sample")
TRITON_HOST = os.environ.get("TRITON_HOST", "localhost")
TRITON_HTTP = os.environ.get("TRITON_HTTP", "8100")
Expand All @@ -14,3 +22,11 @@ def config(request):
Returns a tuple of (protocol, port, run_async)
"""
return request.param


@pytest.fixture(params=[("http", TRITON_HTTP), ("grpc", TRITON_GRPC)])
def async_config(request):
"""
Returns a tuple of (protocol, port)
"""
return request.param
58 changes: 58 additions & 0 deletions tests/test_async_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import asyncio
import logging

import numpy as np
import pytest

from tritony import InferenceClient

from .common_fixtures import MODEL_NAME, TRITON_HOST, async_config

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)

__all__ = ["async_config"]
EPSILON = 1e-8


def get_client(protocol, port, model_name):
print(f"Testing {protocol}", flush=True)
return InferenceClient.create_with_asyncio(model_name, f"{TRITON_HOST}:{port}", protocol=protocol)


@pytest.mark.asyncio
async def test_basics(async_config):
protocol, port = async_config

client = get_client(*async_config, model_name=MODEL_NAME)
sample = np.random.rand(1, 100).astype(np.float32)

result = await client.aio_infer(sample)
assert np.isclose(result, sample).all()

result = await client.aio_infer({"model_in": sample})
assert np.isclose(result, sample).all()


@pytest.mark.asyncio
async def test_multiple_tasks(async_config):
n_multiple_tasks = 10
protocol, port = async_config
print(f"Testing {protocol}:{port}")

client_list = [get_client(*async_config, model_name="sample_sleep_1sec") for _ in range(n_multiple_tasks)]

sample = np.random.rand(1, 100).astype(np.float32)
tasks = [client.aio_infer(sample) for client in client_list]

start_time = asyncio.get_event_loop().time()
results = await asyncio.gather(*tasks)
end_time = asyncio.get_event_loop().time()

for result in results:
assert np.isclose(result, sample).all()

assert (end_time - start_time) < 2, f"Time taken: {end_time - start_time}"
59 changes: 48 additions & 11 deletions tritony/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from attrs import define
from tritonclient import grpc as grpcclient
from tritonclient import http as httpclient
from tritonclient.grpc import aio as aio_grpcclient
from tritonclient.grpc import model_config_pb2
from tritonclient.grpc.aio import model_config_pb2
from tritonclient.http import aio as aio_httpclient


class TritonProtocol(Enum):
Expand Down Expand Up @@ -77,7 +80,7 @@ class TritonClientFlag:
url: str
model_name: str
model_version: str = "1"
protocol: TritonProtocol | str = attrs.field(converter=TritonProtocol, default=TritonProtocol.grpc)
protocol: TritonProtocol = attrs.field(converter=TritonProtocol, default=TritonProtocol.grpc)
streaming: bool = False # TODO: not implemented
async_set: bool = True # TODO: not totally implemented
concurrency: int = 6 # only for TritonProtocol.http client
Expand All @@ -86,22 +89,37 @@ class TritonClientFlag:
compression_algorithm: str | None = None
ssl: bool = False

use_aio_tritonclient: bool = False


def init_triton_client(
flag: TritonClientFlag,
) -> grpcclient.InferenceServerClient | httpclient.InferenceServerClient:
) -> (
grpcclient.InferenceServerClient
| httpclient.InferenceServerClient
| aio_grpcclient.InferenceServerClient
| aio_httpclient.InferenceServerClient
):
assert not (
flag.streaming and flag.protocol is not TritonProtocol.grpc
), "Streaming is only allowed with gRPC protocol"

if flag.protocol is TritonProtocol.grpc:
# Create gRPC client for communicating with the server
triton_client = grpcclient.InferenceServerClient(url=flag.url, verbose=flag.verbose, ssl=flag.ssl)
if not flag.use_aio_tritonclient:
if flag.protocol is TritonProtocol.grpc:
# Create gRPC client for communicating with the server
triton_client = grpcclient.InferenceServerClient(url=flag.url, verbose=flag.verbose, ssl=flag.ssl)
else:
# Specify large enough concurrency to handle the
# the number of requests.
concurrency = flag.concurrency if flag.async_set else 1
triton_client = httpclient.InferenceServerClient(
url=flag.url, verbose=flag.verbose, concurrency=concurrency
)
else:
# Specify large enough concurrency to handle the
# the number of requests.
concurrency = flag.concurrency if flag.async_set else 1
triton_client = httpclient.InferenceServerClient(url=flag.url, verbose=flag.verbose, concurrency=concurrency)
if flag.protocol is TritonProtocol.grpc:
triton_client = aio_grpcclient.InferenceServerClient(url=flag.url, verbose=flag.verbose, ssl=flag.ssl)
else:
triton_client = aio_httpclient.InferenceServerClient(url=flag.url, verbose=flag.verbose)

return triton_client

Expand All @@ -111,7 +129,7 @@ def get_triton_client(
model_name: str,
model_version: str,
protocol: TritonProtocol,
) -> (int, list[TritonModelInput], list[str]):
) -> tuple[int, list[TritonModelInput], list[str]]:
"""
(required in)
:param triton_client:
Expand All @@ -138,6 +156,25 @@ def get_triton_client(
return max_batch_size, input_list, output_name_list


async def async_get_triton_client(
triton_client: aio_grpcclient.InferenceServerClient | aio_httpclient.InferenceServerClient,
model_name: str,
model_version: str,
protocol: TritonProtocol,
) -> tuple[int, list[TritonModelInput], list[str]]:
args = dict(model_name=model_name, model_version=model_version)

model_config = await triton_client.get_model_config(**args)
if protocol is TritonProtocol.http:
model_config = dict_to_attr(model_config)
elif protocol is TritonProtocol.grpc:
model_config = model_config.config

max_batch_size, input_list, output_name_list = parse_model(model_config)

return max_batch_size, input_list, output_name_list


def parse_model_input(
model_input: model_config_pb2.ModelInput | SimpleNamespace,
) -> TritonModelInput:
Expand All @@ -161,7 +198,7 @@ def parse_model_input(

def parse_model(
model_config: model_config_pb2.ModelConfig | SimpleNamespace,
) -> (int, list[TritonModelInput], list[str]):
) -> tuple[int, list[TritonModelInput], list[str]]:
return (
model_config.max_batch_size,
[parse_model_input(model_config_input) for model_config_input in model_config.input],
Expand Down
Loading
Loading