Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Adrian Gonzalez-Martin committed Aug 5, 2021
1 parent 0335265 commit 6eb964d
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions tests/test_mlserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest
from mlserver.settings import ModelSettings
from mlserver.types import InferenceRequest, RequestInput
from mlserver.utils import to_ndarray
from mlserver.codecs import NumpyCodec
from pytest_cases import fixture, parametrize_with_cases
from pytest_cases.common_pytest_lazy_values import is_lazy

Expand All @@ -16,7 +16,11 @@

@pytest.fixture
def inference_request() -> InferenceRequest:
return InferenceRequest(inputs=[RequestInput(name="payload", shape=[4], data=[1, 2, 3, 4], datatype="FP32")])
return InferenceRequest(
inputs=[
RequestInput(name="payload", shape=[4], data=[1, 2, 3, 4], datatype="FP32")
]
)


@fixture
Expand All @@ -40,15 +44,18 @@ async def test_load(mlserver_runtime: InferenceRuntime):
assert isinstance(mlserver_runtime._model, BaseModel)


async def test_predict(mlserver_runtime: InferenceRuntime, inference_request: InferenceRequest):
async def test_predict(
mlserver_runtime: InferenceRuntime, inference_request: InferenceRequest
):
# NOTE: pytest-cases doesn't wait for async fixtures
# TODO: Raise issue in pytest-cases repo
mlserver_runtime = await mlserver_runtime
res = await mlserver_runtime.predict(inference_request)

assert len(res.outputs) == 1

pipeline_input = to_ndarray(inference_request.inputs[0])
codec = NumpyCodec()
pipeline_input = codec.decode(inference_request.inputs[0])
custom_model = copy.copy(mlserver_runtime._model)
# Ensure direct call to class does not try to do remote
custom_model.set_remote(False)
Expand All @@ -58,11 +65,14 @@ async def test_predict(mlserver_runtime: InferenceRuntime, inference_request: In

pipeline_output = res.outputs[0].data

assert expected_output.tolist() == pipeline_output
assert expected_output.tolist() == pipeline_output.__root__


async def test_load_wrapped_class(inference_pipeline_class, inference_request: InferenceRequest):
pipeline_input = to_ndarray(inference_request.inputs[0])
async def test_load_wrapped_class(
inference_pipeline_class, inference_request: InferenceRequest
):
codec = NumpyCodec()
pipeline_input = codec.decode(inference_request.inputs[0])

inference_pipeline_class(pipeline_input)
assert inference_pipeline_class.counter == 1
Expand Down

0 comments on commit 6eb964d

Please sign in to comment.