From f0f1d0f8de7f33ba96487ddb6be1cedecc53ffe7 Mon Sep 17 00:00:00 2001 From: martinbrose <13284268+martinbrose@users.noreply.github.com> Date: Sun, 20 Aug 2023 12:50:45 +0100 Subject: [PATCH 1/5] Add text classification to inference client --- docs/source/guides/inference.md | 2 +- src/huggingface_hub/inference/_client.py | 48 ++++++++++++++++++ .../inference/_generated/_async_client.py | 49 +++++++++++++++++++ ...lientVCRTest.test_text_classification.yaml | 48 ++++++++++++++++++ tests/test_inference_client.py | 8 +++ 5 files changed, 154 insertions(+), 1 deletion(-) create mode 100644 tests/cassettes/InferenceClientVCRTest.test_text_classification.yaml diff --git a/docs/source/guides/inference.md b/docs/source/guides/inference.md index 5ee7d7a114..ee01fc1dae 100644 --- a/docs/source/guides/inference.md +++ b/docs/source/guides/inference.md @@ -139,7 +139,7 @@ has a simple API that supports the most common tasks. Here is a list of the curr | | [Sentence Similarity](https://huggingface.co/tasks/sentence-similarity) | ✅ | [`~InferenceClient.sentence_similarity`] | | | [Summarization](https://huggingface.co/tasks/summarization) | ✅ | [`~InferenceClient.summarization`] | | | [Table Question Answering](https://huggingface.co/tasks/table-question-answering) | | | -| | [Text Classification](https://huggingface.co/tasks/text-classification) | | | +| | [Text Classification](https://huggingface.co/tasks/text-classification) | ✅ | [`~InferenceClient.text_classification`] | | | [Text Generation](https://huggingface.co/tasks/text-generation) | ✅ | [`~InferenceClient.text_generation`] | | | [Token Classification](https://huggingface.co/tasks/token-classification) | | | | | [Translation](https://huggingface.co/tasks/translation) | | | diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 15d0c4edff..a4fdcecace 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -763,6 +763,54 @@ def summarization( response = self.post(json=payload, model=model, task="summarization") return _bytes_to_dict(response)[0]["summary_text"] + def text_classification( + self, text: List[str], *, parameters: Optional[Dict[str, Any]] = None, model: Optional[str] = None + ) -> List[ClassificationOutput]: + """ + Perform sentiment-analysis on the given text. + + Args: + text (`str`): + A list of strings to be classified. + parameters (`Dict[str, Any]`, *optional*): + Additional parameters for the text classification task. Defaults to None. For more details about the available + parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#text-classification-task) + model (`str`, *optional*): + The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. + Defaults to None. + + Returns: + `List[Dict]`: a list of dictionaries containing the predicted label and associated probability. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + >>> output = client.text_classification(["I like you", "I love you"]) + >>> output + [[{'label': 'POSITIVE', 'score': 0.9998695850372314}, + {'label': 'NEGATIVE', 'score': 0.0001304351753788069}], + [{'label': 'POSITIVE', 'score': 0.9998656511306763}, + {'label': 'NEGATIVE', 'score': 0.00013436275185085833}]] + ``` + """ + payload: Dict[str, Any] = {"inputs": text} + if parameters is not None: + payload["parameters"] = parameters + response = self.post( + json=payload, + model=model, + task="text-classification", + ) + return _bytes_to_dict(response) + @overload def text_generation( # type: ignore self, diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index f9df5e2644..7b95e7c592 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -770,6 +770,55 @@ async def summarization( response = await self.post(json=payload, model=model, task="summarization") return _bytes_to_dict(response)[0]["summary_text"] + async def text_classification( + self, text: List[str], *, parameters: Optional[Dict[str, Any]] = None, model: Optional[str] = None + ) -> List[ClassificationOutput]: + """ + Perform sentiment-analysis on the given text. + + Args: + text (`str`): + A list of strings to be classified. + parameters (`Dict[str, Any]`, *optional*): + Additional parameters for the text classification task. Defaults to None. For more details about the available + parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#text-classification-task) + model (`str`, *optional*): + The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to + a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. + Defaults to None. + + Returns: + `List[Dict]`: a list of dictionaries containing the predicted label and associated probability. + + Raises: + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `aiohttp.ClientResponseError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + # Must be run in an async context + >>> from huggingface_hub import AsyncInferenceClient + >>> client = AsyncInferenceClient() + >>> output = await client.text_classification(["I like you", "I love you"]) + >>> output + [[{'label': 'POSITIVE', 'score': 0.9998695850372314}, + {'label': 'NEGATIVE', 'score': 0.0001304351753788069}], + [{'label': 'POSITIVE', 'score': 0.9998656511306763}, + {'label': 'NEGATIVE', 'score': 0.00013436275185085833}]] + ``` + """ + payload: Dict[str, Any] = {"inputs": text} + if parameters is not None: + payload["parameters"] = parameters + response = await self.post( + json=payload, + model=model, + task="text-classification", + ) + return _bytes_to_dict(response) + @overload async def text_generation( # type: ignore self, diff --git a/tests/cassettes/InferenceClientVCRTest.test_text_classification.yaml b/tests/cassettes/InferenceClientVCRTest.test_text_classification.yaml new file mode 100644 index 0000000000..bf6d63347b --- /dev/null +++ b/tests/cassettes/InferenceClientVCRTest.test_text_classification.yaml @@ -0,0 +1,48 @@ +interactions: +- request: + body: '{"inputs": ["I like you", "I love you."]}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate, br + Connection: + - keep-alive + Content-Length: + - '41' + Content-Type: + - application/json + X-Amzn-Trace-Id: + - b658f44b-c82c-4a0c-9fc1-c287ea0b66d3 + user-agent: + - unknown/None; hf_hub/0.17.0.dev0; python/3.10.12 + method: POST + uri: https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english + response: + body: + string: '[[{"label":"POSITIVE","score":0.9998695850372314},{"label":"NEGATIVE","score":0.0001304351753788069}],[{"label":"POSITIVE","score":0.9998705387115479},{"label":"NEGATIVE","score":0.00012938841246068478}]]' + headers: + Connection: + - keep-alive + Content-Length: + - '204' + Content-Type: + - application/json + Date: + - Sun, 20 Aug 2023 11:48:55 GMT + access-control-allow-credentials: + - 'true' + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-time: + - '0.033' + x-compute-type: + - cache + x-request-id: + - MiuTWky1u3OlV7JlitniT + x-sha: + - 3d65bad49c7ba6f71920504507a8927f4b9db6c0 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 2fe2b9bdb9..17e413cb95 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -200,6 +200,14 @@ def test_summarization(self) -> None: " surpassed the Washington Monument to become the tallest man-made structure in the world.", ) + def test_text_classification(self) -> None: + output = self.client.text_classification(["I like you", "I love you."]) + self.assertIsInstance(output, list) + self.assertEqual(len(output), 2) + for item in output: + self.assertIsInstance(item[0]["score"], float) + self.assertIsInstance(item[0]["label"], str) + def test_text_generation(self) -> None: """Tested separately in `test_inference_text_generation.py`.""" From 8ed6dd95c8604ccb19a298361acba739935cf276 Mon Sep 17 00:00:00 2001 From: martinbrose <13284268+martinbrose@users.noreply.github.com> Date: Tue, 5 Sep 2023 00:38:43 +0100 Subject: [PATCH 2/5] Address PR review comments --- src/huggingface_hub/inference/_client.py | 20 +++++-------------- .../inference/_generated/_async_client.py | 20 +++++-------------- tests/test_inference_client.py | 9 ++++----- 3 files changed, 14 insertions(+), 35 deletions(-) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index bf17cb31c6..ca6da4a5d2 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -764,18 +764,13 @@ def summarization( response = self.post(json=payload, model=model, task="summarization") return _bytes_to_dict(response)[0]["summary_text"] - def text_classification( - self, text: List[str], *, parameters: Optional[Dict[str, Any]] = None, model: Optional[str] = None - ) -> List[ClassificationOutput]: + def text_classification(self, text: str, *, model: Optional[str] = None) -> ClassificationOutput: """ Perform sentiment-analysis on the given text. Args: text (`str`): - A list of strings to be classified. - parameters (`Dict[str, Any]`, *optional*): - Additional parameters for the text classification task. Defaults to None. For more details about the available - parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#text-classification-task) + A string to be classified. model (`str`, *optional*): The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. @@ -794,23 +789,18 @@ def text_classification( ```py >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() - >>> output = client.text_classification(["I like you", "I love you"]) + >>> output = client.text_classification("I like you") >>> output - [[{'label': 'POSITIVE', 'score': 0.9998695850372314}, - {'label': 'NEGATIVE', 'score': 0.0001304351753788069}], - [{'label': 'POSITIVE', 'score': 0.9998656511306763}, - {'label': 'NEGATIVE', 'score': 0.00013436275185085833}]] + {'label': 'POSITIVE', 'score': 0.9998695850372314} ``` """ payload: Dict[str, Any] = {"inputs": text} - if parameters is not None: - payload["parameters"] = parameters response = self.post( json=payload, model=model, task="text-classification", ) - return _bytes_to_dict(response) + return _bytes_to_dict(response)[0][0] @overload def text_generation( # type: ignore diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index a637dc0cff..05e624f3a2 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -771,18 +771,13 @@ async def summarization( response = await self.post(json=payload, model=model, task="summarization") return _bytes_to_dict(response)[0]["summary_text"] - async def text_classification( - self, text: List[str], *, parameters: Optional[Dict[str, Any]] = None, model: Optional[str] = None - ) -> List[ClassificationOutput]: + async def text_classification(self, text: str, *, model: Optional[str] = None) -> ClassificationOutput: """ Perform sentiment-analysis on the given text. Args: text (`str`): - A list of strings to be classified. - parameters (`Dict[str, Any]`, *optional*): - Additional parameters for the text classification task. Defaults to None. For more details about the available - parameters, please refer to [this page](https://huggingface.co/docs/api-inference/detailed_parameters#text-classification-task) + A string to be classified. model (`str`, *optional*): The model to use for the text classification task. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed Inference Endpoint. If not provided, the default recommended text classification model will be used. @@ -802,23 +797,18 @@ async def text_classification( # Must be run in an async context >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() - >>> output = await client.text_classification(["I like you", "I love you"]) + >>> output = await client.text_classification("I like you") >>> output - [[{'label': 'POSITIVE', 'score': 0.9998695850372314}, - {'label': 'NEGATIVE', 'score': 0.0001304351753788069}], - [{'label': 'POSITIVE', 'score': 0.9998656511306763}, - {'label': 'NEGATIVE', 'score': 0.00013436275185085833}]] + {'label': 'POSITIVE', 'score': 0.9998695850372314} ``` """ payload: Dict[str, Any] = {"inputs": text} - if parameters is not None: - payload["parameters"] = parameters response = await self.post( json=payload, model=model, task="text-classification", ) - return _bytes_to_dict(response) + return _bytes_to_dict(response)[0][0] @overload async def text_generation( # type: ignore diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 83ae6c2f78..427ef13ad1 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -201,12 +201,11 @@ def test_summarization(self) -> None: ) def test_text_classification(self) -> None: - output = self.client.text_classification(["I like you", "I love you."]) - self.assertIsInstance(output, list) + output = self.client.text_classification("I like you") + self.assertIsInstance(output, dict) self.assertEqual(len(output), 2) - for item in output: - self.assertIsInstance(item[0]["score"], float) - self.assertIsInstance(item[0]["label"], str) + self.assertIsInstance(output["score"], float) + self.assertIsInstance(output["label"], str) def test_text_generation(self) -> None: """Tested separately in `test_inference_text_generation.py`.""" From 559872febdbf96e715de92222011afce63adfc3f Mon Sep 17 00:00:00 2001 From: martinbrose <13284268+martinbrose@users.noreply.github.com> Date: Tue, 5 Sep 2023 18:19:06 +0100 Subject: [PATCH 3/5] Return a list of dictionaries and update test --- src/huggingface_hub/inference/_client.py | 4 ++-- src/huggingface_hub/inference/_generated/_async_client.py | 4 ++-- tests/test_inference_client.py | 7 ++++--- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 00f429d147..26cd10e222 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -765,7 +765,7 @@ def summarization( response = self.post(json=payload, model=model, task="summarization") return _bytes_to_dict(response)[0]["summary_text"] - def text_classification(self, text: str, *, model: Optional[str] = None) -> ClassificationOutput: + def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]: """ Perform sentiment-analysis on the given text. @@ -801,7 +801,7 @@ def text_classification(self, text: str, *, model: Optional[str] = None) -> Clas model=model, task="text-classification", ) - return _bytes_to_dict(response)[0][0] + return _bytes_to_list(response) @overload def text_generation( # type: ignore diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index a4ab5d86f9..3d66487e42 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -772,7 +772,7 @@ async def summarization( response = await self.post(json=payload, model=model, task="summarization") return _bytes_to_dict(response)[0]["summary_text"] - async def text_classification(self, text: str, *, model: Optional[str] = None) -> ClassificationOutput: + async def text_classification(self, text: str, *, model: Optional[str] = None) -> List[ClassificationOutput]: """ Perform sentiment-analysis on the given text. @@ -809,7 +809,7 @@ async def text_classification(self, text: str, *, model: Optional[str] = None) - model=model, task="text-classification", ) - return _bytes_to_dict(response)[0][0] + return _bytes_to_list(response) @overload async def text_generation( # type: ignore diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index 427ef13ad1..af9710c354 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -202,10 +202,11 @@ def test_summarization(self) -> None: def test_text_classification(self) -> None: output = self.client.text_classification("I like you") - self.assertIsInstance(output, dict) + self.assertIsInstance(output, list) self.assertEqual(len(output), 2) - self.assertIsInstance(output["score"], float) - self.assertIsInstance(output["label"], str) + for item in output: + self.assertIsInstance(item[0]["score"], float) + self.assertIsInstance(item[0]["label"], str) def test_text_generation(self) -> None: """Tested separately in `test_inference_text_generation.py`.""" From cf8e176e614c829a42b0df9294a6b03220a3f12f Mon Sep 17 00:00:00 2001 From: Lucain Date: Wed, 6 Sep 2023 15:54:33 +0200 Subject: [PATCH 4/5] Apply suggestions from code review --- src/huggingface_hub/inference/_client.py | 12 +++--------- .../inference/_generated/_async_client.py | 12 +++--------- ...erenceClientVCRTest.test_text_classification.yaml | 4 ++-- tests/test_inference_client.py | 4 ++-- 4 files changed, 10 insertions(+), 22 deletions(-) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 26cd10e222..ca08f2dc28 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -791,17 +791,11 @@ def text_classification(self, text: str, *, model: Optional[str] = None) -> List >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> output = client.text_classification("I like you") - >>> output - {'label': 'POSITIVE', 'score': 0.9998695850372314} + [[{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]] ``` """ - payload: Dict[str, Any] = {"inputs": text} - response = self.post( - json=payload, - model=model, - task="text-classification", - ) - return _bytes_to_list(response) + response = self.post(json={"inputs": text}, model=model, task="text-classification") + return _bytes_to_list(response)[0] @overload def text_generation( # type: ignore diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 3d66487e42..70341d6df2 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -799,17 +799,11 @@ async def text_classification(self, text: str, *, model: Optional[str] = None) - >>> from huggingface_hub import AsyncInferenceClient >>> client = AsyncInferenceClient() >>> output = await client.text_classification("I like you") - >>> output - {'label': 'POSITIVE', 'score': 0.9998695850372314} + [{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}] ``` """ - payload: Dict[str, Any] = {"inputs": text} - response = await self.post( - json=payload, - model=model, - task="text-classification", - ) - return _bytes_to_list(response) + response = await self.post(json={"inputs": text}, model=model, task="text-classification") + return _bytes_to_list(response)[0] @overload async def text_generation( # type: ignore diff --git a/tests/cassettes/InferenceClientVCRTest.test_text_classification.yaml b/tests/cassettes/InferenceClientVCRTest.test_text_classification.yaml index bf6d63347b..67e8ac807e 100644 --- a/tests/cassettes/InferenceClientVCRTest.test_text_classification.yaml +++ b/tests/cassettes/InferenceClientVCRTest.test_text_classification.yaml @@ -1,6 +1,6 @@ interactions: - request: - body: '{"inputs": ["I like you", "I love you."]}' + body: '{"inputs": ["I like you"]}' headers: Accept: - '*/*' @@ -20,7 +20,7 @@ interactions: uri: https://api-inference.huggingface.co/models/distilbert-base-uncased-finetuned-sst-2-english response: body: - string: '[[{"label":"POSITIVE","score":0.9998695850372314},{"label":"NEGATIVE","score":0.0001304351753788069}],[{"label":"POSITIVE","score":0.9998705387115479},{"label":"NEGATIVE","score":0.00012938841246068478}]]' + string: '[[{"label":"POSITIVE","score":0.9998695850372314},{"label":"NEGATIVE","score":0.0001304351753788069}]]' headers: Connection: - keep-alive diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index af9710c354..cca0ce9c04 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -205,8 +205,8 @@ def test_text_classification(self) -> None: self.assertIsInstance(output, list) self.assertEqual(len(output), 2) for item in output: - self.assertIsInstance(item[0]["score"], float) - self.assertIsInstance(item[0]["label"], str) + self.assertIsInstance(item["score"], float) + self.assertIsInstance(item["label"], str) def test_text_generation(self) -> None: """Tested separately in `test_inference_text_generation.py`.""" From 5c0c2831450a7a33651d5fd0c5c5976d366f8bf7 Mon Sep 17 00:00:00 2001 From: Lucain Pouget Date: Wed, 6 Sep 2023 15:56:51 +0200 Subject: [PATCH 5/5] make style --- src/huggingface_hub/inference/_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index ca08f2dc28..be04edc324 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -791,7 +791,7 @@ def text_classification(self, text: str, *, model: Optional[str] = None) -> List >>> from huggingface_hub import InferenceClient >>> client = InferenceClient() >>> output = client.text_classification("I like you") - [[{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}]] + [{'label': 'POSITIVE', 'score': 0.9998695850372314}, {'label': 'NEGATIVE', 'score': 0.0001304351753788069}] ``` """ response = self.post(json={"inputs": text}, model=model, task="text-classification")