Skip to content

Commit

Permalink
Few more HTTP based Provider implementation (#32)
Browse files Browse the repository at this point in the history
* HTTP Impl for Anthropic, OpenAI and Google.

* Http Impl for Hugging Face.

* Update the examples with an example from HF.

* Removing Google Http provider, and using the sdk provider.

* Removed debugging prints, and resolved conflicts.

* Update examples/client.ipynb

---------

Co-authored-by: rohit-rptless <[email protected]>
Co-authored-by: Kevin Solorio <[email protected]>
  • Loading branch information
3 people authored Sep 20, 2024
1 parent e969692 commit f8b921d
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 3 deletions.
7 changes: 6 additions & 1 deletion aisuite/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ProviderNames(str, Enum):
AZURE = "azure"
GROQ = "groq"
GOOGLE = "google"
HUGGINGFACE = "huggingface"
MISTRAL = "mistral"
OLLAMA = "ollama"
OPENAI = "openai"
Expand All @@ -41,8 +42,12 @@ class ProviderFactory:
"AWSBedrockProvider",
),
ProviderNames.AZURE: ("aisuite.providers.azure_provider", "AzureProvider"),
ProviderNames.GROQ: ("aisuite.providers.groq_provider", "GroqProvider"),
ProviderNames.GOOGLE: ("aisuite.providers.google_provider", "GoogleProvider"),
ProviderNames.GROQ: ("aisuite.providers.groq_provider", "GroqProvider"),
ProviderNames.HUGGINGFACE: (
"aisuite.providers.huggingface_provider",
"HuggingFaceProvider",
),
ProviderNames.MISTRAL: (
"aisuite.providers.mistral_provider",
"MistralProvider",
Expand Down
80 changes: 80 additions & 0 deletions aisuite/providers/anthropic_http_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import os
import httpx
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse

# Define a constant for the default max_tokens value
DEFAULT_MAX_TOKENS = 4096


class AnthropicProvider(Provider):
"""
Anthropic Provider using httpx for direct API calls instead of the SDK.
"""

BASE_URL = "https://api.anthropic.com/v1/messages"
API_VERSION = "2023-06-01"

def __init__(self, **config):
"""
Initialize the Anthropic provider with the given configuration.
The API key is fetched from the config or environment variables.
"""
self.api_key = config.get("api_key") or os.getenv("ANTHROPIC_API_KEY")
if not self.api_key:
raise ValueError(
"Anthropic API key is missing. Please provide it in the config or set the ANTHROPIC_API_KEY environment variable."
)

# Optionally set a custom timeout (default to 30s)
self.timeout = config.get("timeout", 30)

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the Anthropic chat completions endpoint using httpx.
"""
headers = {
"x-api-key": self.api_key,
"anthropic-version": self.API_VERSION,
"content-type": "application/json",
}

# Extract and handle system message if present
system_message = None
if messages[0]["role"] == "system":
system_message = messages[0]["content"]
messages = messages[1:]

# Set default max_tokens if not provided
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS)

data = {
"model": model,
"messages": messages,
"system": system_message,
**kwargs, # Pass any additional arguments to the API
}

try:
# Make the request to the Anthropic API
response = httpx.post(
self.BASE_URL, json=data, headers=headers, timeout=self.timeout
)
response.raise_for_status()
except httpx.HTTPStatusError as http_err:
raise LLMError(f"Anthropic API request failed: {http_err}")
except Exception as e:
raise LLMError(f"An error occurred: {e}")

# Return the normalized response
return self._normalize_response(response.json())

def _normalize_response(self, response_data):
"""
Normalize the Anthropic API response to a common format (ChatCompletionResponse).
"""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response_data["content"][0][
"text"
]
return normalized_response
116 changes: 116 additions & 0 deletions aisuite/providers/google_http_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import json
import httpx
import google.auth
from google.auth.transport.requests import Request
from aisuite.provider import Provider
from aisuite.framework import ChatCompletionResponse


class GoogleHttpProvider(Provider):
def __init__(self, **config):
"""
Initializes the GoogleHttpProvider for Gemini.
Checks for either the GCLOUD_APPLICATION_CREDENTIALS environment variable or the GCLOUD_ACCESS_TOKEN environment variable.
Also checks for the ENDPOINT, REGION, and PROJECT_ID values from either config or environment variables.
"""
# Check for GCLOUD_ACCESS_TOKEN environment variable.
# Run `gcloud auth print-access-token` to set this value. Not recommended for production deployment scenarios.
self.access_token = os.environ.get("GCLOUD_ACCESS_TOKEN")
self.project_id = ""

# If no manual token is provided, use google-auth for credentials
if not self.access_token:
if "GCLOUD_APPLICATION_CREDENTIALS" not in os.environ:
raise EnvironmentError(
"Neither 'GCLOUD_ACCESS_TOKEN' nor 'GCLOUD_APPLICATION_CREDENTIALS' is set. "
"Please set 'GCLOUD_ACCESS_TOKEN' by running 'gcloud auth print-access-token' or "
"set 'GCLOUD_APPLICATION_CREDENTIALS' to the path of your service account JSON key file."
)

# Load default credentials and project information from google-auth
self.credentials, self.project_id = google.auth.default()

# Refresh credentials to get the access token
self.credentials.refresh(Request())
self.access_token = self.credentials.token

# Set region, and project_id from config or environment variables
self.region = config.get("region", os.environ.get("GOOGLE_REGION"))
self.project_id = config.get(
"project_id", os.environ.get("GOOGLE_PROJECT_ID", self.project_id)
)

# Validate that all required values are present
if not self.region:
raise ValueError(
"Missing 'region'. Please set the 'REGION' environment variable or provide it in the config."
)
if not self.project_id:
raise ValueError(
"Missing 'project_id'. Please set the 'PROJECT_ID' environment variable or provide it in the config."
)

def chat_completions_create(self, model, messages, **kwargs):
"""
Creates chat completions by sending a request to the Google Cloud API for Gemini.
Adapts the message structure to match Gemini's input format.
"""
url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{model}:generateContent"

contents = []
for message in messages:
role = message["role"]
if role == "system":
role = "user" # Gemini doesn't have a system role, map it to user
elif role == "assistant":
role = "model" # Convert assistant to model

contents.append({"role": role, "parts": [{"text": message["content"]}]})

data = {"contents": contents, **kwargs}

headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.access_token}",
}

try:
with httpx.Client() as client:
resp = client.post(url, json=data, headers=headers, timeout=None)

# Raise for any HTTP error status
resp.raise_for_status()

# Parse the JSON response
resp_json = resp.json()

# Create the single choice with the concatenated message
completion_response = ChatCompletionResponse()
completion_response.choices[0].message.content = resp_json[
"candidates"
][0]["content"]["parts"][0]["text"]

return completion_response

except httpx.HTTPStatusError as e:
# Handle non-2xx HTTP status codes
error_message = f"Request failed with status code {e.response.status_code}: {e.response.text}"
raise Exception(error_message)

except httpx.RequestError as e:
# Handle connection-related errors
error_message = (
f"An error occurred while requesting {e.request.url!r}: {str(e)}"
)
raise Exception(error_message)

except json.JSONDecodeError as e:
# Handle issues with parsing the response
error_message = "Failed to parse JSON response: " + str(e)
raise Exception(error_message)

except Exception as e:
# Catch-all for any other exceptions
error_message = f"An unexpected error occurred: {str(e)}"
raise Exception(error_message)
67 changes: 67 additions & 0 deletions aisuite/providers/huggingface_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import httpx
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse


class HuggingFaceProvider(Provider):
"""
HuggingFace Provider using httpx for direct API calls.
Currently, this provider support calls to HF serverless Inference Endpoints
which uses Text Generation Inference (TGI) as the backend.
TGI is OpenAI protocol compliant.
https://huggingface.co/inference-endpoints/
"""

def __init__(self, **config):
"""
Initialize the provider with the given configuration.
The token is fetched from the config or environment variables.
"""
# Ensure API key is provided either in config or via environment variable
self.token = config.get("token") or os.getenv("HUGGINGFACE_TOKEN")
if not self.token:
raise ValueError(
"Hugging Face token is missing. Please provide it in the config or set the HUGGINGFACE_TOKEN environment variable."
)

# Optionally set a custom timeout (default to 30s)
self.timeout = config.get("timeout", 30)

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the Inference API endpoint using httpx.
"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.token}",
}

data = {
"model": model,
"messages": messages,
**kwargs, # Pass any additional arguments to the API
}

url = f"https://api-inference.huggingface.co/models/{model}/v1/chat/completions"
try:
# Make the request to Hugging Face endpoint.
response = httpx.post(url, json=data, headers=headers, timeout=self.timeout)
response.raise_for_status()
except httpx.HTTPStatusError as http_err:
raise LLMError(f"Hugging Face request failed: {http_err}")
except Exception as e:
raise LLMError(f"An error occurred: {e}")

# Return the normalized response
return self._normalize_response(response.json())

def _normalize_response(self, response_data):
"""
Normalize the response to a common format (ChatCompletionResponse).
"""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response_data["choices"][0][
"message"
]["content"]
return normalized_response
66 changes: 66 additions & 0 deletions aisuite/providers/openai_http_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import os
import httpx
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse


class OpenAIProvider(Provider):
"""
OpenAI Provider using httpx for direct API calls instead of the OpenAI SDK.
"""

BASE_URL = "https://api.openai.com/v1/chat/completions"

def __init__(self, **config):
"""
Initialize the OpenAI provider with the given configuration.
The API key is fetched from the config or environment variables.
"""
# Ensure API key is provided either in config or via environment variable
self.api_key = config.get("api_key") or os.getenv("OPENAI_API_KEY")
if not self.api_key:
raise ValueError(
"OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable."
)

# Optionally set a custom timeout (default to 30s)
self.timeout = config.get("timeout", 30)

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the OpenAI chat completions endpoint using httpx.
"""
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}

data = {
"model": model,
"messages": messages,
**kwargs, # Pass any additional arguments to the API
}

try:
# Make the request to OpenAI API
response = httpx.post(
self.BASE_URL, json=data, headers=headers, timeout=self.timeout
)
response.raise_for_status()
except httpx.HTTPStatusError as http_err:
raise LLMError(f"OpenAI API request failed: {http_err}")
except Exception as e:
raise LLMError(f"An error occurred: {e}")

# Return the normalized response
return self._normalize_response(response.json())

def _normalize_response(self, response_data):
"""
Normalize the OpenAI API response to a common format (ChatCompletionResponse).
"""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response_data["choices"][0][
"message"
]["content"]
return normalized_response
21 changes: 19 additions & 2 deletions examples/client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
"\n",
"client = ai.Client()\n",
"messages = [\n",
" {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n",
" {\"role\": \"system\", \"content\": \"Respond in Pirate English. Always try to include the phrase - No rum No fun.\"},\n",
" {\"role\": \"user\", \"content\": \"Tell me a joke about Captain Jack Sparrow\"},\n",
"]"
]
Expand Down Expand Up @@ -126,7 +126,11 @@
"metadata": {},
"outputs": [],
"source": [
"client2 = ai.Client({\"azure\" : {\n",
"# client2 = ai.Client({\"azure\" : {\n",
"# \"api_key\": os.environ[\"AZURE_API_KEY\"],\n",
"# }});\n",
"client2 = ai.Client()\n",
"client2.configure({\"azure\" : {\n",
" \"api_key\": os.environ[\"AZURE_API_KEY\"],\n",
" \"base_url\": \"https://mistral-large-2407.westus3.models.ai.azure.com/v1/\",\n",
"}});\n",
Expand All @@ -135,6 +139,19 @@
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f996b121",
"metadata": {},
"outputs": [],
"source": [
"client3 = ai.Client()\n",
"hf_model = \"huggingface:mistralai/Mistral-7B-Instruct-v0.3\"\n",
"response = client3.chat.completions.create(model=hf_model, messages=messages)\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit f8b921d

Please sign in to comment.