Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Distinguish _bytes_to_dict and _bytes_to_list + fix issues #1641

Merged
merged 1 commit into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
_b64_to_image,
_bytes_to_dict,
_bytes_to_image,
_bytes_to_list,
_get_recommended_model,
_import_numpy,
_is_tgi_server,
Expand Down Expand Up @@ -284,7 +285,7 @@ def audio_classification(
```
"""
response = self.post(data=audio, model=model, task="audio-classification")
return _bytes_to_dict(response)
return _bytes_to_list(response)

def automatic_speech_recognition(
self,
Expand Down Expand Up @@ -381,7 +382,7 @@ def conversational(
if parameters is not None:
payload["parameters"] = parameters
response = self.post(json=payload, model=model, task="conversational")
return _bytes_to_dict(response)
return _bytes_to_dict(response) # type: ignore

def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
"""
Expand Down Expand Up @@ -453,7 +454,7 @@ def image_classification(
```
"""
response = self.post(data=image, model=model, task="image-classification")
return _bytes_to_dict(response)
return _bytes_to_list(response)

def image_segmentation(
self,
Expand Down Expand Up @@ -719,7 +720,7 @@ def sentence_similarity(
model=model,
task="sentence-similarity",
)
return _bytes_to_dict(response)
return _bytes_to_list(response)

def summarization(
self,
Expand Down Expand Up @@ -1285,7 +1286,7 @@ def zero_shot_image_classification(
model=model,
task="zero-shot-image-classification",
)
return _bytes_to_dict(response)
return _bytes_to_list(response)

def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
model = model or self.model
Expand Down
18 changes: 16 additions & 2 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,24 @@ def _b64_to_image(encoded_image: str) -> "Image":
return Image.open(io.BytesIO(base64.b64decode(encoded_image)))


def _bytes_to_dict(content: bytes) -> "Image":
def _bytes_to_list(content: bytes) -> List:
"""Parse bytes from a Response object into a Python list.

Expects the response body to be JSON-encoded data.

NOTE: This is exactly the same implementation as `_bytes_to_dict` and will not complain if the returned data is a
dictionary. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
"""
return json.loads(content.decode())


def _bytes_to_dict(content: bytes) -> Dict:
"""Parse bytes from a Response object into a Python dictionary.

Expects the response body to be encoded-JSON data.
Expects the response body to be JSON-encoded data.

NOTE: This is exactly the same implementation as `_bytes_to_list` and will not complain if the returned data is a
list. The only advantage of having both is to help the user (and mypy) understand what kind of data to expect.
"""
return json.loads(content.decode())

Expand Down
11 changes: 6 additions & 5 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_b64_to_image,
_bytes_to_dict,
_bytes_to_image,
_bytes_to_list,
_get_recommended_model,
_import_numpy,
_is_tgi_server,
Expand Down Expand Up @@ -281,7 +282,7 @@ async def audio_classification(
```
"""
response = await self.post(data=audio, model=model, task="audio-classification")
return _bytes_to_dict(response)
return _bytes_to_list(response)

async def automatic_speech_recognition(
self,
Expand Down Expand Up @@ -380,7 +381,7 @@ async def conversational(
if parameters is not None:
payload["parameters"] = parameters
response = await self.post(json=payload, model=model, task="conversational")
return _bytes_to_dict(response)
return _bytes_to_dict(response) # type: ignore

async def feature_extraction(self, text: str, *, model: Optional[str] = None) -> "np.ndarray":
"""
Expand Down Expand Up @@ -454,7 +455,7 @@ async def image_classification(
```
"""
response = await self.post(data=image, model=model, task="image-classification")
return _bytes_to_dict(response)
return _bytes_to_list(response)

async def image_segmentation(
self,
Expand Down Expand Up @@ -725,7 +726,7 @@ async def sentence_similarity(
model=model,
task="sentence-similarity",
)
return _bytes_to_dict(response)
return _bytes_to_list(response)

async def summarization(
self,
Expand Down Expand Up @@ -1297,7 +1298,7 @@ async def zero_shot_image_classification(
model=model,
task="zero-shot-image-classification",
)
return _bytes_to_dict(response)
return _bytes_to_list(response)

def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None) -> str:
model = model or self.model
Expand Down