Skip to content

Commit

Permalink
Http Impl for Hugging Face.
Browse files Browse the repository at this point in the history
  • Loading branch information
rohit-rptless committed Sep 17, 2024
1 parent 2fbbcbc commit a33c79d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 2 deletions.
9 changes: 7 additions & 2 deletions aisuite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,23 @@ class ProviderNames(str, Enum):
AZURE = "azure"
GROQ = "groq"
GOOGLE = "google"
HUGGINGFACE = "huggingface"


class ProviderFactory:
"""Factory to register and create provider instances based on keys."""

_provider_info = {
ProviderNames.OPENAI: (
"aisuite.providers.openai_http_provider",
"aisuite.providers.openai_provider",
"OpenAIProvider",
),
ProviderNames.AWS_BEDROCK: (
"aisuite.providers.aws_bedrock_provider",
"AWSBedrockProvider",
),
ProviderNames.ANTHROPIC: (
"aisuite.providers.anthropic_http_provider",
"aisuite.providers.anthropic_provider",
"AnthropicProvider",
),
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
Expand All @@ -48,6 +49,10 @@ class ProviderFactory:
"aisuite.providers.google_http_provider",
"GoogleHttpProvider",
),
ProviderNames.HUGGINGFACE: (
"aisuite.providers.huggingface_provider",
"HuggingFaceProvider",
),
}

@classmethod
Expand Down
67 changes: 67 additions & 0 deletions aisuite/providers/huggingface_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import httpx
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse


class HuggingFaceProvider(Provider):
"""
HuggingFace Provider using httpx for direct API calls.
Currently, this provider support calls to HF serverless Inference Endpoints
which uses Text Generation Inference (TGI) as the backend.
TGI is OpenAI protocol compliant.
https://huggingface.co/inference-endpoints/
"""

def __init__(self, **config):
"""
Initialize the provider with the given configuration.
The token is fetched from the config or environment variables.
"""
# Ensure API key is provided either in config or via environment variable
self.token = config.get("token") or os.getenv("HUGGINGFACE_TOKEN")
if not self.token:
raise ValueError(
"Hugging Face token is missing. Please provide it in the config or set the HUGGINGFACE_TOKEN environment variable."
)

# Optionally set a custom timeout (default to 30s)
self.timeout = config.get("timeout", 30)

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the Inference API endpoint using httpx.
"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.token}",
}

data = {
"model": model,
"messages": messages,
**kwargs, # Pass any additional arguments to the API
}

url = f"https://api-inference.huggingface.co/models/{model}/v1/chat/completions"
try:
# Make the request to Hugging Face endpoint.
response = httpx.post(url, json=data, headers=headers, timeout=self.timeout)
response.raise_for_status()
except httpx.HTTPStatusError as http_err:
raise LLMError(f"Hugging Face request failed: {http_err}")
except Exception as e:
raise LLMError(f"An error occurred: {e}")

# Return the normalized response
return self._normalize_response(response.json())

def _normalize_response(self, response_data):
"""
Normalize the response to a common format (ChatCompletionResponse).
"""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response_data["choices"][0][
"message"
]["content"]
return normalized_response

0 comments on commit a33c79d

Please sign in to comment.