From a95500c4a94461120594264e2ca19975bb584c87 Mon Sep 17 00:00:00 2001 From: Adrian Gonzalez-Martin Date: Thu, 6 Apr 2023 12:36:33 +0100 Subject: [PATCH] Improvements for HuggingFace runtime (#1077) --- docs/examples/huggingface/README.ipynb | 93 +++++----- docs/examples/huggingface/README.md | 7 - runtimes/huggingface/README.md | 42 +++++ .../mlserver_huggingface/common.py | 162 ++++-------------- .../mlserver_huggingface/errors.py | 42 +++++ .../mlserver_huggingface/runtime.py | 63 +------ .../mlserver_huggingface/settings.py | 145 ++++++++++++++++ runtimes/huggingface/setup.py | 1 - runtimes/huggingface/tests/conftest.py | 25 +-- runtimes/huggingface/tests/test_common.py | 31 +++- runtimes/huggingface/tests/test_runtime.py | 26 ++- .../huggingface/tests/test_runtime_cases.py | 24 +++ 12 files changed, 389 insertions(+), 272 deletions(-) create mode 100644 runtimes/huggingface/mlserver_huggingface/errors.py create mode 100644 runtimes/huggingface/mlserver_huggingface/settings.py create mode 100644 runtimes/huggingface/tests/test_runtime_cases.py diff --git a/docs/examples/huggingface/README.ipynb b/docs/examples/huggingface/README.ipynb index feaacb0af..d6bae76d4 100644 --- a/docs/examples/huggingface/README.ipynb +++ b/docs/examples/huggingface/README.ipynb @@ -18,7 +18,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 8, "id": "b5b2588c", "metadata": {}, "outputs": [], @@ -42,7 +42,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 9, "id": "6df62443", "metadata": {}, "outputs": [ @@ -59,7 +59,6 @@ "{\n", " \"name\": \"transformer\",\n", " \"implementation\": \"mlserver_huggingface.HuggingFaceRuntime\",\n", - " \"parallel_workers\": 0,\n", " \"parameters\": {\n", " \"extra\": {\n", " \"task\": \"text-generation\",\n", @@ -93,7 +92,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 10, "id": "759ad7df", "metadata": {}, "outputs": [ @@ -101,17 +100,16 @@ "data": { "text/plain": [ "{'model_name': 'transformer',\n", - " 'model_version': None,\n", - " 'id': '9b24304e-730f-4a98-bfde-8949851388a9',\n", - " 'parameters': None,\n", + " 'id': 'eb160c6b-8223-4342-ad92-6ac301a9fa5d',\n", + " 'parameters': {},\n", " 'outputs': [{'name': 'output',\n", - " 'shape': [1],\n", + " 'shape': [1, 1],\n", " 'datatype': 'BYTES',\n", - " 'parameters': None,\n", - " 'data': ['[{\"generated_text\": \"this is a test-case where you\\'re checking if someone\\'s going to have an encrypted file that they like to open, or whether their file has a hidden contents if their file is not opened. If it\\'s the same file, when all the\"}]']}]}" + " 'parameters': {'content_type': 'hg_jsonlist'},\n", + " 'data': ['{\"generated_text\": \"this is a testnet with 1-3,000-bit nodes as nodes.\"}']}]}" ] }, - "execution_count": 22, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -145,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 11, "id": "6d185281", "metadata": {}, "outputs": [ @@ -162,7 +160,6 @@ "{\n", " \"name\": \"transformer\",\n", " \"implementation\": \"mlserver_huggingface.HuggingFaceRuntime\",\n", - " \"parallel_workers\": 0,\n", " \"parameters\": {\n", " \"extra\": {\n", " \"task\": \"text-generation\",\n", @@ -197,7 +194,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 12, "id": "39d8b438", "metadata": {}, "outputs": [ @@ -205,17 +202,16 @@ "data": { "text/plain": [ "{'model_name': 'transformer',\n", - " 'model_version': None,\n", - " 'id': '296ea44e-7696-4584-af5a-148a7083b2e7',\n", - " 'parameters': None,\n", + " 'id': '9c482c8d-b21e-44b1-8a42-7650a9dc01ef',\n", + " 'parameters': {},\n", " 'outputs': [{'name': 'output',\n", - " 'shape': [1],\n", + " 'shape': [1, 1],\n", " 'datatype': 'BYTES',\n", - " 'parameters': None,\n", - " 'data': ['[{\"generated_text\": \"this is a test that allows us to define the value type, and a function is defined directly with these variables.\\\\n\\\\n\\\\nThe function is defined for a parameter with type\\\\nIn this example,\\\\nif you pass a message function like\\\\ntype\"}]']}]}" + " 'parameters': {'content_type': 'hg_jsonlist'},\n", + " 'data': ['{\"generated_text\": \"this is a test of the \\\\\"safe-code-safe-code-safe-code\\\\\" approach. The method only accepts two parameters as parameters: the code. The parameter \\'unsafe-code-safe-code-safe-code\\' should\"}']}]}" ] }, - "execution_count": 23, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -256,7 +252,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "id": "4492dc01", "metadata": {}, "outputs": [ @@ -273,7 +269,6 @@ "{\n", " \"name\": \"transformer\",\n", " \"implementation\": \"mlserver_huggingface.HuggingFaceRuntime\",\n", - " \"parallel_workers\": 0,\n", " \"parameters\": {\n", " \"extra\": {\n", " \"task\": \"question-answering\"\n", @@ -296,7 +291,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "d7aaf365", "metadata": { "scrolled": true @@ -305,18 +300,17 @@ { "data": { "text/plain": [ - "{'model_name': 'gpt2-model',\n", - " 'model_version': None,\n", - " 'id': '204ad4e7-79ea-40b4-8efb-aed16dedf7ed',\n", - " 'parameters': None,\n", + "{'model_name': 'transformer',\n", + " 'id': '4efac938-86d8-41a1-b78f-7690b2dcf197',\n", + " 'parameters': {},\n", " 'outputs': [{'name': 'output',\n", - " 'shape': [1],\n", + " 'shape': [1, 1],\n", " 'datatype': 'BYTES',\n", - " 'parameters': None,\n", - " 'data': ['{\"score\": 0.9869922995567322, \"start\": 12, \"end\": 18, \"answer\": \"Seldon\"}']}]}" + " 'parameters': {'content_type': 'hg_jsonlist'},\n", + " 'data': ['{\"score\": 0.9869915843009949, \"start\": 12, \"end\": 18, \"answer\": \"Seldon\"}']}]}" ] }, - "execution_count": 13, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -352,7 +346,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "8e70c7d7", "metadata": {}, "outputs": [ @@ -369,7 +363,6 @@ "{\n", " \"name\": \"transformer\",\n", " \"implementation\": \"mlserver_huggingface.HuggingFaceRuntime\",\n", - " \"parallel_workers\": 0,\n", " \"parameters\": {\n", " \"extra\": {\n", " \"task\": \"text-classification\"\n", @@ -392,7 +385,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 16, "id": "2f704413", "metadata": { "scrolled": true @@ -402,17 +395,16 @@ "data": { "text/plain": [ "{'model_name': 'transformer',\n", - " 'model_version': None,\n", - " 'id': '463ceddb-f426-4815-9c46-9fa9fc5272b1',\n", - " 'parameters': None,\n", + " 'id': '835eabbd-daeb-4423-a64f-a7c4d7c60a9b',\n", + " 'parameters': {},\n", " 'outputs': [{'name': 'output',\n", - " 'shape': [1],\n", + " 'shape': [1, 1],\n", " 'datatype': 'BYTES',\n", - " 'parameters': None,\n", - " 'data': ['[{\"label\": \"NEGATIVE\", \"score\": 0.9996137022972107}]']}]}" + " 'parameters': {'content_type': 'hg_jsonlist'},\n", + " 'data': ['{\"label\": \"NEGATIVE\", \"score\": 0.9996137022972107}']}]}" ] }, - "execution_count": 19, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -448,7 +440,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 17, "id": "827472eb", "metadata": {}, "outputs": [ @@ -465,7 +457,6 @@ "{\n", " \"name\": \"transformer\",\n", " \"implementation\": \"mlserver_huggingface.HuggingFaceRuntime\",\n", - " \"parallel_workers\": 0,\n", " \"max_batch_size\": 128,\n", " \"max_batch_time\": 1,\n", " \"parameters\": {\n", @@ -491,7 +482,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 19, "id": "888501c1", "metadata": {}, "outputs": [ @@ -499,7 +490,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Elapsed time: 81.57849169999827\n" + "Elapsed time: 66.42268538899953\n" ] } ], @@ -546,7 +537,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 6, "id": "032b8f4e", "metadata": {}, "outputs": [ @@ -563,7 +554,6 @@ "{\n", " \"name\": \"transformer\",\n", " \"implementation\": \"mlserver_huggingface.HuggingFaceRuntime\",\n", - " \"parallel_workers\": 0,\n", " \"parameters\": {\n", " \"extra\": {\n", " \"task\": \"text-generation\",\n", @@ -632,7 +622,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 7, "id": "810a4abe", "metadata": {}, "outputs": [ @@ -649,7 +639,6 @@ "{\n", " \"name\": \"transformer\",\n", " \"implementation\": \"mlserver_huggingface.HuggingFaceRuntime\",\n", - " \"parallel_workers\": 0,\n", " \"max_batch_size\": 128,\n", " \"max_batch_time\": 1,\n", " \"parameters\": {\n", @@ -718,7 +707,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -732,7 +721,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.10" + "version": "3.9.8" } }, "nbformat": 4, diff --git a/docs/examples/huggingface/README.md b/docs/examples/huggingface/README.md index 663b74285..9db1f14b3 100644 --- a/docs/examples/huggingface/README.md +++ b/docs/examples/huggingface/README.md @@ -27,7 +27,6 @@ We will show how to add share a task { "name": "transformer", "implementation": "mlserver_huggingface.HuggingFaceRuntime", - "parallel_workers": 0, "parameters": { "extra": { "task": "text-generation", @@ -76,7 +75,6 @@ We can download pretrained optimized models from the hub if available by enablin { "name": "transformer", "implementation": "mlserver_huggingface.HuggingFaceRuntime", - "parallel_workers": 0, "parameters": { "extra": { "task": "text-generation", @@ -127,7 +125,6 @@ We can support multiple other transformers other than just text generation, belo { "name": "transformer", "implementation": "mlserver_huggingface.HuggingFaceRuntime", - "parallel_workers": 0, "parameters": { "extra": { "task": "question-answering" @@ -172,7 +169,6 @@ requests.post("http://localhost:8080/v2/models/transformer/infer", json=inferenc { "name": "transformer", "implementation": "mlserver_huggingface.HuggingFaceRuntime", - "parallel_workers": 0, "parameters": { "extra": { "task": "text-classification" @@ -217,7 +213,6 @@ We first test the time taken with the device=-1 which configures CPU by default { "name": "transformer", "implementation": "mlserver_huggingface.HuggingFaceRuntime", - "parallel_workers": 0, "max_batch_size": 128, "max_batch_time": 1, "parameters": { @@ -271,7 +266,6 @@ Now we'll run the benchmark with GPU configured, which we can do by setting `dev { "name": "transformer", "implementation": "mlserver_huggingface.HuggingFaceRuntime", - "parallel_workers": 0, "parameters": { "extra": { "task": "text-generation", @@ -319,7 +313,6 @@ We will also configure `max_batch_time` which specifies` the maximum amount of t { "name": "transformer", "implementation": "mlserver_huggingface.HuggingFaceRuntime", - "parallel_workers": 0, "max_batch_size": 128, "max_batch_time": 1, "parameters": { diff --git a/runtimes/huggingface/README.md b/runtimes/huggingface/README.md index 24a5098ef..dfd0bdbfb 100644 --- a/runtimes/huggingface/README.md +++ b/runtimes/huggingface/README.md @@ -12,3 +12,45 @@ pip install mlserver mlserver-huggingface For further information on how to use MLServer with HuggingFace, you can check out this [worked out example](../../docs/examples/huggingface/README.md). + +## Settings + +The HuggingFace runtime exposes a couple extra parameters which can be used to +customise how the runtime behaves. +These settings can be added under the `parameters.extra` section of your +`model-settings.json` file, e.g. + +```{code-block} json +--- +emphasize-lines: 5-8 +--- +{ + "name": "qa", + "implementation": "mlserver_huggingface.HuggingFaceRuntime", + "parameters": { + "extra": { + "task": "question-answering", + "optimum_model": true + } + } +} +``` + +````{note} +These settings can also be injected through environment variables prefixed with `MLSERVER_MODEL_HUGGINGFACE_`, e.g. + +```bash +MLSERVER_MODEL_HUGGINGFACE_TASK="question-answering" +MLSERVER_MODEL_HUGGINGFACE_OPTIMUM_MODEL=true +``` +```` + +### Reference + +You can find the full reference of the accepted extra settings for the +HuggingFace runtime below: + +```{eval-rst} + +.. autopydantic_settings:: mlserver_huggingface.settings.HuggingFaceSettings +``` diff --git a/runtimes/huggingface/mlserver_huggingface/common.py b/runtimes/huggingface/mlserver_huggingface/common.py index a3d0eb5b7..8ba48e22e 100644 --- a/runtimes/huggingface/mlserver_huggingface/common.py +++ b/runtimes/huggingface/mlserver_huggingface/common.py @@ -1,154 +1,60 @@ -import os import json -from typing import Optional, Dict -from distutils.util import strtobool - import numpy as np -from pydantic import BaseSettings -from mlserver.errors import MLServerError + +from typing import Callable +from functools import partial from mlserver.settings import ModelSettings -from transformers.pipelines import pipeline +from optimum.pipelines import pipeline as opt_pipeline +from transformers.pipelines import pipeline as trf_pipeline from transformers.pipelines.base import Pipeline -from transformers.models.auto.tokenization_auto import AutoTokenizer - -try: - # Optimum 1.7 changed the import name from `SUPPORTED_TASKS` to - # `ORT_SUPPORTED_TASKS`. - # We'll try to import the more recent one, falling back to the previous - # import name if not present. - # https://github.com/huggingface/optimum/blob/987b02e4f6e2a1c9325b364ff764da2e57e89902/optimum/pipelines/__init__.py#L18 - from optimum.pipelines import ORT_SUPPORTED_TASKS as SUPPORTED_OPTIMUM_TASKS -except ImportError: - from optimum.pipelines import SUPPORTED_TASKS as SUPPORTED_OPTIMUM_TASKS - - -HUGGINGFACE_TASK_TAG = "task" - -ENV_PREFIX_HUGGINGFACE_SETTINGS = "MLSERVER_MODEL_HUGGINGFACE_" -HUGGINGFACE_PARAMETERS_TAG = "huggingface_parameters" -PARAMETERS_ENV_NAME = "PREDICTIVE_UNIT_PARAMETERS" - - -class InvalidTranformerInitialisation(MLServerError): - def __init__(self, code: int, reason: str): - super().__init__( - f"Huggingface server failed with {code}, {reason}", - status_code=code, - ) - - -class HuggingFaceSettings(BaseSettings): - """ - Parameters that apply only to alibi huggingface models - """ - - class Config: - env_prefix = ENV_PREFIX_HUGGINGFACE_SETTINGS - - task: str = "" - # Why need this filed? - # for translation task, required a suffix to specify source and target - # related issue: https://github.com/SeldonIO/MLServer/issues/947 - task_suffix: str = "" - pretrained_model: Optional[str] = None - pretrained_tokenizer: Optional[str] = None - framework: Optional[str] = None - optimum_model: bool = False - device: int = -1 - - @property - def task_name(self): - if self.task == "translation": - return f"{self.task}{self.task_suffix}" - return self.task - - -def parse_parameters_from_env() -> Dict: - """ - TODO - """ - parameters = json.loads(os.environ.get(PARAMETERS_ENV_NAME, "[]")) - - type_dict = { - "INT": int, - "FLOAT": float, - "DOUBLE": float, - "STRING": str, - "BOOL": bool, - } - - parsed_parameters = {} - for param in parameters: - name = param.get("name") - value = param.get("value") - type_ = param.get("type") - if type_ == "BOOL": - parsed_parameters[name] = bool(strtobool(value)) - else: - try: - parsed_parameters[name] = type_dict[type_](value) - except ValueError: - raise InvalidTranformerInitialisation( - "Bad model parameter: " - + name - + " with value " - + value - + " can't be parsed as a " - + type_, - reason="MICROSERVICE_BAD_PARAMETER", - ) - except KeyError: - raise InvalidTranformerInitialisation( - "Bad model parameter type: " - + type_ - + " valid are INT, FLOAT, DOUBLE, STRING, BOOL", - reason="MICROSERVICE_BAD_PARAMETER", - ) - return parsed_parameters + +from .settings import HuggingFaceSettings + + +OPTIMUM_ACCELERATOR = "ort" + +_PipelineConstructor = Callable[..., Pipeline] def load_pipeline_from_settings( hf_settings: HuggingFaceSettings, settings: ModelSettings ) -> Pipeline: - """ - TODO - """ # TODO: Support URI for locally downloaded artifacts # uri = model_parameters.uri - model = hf_settings.pretrained_model - tokenizer = hf_settings.pretrained_tokenizer - device = hf_settings.device + pipeline = _get_pipeline_class(hf_settings) - if model and not tokenizer: - tokenizer = model + batch_size = 1 + if settings.max_batch_size: + batch_size = settings.max_batch_size - if hf_settings.optimum_model: - optimum_class = SUPPORTED_OPTIMUM_TASKS[hf_settings.task]["class"][0] - model = optimum_class.from_pretrained( - hf_settings.pretrained_model, - from_transformers=True, - ) - tokenizer = AutoTokenizer.from_pretrained(tokenizer) - # Device needs to be set to -1 due to known issue - # https://github.com/huggingface/optimum/issues/191 - device = -1 - - batch_size = 1 if settings.max_batch_size == 0 else settings.max_batch_size - pp = pipeline( + tokenizer = hf_settings.pretrained_tokenizer + if not tokenizer: + tokenizer = hf_settings.pretrained_model + + hf_pipeline = pipeline( hf_settings.task_name, - model=model, + model=hf_settings.pretrained_model, tokenizer=tokenizer, - device=device, + device=hf_settings.device, batch_size=batch_size, framework=hf_settings.framework, ) # If max_batch_size > 0 we need to ensure tokens are padded if settings.max_batch_size: - pp.tokenizer.pad_token_id = [str(pp.model.config.eos_token_id)] # type: ignore + model = hf_pipeline.model + eos_token_id = model.config.eos_token_id + hf_pipeline.tokenizer.pad_token_id = [str(eos_token_id)] # type: ignore + + return hf_pipeline + + +def _get_pipeline_class(hf_settings: HuggingFaceSettings) -> _PipelineConstructor: + if hf_settings.optimum_model: + return partial(opt_pipeline, accelerator=OPTIMUM_ACCELERATOR) - return pp + return trf_pipeline class NumpyEncoder(json.JSONEncoder): diff --git a/runtimes/huggingface/mlserver_huggingface/errors.py b/runtimes/huggingface/mlserver_huggingface/errors.py new file mode 100644 index 000000000..2cca2dca6 --- /dev/null +++ b/runtimes/huggingface/mlserver_huggingface/errors.py @@ -0,0 +1,42 @@ +from typing import List + +from mlserver.errors import MLServerError + + +class MissingHuggingFaceSettings(MLServerError): + def __init__(self): + super().__init__("Missing HuggingFace Runtime settings.") + + +class InvalidTransformersTask(MLServerError): + def __init__(self, task: str, available_tasks: List[str]): + msg = f"Invalid transformer task: {task}. Available tasks: {available_tasks}." + super().__init__(msg) + + +class InvalidOptimumTask(MLServerError): + def __init__(self, task: str, available_tasks: List[str]): + msg = ( + "Invalid transformer task for Optimum model: {task}. " + f"Available Optimum tasks: {available_tasks}." + ) + super().__init__(msg) + + +class InvalidModelParameter(MLServerError): + def __init__(self, name: str, value: str, param_type: str): + msg = ( + f"Bad model parameter: {name}" + f" with value {value}" + f" can't be parsed as a {param_type}" + ) + super().__init__(msg) + + +class InvalidModelParameterType(MLServerError): + def __init__(self, param_type: str): + msg = ( + f"Bad model parameter type: {param_type}." + f" Only valid types are INT, FLOAT, DOUBLE, STRING, BOOL." + ) + super().__init__(msg) diff --git a/runtimes/huggingface/mlserver_huggingface/runtime.py b/runtimes/huggingface/mlserver_huggingface/runtime.py index 624a1a5db..45d4ecf08 100644 --- a/runtimes/huggingface/mlserver_huggingface/runtime.py +++ b/runtimes/huggingface/mlserver_huggingface/runtime.py @@ -1,21 +1,15 @@ import asyncio + from mlserver.model import MLModel from mlserver.settings import ModelSettings +from mlserver.logging import logger from mlserver.types import ( InferenceRequest, InferenceResponse, ) -from transformers.pipelines import SUPPORTED_TASKS - -from mlserver.logging import logger -from .common import ( - HuggingFaceSettings, - parse_parameters_from_env, - InvalidTranformerInitialisation, - load_pipeline_from_settings, - SUPPORTED_OPTIMUM_TASKS, -) +from .settings import get_huggingface_settings +from .common import load_pipeline_from_settings from .codecs import HuggingfaceRequestCodec from .metadata import METADATA @@ -24,66 +18,25 @@ class HuggingFaceRuntime(MLModel): """Runtime class for specific Huggingface models""" def __init__(self, settings: ModelSettings): - env_params = parse_parameters_from_env() - if not env_params and ( - not settings.parameters or not settings.parameters.extra - ): - raise InvalidTranformerInitialisation( - 500, - "Settings parameters not provided via config file nor env variables", - ) - - extra = env_params or settings.parameters.extra # type: ignore - self.hf_settings = HuggingFaceSettings(**extra) # type: ignore - - if self.hf_settings.task not in SUPPORTED_TASKS: - raise InvalidTranformerInitialisation( - 500, - ( - f"Invalid transformer task: {self.hf_settings.task}." - f" Available tasks: {SUPPORTED_TASKS.keys()}" - ), - ) - - if self.hf_settings.optimum_model: - if self.hf_settings.task not in SUPPORTED_OPTIMUM_TASKS: - raise InvalidTranformerInitialisation( - 500, - ( - f"Invalid transformer task for " - f"OPTIMUM model: {self.hf_settings.task}. " - f"Supported Optimum tasks: {SUPPORTED_OPTIMUM_TASKS.keys()}" - ), - ) - + self.hf_settings = get_huggingface_settings(settings) super().__init__(settings) async def load(self) -> bool: # Loading & caching pipeline in asyncio loop to avoid blocking - print("=" * 80) - print(self.hf_settings.task_name) - print("loading model...") + logger.info(f"Loading model for task '{self.hf_settings.task_name}'...") await asyncio.get_running_loop().run_in_executor( None, load_pipeline_from_settings, self.hf_settings, self.settings, ) - print("(re)loading model...") + # Now we load the cached model which should not block asyncio self._model = load_pipeline_from_settings(self.hf_settings, self.settings) self._merge_metadata() - print("model has been loaded!") return True async def predict(self, payload: InferenceRequest) -> InferenceResponse: - """ - TODO - """ - - # Adding some logging as hard to debug given the many types of input accepted - logger.debug("Payload %s", payload) - # TODO: convert and validate? kwargs = self.decode_request(payload, default_codec=HuggingfaceRequestCodec) args = kwargs.pop("args", []) @@ -93,8 +46,6 @@ async def predict(self, payload: InferenceRequest) -> InferenceResponse: args = [list(array_inputs)] + args prediction = self._model(*args, **kwargs) - logger.debug("Prediction %s", prediction) - return self.encode_response( payload=prediction, default_codec=HuggingfaceRequestCodec ) diff --git a/runtimes/huggingface/mlserver_huggingface/settings.py b/runtimes/huggingface/mlserver_huggingface/settings.py new file mode 100644 index 000000000..f46758463 --- /dev/null +++ b/runtimes/huggingface/mlserver_huggingface/settings.py @@ -0,0 +1,145 @@ +import os +import orjson + +from typing import Optional, Dict +from pydantic import BaseSettings +from distutils.util import strtobool +from transformers.pipelines import SUPPORTED_TASKS + +try: + # Optimum 1.7 changed the import name from `SUPPORTED_TASKS` to + # `ORT_SUPPORTED_TASKS`. + # We'll try to import the more recent one, falling back to the previous + # import name if not present. + # https://github.com/huggingface/optimum/blob/987b02e4f6e2a1c9325b364ff764da2e57e89902/optimum/pipelines/__init__.py#L18 + from optimum.pipelines import ORT_SUPPORTED_TASKS as SUPPORTED_OPTIMUM_TASKS +except ImportError: + from optimum.pipelines import SUPPORTED_TASKS as SUPPORTED_OPTIMUM_TASKS + +from mlserver.settings import ModelSettings + +from .errors import ( + MissingHuggingFaceSettings, + InvalidTransformersTask, + InvalidOptimumTask, + InvalidModelParameter, + InvalidModelParameterType, +) + +ENV_PREFIX_HUGGINGFACE_SETTINGS = "MLSERVER_MODEL_HUGGINGFACE_" +PARAMETERS_ENV_NAME = "PREDICTIVE_UNIT_PARAMETERS" + + +class HuggingFaceSettings(BaseSettings): + """ + Parameters that apply only to HuggingFace models + """ + + class Config: + env_prefix = ENV_PREFIX_HUGGINGFACE_SETTINGS + + # TODO: Document fields + task: str = "" + """ + Pipeline task to load. + You can see the available Optimum and Transformers tasks available in the + links below: + + - `Optimum Tasks `_ + - `Transformer Tasks `_ + """ # noqa: E501 + + task_suffix: str = "" + """ + Suffix to append to the base task name. + Useful for, e.g. translation tasks which require a suffix on the task name + to specify source and target. + """ + + pretrained_model: Optional[str] = None + """ + Name of the model that should be loaded in the pipeline. + """ + + pretrained_tokenizer: Optional[str] = None + """ + Name of the tokenizer that should be loaded in the pipeline. + """ + + framework: Optional[str] = None + """ + The framework to use, either "pt" for PyTorch or "tf" for TensorFlow. + """ + + optimum_model: bool = False + """ + Flag to decide whether the pipeline should use a Optimum-optimised model or + the standard Transformers model. + Under the hood, this will enable the model to use the optimised ONNX + runtime. + """ + + device: int = -1 + """ + Device in which this pipeline will be loaded (e.g., "cpu", "cuda:1", "mps", + or a GPU ordinal rank like 1). + """ + + @property + def task_name(self): + if self.task == "translation": + return f"{self.task}{self.task_suffix}" + return self.task + + +def parse_parameters_from_env() -> Dict: + """ + This method parses the environment variables injected via SCv1. + """ + # TODO: Once support for SCv1 is deprecated, we should remove this method and rely + # purely on settings coming via the `model-settings.json` file. + parameters = orjson.loads(os.environ.get(PARAMETERS_ENV_NAME, "[]")) + + type_dict = { + "INT": int, + "FLOAT": float, + "DOUBLE": float, + "STRING": str, + "BOOL": bool, + } + + parsed_parameters = {} + for param in parameters: + name = param.get("name") + value = param.get("value") + type_ = param.get("type") + if type_ == "BOOL": + parsed_parameters[name] = bool(strtobool(value)) + else: + try: + parsed_parameters[name] = type_dict[type_](value) + except ValueError: + raise InvalidModelParameter(name, value, type_) + except KeyError: + raise InvalidModelParameterType(type_) + return parsed_parameters + + +def get_huggingface_settings(model_settings: ModelSettings) -> HuggingFaceSettings: + env_params = parse_parameters_from_env() + if not env_params and ( + not model_settings.parameters or not model_settings.parameters.extra + ): + raise MissingHuggingFaceSettings() + + extra = env_params or model_settings.parameters.extra # type: ignore + hf_settings = HuggingFaceSettings(**extra) # type: ignore + + if hf_settings.task not in SUPPORTED_TASKS: + raise InvalidTransformersTask(hf_settings.task, SUPPORTED_TASKS.keys()) + + if hf_settings.optimum_model: + if hf_settings.task not in SUPPORTED_OPTIMUM_TASKS: + raise InvalidOptimumTask(hf_settings.task, SUPPORTED_OPTIMUM_TASKS.keys()) + + return hf_settings diff --git a/runtimes/huggingface/setup.py b/runtimes/huggingface/setup.py index 9d38081bc..4430b3354 100644 --- a/runtimes/huggingface/setup.py +++ b/runtimes/huggingface/setup.py @@ -36,7 +36,6 @@ def _load_description() -> str: install_requires=[ "mlserver", "optimum[onnxruntime]>=1.4.0, <1.8.0", - "transformers", "Pillow", ], long_description=_load_description(), diff --git a/runtimes/huggingface/tests/conftest.py b/runtimes/huggingface/tests/conftest.py index e278270ea..7b6c07998 100644 --- a/runtimes/huggingface/tests/conftest.py +++ b/runtimes/huggingface/tests/conftest.py @@ -3,9 +3,6 @@ from mlserver.utils import install_uvloop_event_loop from mlserver.types import InferenceRequest, RequestInput -from mlserver.settings import ModelSettings, ModelParameters - -from mlserver_huggingface import HuggingFaceRuntime # test a prediction spend long time, so add this command argument to enable test tasks @@ -14,7 +11,7 @@ def pytest_addoption(parser): parser.addoption("--test-hg-tasks", action="store_true", default=False) -@pytest.fixture(scope="module") +@pytest.fixture() def event_loop(): # NOTE: We need to override the `event_loop` fixture to change its scope to # `module`, so that it can be used downstream on other `module`-scoped @@ -25,26 +22,6 @@ def event_loop(): loop.close() -@pytest.fixture(scope="module") -def model_settings() -> ModelSettings: - return ModelSettings( - name="foo", - implementation=HuggingFaceRuntime, - parameters=ModelParameters( - extra={ - "task": "question-answering", - } - ), - ) - - -@pytest.fixture(scope="module") -async def runtime(model_settings: ModelSettings) -> HuggingFaceRuntime: - runtime = HuggingFaceRuntime(model_settings) - runtime.ready = await runtime.load() - return runtime - - @pytest.fixture def inference_request() -> InferenceRequest: return InferenceRequest( diff --git a/runtimes/huggingface/tests/test_common.py b/runtimes/huggingface/tests/test_common.py index 5c431c43d..4c3a3ca79 100644 --- a/runtimes/huggingface/tests/test_common.py +++ b/runtimes/huggingface/tests/test_common.py @@ -1,6 +1,16 @@ import pytest + from typing import Dict -from mlserver_huggingface.common import HuggingFaceSettings +from optimum.onnxruntime.modeling_ort import ORTModelForQuestionAnswering +from transformers.models.distilbert.modeling_distilbert import ( + DistilBertForQuestionAnswering, +) + +from mlserver.settings import ModelSettings, ModelParameters + +from mlserver_huggingface.runtime import HuggingFaceRuntime +from mlserver_huggingface.settings import HuggingFaceSettings +from mlserver_huggingface.common import load_pipeline_from_settings @pytest.mark.parametrize( @@ -16,3 +26,22 @@ def test_settings_task_name(envs: Dict[str, str], expected: str): setting = HuggingFaceSettings(**envs) assert setting.task_name == expected + + +@pytest.mark.parametrize( + "optimum_model, expected", + [(True, ORTModelForQuestionAnswering), (False, DistilBertForQuestionAnswering)], +) +def test_load_pipeline(optimum_model: bool, expected): + hf_settings = HuggingFaceSettings( + task="question-answering", optimum_model=optimum_model + ) + model_settings = ModelSettings( + name="foo", + implementation=HuggingFaceRuntime, + parameters=ModelParameters(extra=hf_settings.dict()), + ) + + pipeline = load_pipeline_from_settings(hf_settings, model_settings) + + assert isinstance(pipeline.model, expected) diff --git a/runtimes/huggingface/tests/test_runtime.py b/runtimes/huggingface/tests/test_runtime.py index 996284127..125ed80aa 100644 --- a/runtimes/huggingface/tests/test_runtime.py +++ b/runtimes/huggingface/tests/test_runtime.py @@ -1,26 +1,46 @@ import json +from typing import Awaitable from transformers.pipelines.question_answering import QuestionAnsweringPipeline +from pytest_cases import fixture, parametrize_with_cases +from mlserver.settings import ModelSettings from mlserver.types import InferenceRequest from mlserver_huggingface import HuggingFaceRuntime -def test_load(runtime: HuggingFaceRuntime): +@fixture +@parametrize_with_cases("model_settings") +async def future_runtime(model_settings: ModelSettings) -> HuggingFaceRuntime: + # NOTE: The pytest-cases doesn't work too well yet with AsyncIO, therefore + # we need to treat the fixture as an Awaitable and await it in the tests. + # https://github.com/smarie/python-pytest-cases/issues/286 + runtime = HuggingFaceRuntime(model_settings) + runtime.ready = await runtime.load() + return runtime + + +async def test_load(future_runtime: Awaitable[HuggingFaceRuntime]): + runtime = await future_runtime assert runtime.ready assert isinstance(runtime._model, QuestionAnsweringPipeline) -async def test_infer(runtime: HuggingFaceRuntime, inference_request: InferenceRequest): +async def test_infer( + future_runtime: Awaitable[HuggingFaceRuntime], inference_request: InferenceRequest +): + runtime = await future_runtime res = await runtime.predict(inference_request) pred = json.loads(res.outputs[0].data[0]) assert pred["answer"] == "Seldon" async def test_infer_multiple( - runtime: HuggingFaceRuntime, inference_request: InferenceRequest + future_runtime: Awaitable[HuggingFaceRuntime], inference_request: InferenceRequest ): + runtime = await future_runtime + # Send request with two elements for request_input in inference_request.inputs: input_data = request_input.data[0] diff --git a/runtimes/huggingface/tests/test_runtime_cases.py b/runtimes/huggingface/tests/test_runtime_cases.py new file mode 100644 index 000000000..0e6af4934 --- /dev/null +++ b/runtimes/huggingface/tests/test_runtime_cases.py @@ -0,0 +1,24 @@ +from mlserver.settings import ModelSettings, ModelParameters +from mlserver_huggingface import HuggingFaceRuntime + + +def case_optimum_settings() -> ModelSettings: + return ModelSettings( + name="foo", + implementation=HuggingFaceRuntime, + parameters=ModelParameters( + extra={"task": "question-answering", "optimum_model": True} + ), + ) + + +def case_transformers_settings() -> ModelSettings: + return ModelSettings( + name="foo", + implementation=HuggingFaceRuntime, + parameters=ModelParameters( + extra={ + "task": "question-answering", + } + ), + )