Skip to content

Commit

Permalink
Enable the run_async=False option for the InferenceClient. (#17)
Browse files Browse the repository at this point in the history
* Add `tools.request` for `run_async=False`
* Fix type annotations, Refactor tests

---------

Co-authored-by: Dongwoo Arthur Kim <[email protected]>
  • Loading branch information
SangwonSUH and kimdwkimdw authored Dec 15, 2023
1 parent 0ef6067 commit da0998d
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 70 deletions.
16 changes: 16 additions & 0 deletions tests/common_fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import os

import pytest

MODEL_NAME = os.environ.get("MODEL_NAME", "sample")
TRITON_HOST = os.environ.get("TRITON_HOST", "localhost")
TRITON_HTTP = os.environ.get("TRITON_HTTP", "8100")
TRITON_GRPC = os.environ.get("TRITON_GRPC", "8101")


@pytest.fixture(params=[("http", TRITON_HTTP, True), ("grpc", TRITON_GRPC, True), ("grpc", TRITON_GRPC, False)])
def config(request):
"""
Returns a tuple of (protocol, port, run_async)
"""
return request.param
41 changes: 18 additions & 23 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
import os

import grpc
import numpy as np
import pytest

from tritony import InferenceClient

MODEL_NAME = os.environ.get("MODEL_NAME", "sample")
TRITON_HOST = os.environ.get("TRITON_HOST", "localhost")
TRITON_HTTP = os.environ.get("TRITON_HTTP", "8000")
TRITON_GRPC = os.environ.get("TRITON_GRPC", "8001")

from .common_fixtures import MODEL_NAME, TRITON_HOST, config

@pytest.fixture(params=[("http", TRITON_HTTP), ("grpc", TRITON_GRPC)])
def protocol_and_port(request):
return request.param
__all__ = ["config"]


def test_basics(protocol_and_port):
protocol, port = protocol_and_port
print(f"Testing {protocol}")
def test_basics(config):
protocol, port, run_async = config
print(f"Testing {protocol} with run_async={run_async}")

client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol)
client = InferenceClient.create_with(MODEL_NAME, f"{TRITON_HOST}:{port}", protocol=protocol, run_async=run_async)

sample = np.random.rand(1, 100).astype(np.float32)
result = client(sample)
Expand All @@ -31,23 +22,27 @@ def test_basics(protocol_and_port):
assert np.isclose(result, sample).all()


def test_batching(protocol_and_port):
protocol, port = protocol_and_port
print(f"{__name__}, Testing {protocol}")
def test_batching(config):
protocol, port, run_async = config
print(f"{__name__}, Testing {protocol} with run_async={run_async}")

client = InferenceClient.create_with("sample_autobatching", f"{TRITON_HOST}:{port}", protocol=protocol)
client = InferenceClient.create_with(
"sample_autobatching", f"{TRITON_HOST}:{port}", protocol=protocol, run_async=run_async
)

sample = np.random.rand(100, 100).astype(np.float32)
# client automatically makes sub batches with (50, 2, 100)
result = client(sample)
assert np.isclose(result, sample).all()


def test_exception(protocol_and_port):
protocol, port = protocol_and_port
print(f"{__name__}, Testing {protocol}")
def test_exception(config):
protocol, port, run_async = config
print(f"{__name__}, Testing {protocol} with run_async={run_async}")

client = InferenceClient.create_with("sample_autobatching", f"{TRITON_HOST}:{port}", protocol=protocol)
client = InferenceClient.create_with(
"sample_autobatching", f"{TRITON_HOST}:{port}", protocol=protocol, run_async=run_async
)

sample = np.random.rand(100, 100, 100).astype(np.float32)
# client automatically makes sub batches with (50, 2, 100)
Expand Down
45 changes: 15 additions & 30 deletions tests/test_model_call.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
import os

import numpy as np
import pytest

from tritony import InferenceClient

TRITON_HOST = os.environ.get("TRITON_HOST", "localhost")
TRITON_HTTP = os.environ.get("TRITON_HTTP", "8000")
TRITON_GRPC = os.environ.get("TRITON_GRPC", "8001")

from .common_fixtures import TRITON_HOST, config

EPSILON = 1e-8
__all__ = ["config"]


@pytest.fixture(params=[("http", TRITON_HTTP), ("grpc", TRITON_GRPC)])
def protocol_and_port(request):
return request.param

def get_client(protocol, port, run_async, model_name):
print(f"Testing {protocol} with run_async={run_async}", flush=True)
return InferenceClient.create_with(model_name, f"{TRITON_HOST}:{port}", protocol=protocol, run_async=run_async)

def get_client(protocol, port, model_name):
print(f"Testing {protocol}", flush=True)
return InferenceClient.create_with(model_name, f"{TRITON_HOST}:{port}", protocol=protocol)


def test_swithcing(protocol_and_port):
client = get_client(*protocol_and_port, model_name="sample")
def test_swithcing(config):
client = get_client(*config, model_name="sample")

sample = np.random.rand(1, 100).astype(np.float32)
result = client(sample)
Expand All @@ -35,16 +25,16 @@ def test_swithcing(protocol_and_port):
assert np.isclose(result, sample).all()


def test_with_input_name(protocol_and_port):
client = get_client(*protocol_and_port, model_name="sample")
def test_with_input_name(config):
client = get_client(*config, model_name="sample")

sample = np.random.rand(100, 100).astype(np.float32)
result = client({client.default_model_spec.model_input[0].name: sample})
assert np.isclose(result, sample).all()


def test_with_parameters(protocol_and_port):
client = get_client(*protocol_and_port, model_name="sample")
def test_with_parameters(config):
client = get_client(*config, model_name="sample")

sample = np.random.rand(1, 100).astype(np.float32)
ADD_VALUE = 1
Expand All @@ -53,8 +43,8 @@ def test_with_parameters(protocol_and_port):
assert np.isclose(result[0], sample[0] + ADD_VALUE).all()


def test_with_optional(protocol_and_port):
client = get_client(*protocol_and_port, model_name="sample_optional")
def test_with_optional(config):
client = get_client(*config, model_name="sample_optional")

sample = np.random.rand(1, 100).astype(np.float32)

Expand All @@ -71,16 +61,11 @@ def test_with_optional(protocol_and_port):
assert np.isclose(result[0], sample[0] - OPTIONAL_SUB_VALUE, rtol=EPSILON).all()


def test_reload_model_spec(protocol_and_port):
client = get_client(*protocol_and_port, model_name="sample_autobatching")
def test_reload_model_spec(config):
client = get_client(*config, model_name="sample_autobatching")
# force to change max_batch_size
client.default_model_spec.max_batch_size = 4

sample = np.random.rand(8, 100).astype(np.float32)
result = client(sample)
assert np.isclose(result, sample).all()


if __name__ == "__main__":
test_with_parameters(("grpc", "8101"))
test_with_optional(("grpc", "8101"))
10 changes: 5 additions & 5 deletions tritony/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections import defaultdict
from enum import Enum
from types import SimpleNamespace
from typing import Any, Optional, Union
from typing import Any

import attrs
from attrs import define
Expand All @@ -18,7 +18,7 @@ class TritonProtocol(Enum):
http = "http"


COMPRESSION_ALGORITHM_MAP = defaultdict(int)
COMPRESSION_ALGORITHM_MAP: dict[str, int] = defaultdict(int)
COMPRESSION_ALGORITHM_MAP.update({"deflate": 1, "gzip": 2})


Expand Down Expand Up @@ -83,13 +83,13 @@ class TritonClientFlag:
concurrency: int = 6 # only for TritonProtocol.http client
verbose: bool = False
input_dims: int = 1
compression_algorithm: Optional[str] = None
compression_algorithm: str | None = None
ssl: bool = False


def init_triton_client(
flag: TritonClientFlag,
) -> Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient]:
) -> grpcclient.InferenceServerClient | httpclient.InferenceServerClient:
assert not (
flag.streaming and not (flag.protocol is TritonProtocol.grpc)
), "Streaming is only allowed with gRPC protocol"
Expand All @@ -107,7 +107,7 @@ def init_triton_client(


def get_triton_client(
triton_client: Union[grpcclient.InferenceServerClient, httpclient.InferenceServerClient],
triton_client: grpcclient.InferenceServerClient | httpclient.InferenceServerClient,
model_name: str,
model_version: str,
protocol: TritonProtocol,
Expand Down
Loading

0 comments on commit da0998d

Please sign in to comment.