-
Notifications
You must be signed in to change notification settings - Fork 886
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Few more HTTP based Provider implementation (#32)
* HTTP Impl for Anthropic, OpenAI and Google. * Http Impl for Hugging Face. * Update the examples with an example from HF. * Removing Google Http provider, and using the sdk provider. * Removed debugging prints, and resolved conflicts. * Update examples/client.ipynb --------- Co-authored-by: rohit-rptless <[email protected]> Co-authored-by: Kevin Solorio <[email protected]>
- Loading branch information
1 parent
e969692
commit f8b921d
Showing
6 changed files
with
354 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import os | ||
import httpx | ||
from aisuite.provider import Provider, LLMError | ||
from aisuite.framework import ChatCompletionResponse | ||
|
||
# Define a constant for the default max_tokens value | ||
DEFAULT_MAX_TOKENS = 4096 | ||
|
||
|
||
class AnthropicProvider(Provider): | ||
""" | ||
Anthropic Provider using httpx for direct API calls instead of the SDK. | ||
""" | ||
|
||
BASE_URL = "https://api.anthropic.com/v1/messages" | ||
API_VERSION = "2023-06-01" | ||
|
||
def __init__(self, **config): | ||
""" | ||
Initialize the Anthropic provider with the given configuration. | ||
The API key is fetched from the config or environment variables. | ||
""" | ||
self.api_key = config.get("api_key") or os.getenv("ANTHROPIC_API_KEY") | ||
if not self.api_key: | ||
raise ValueError( | ||
"Anthropic API key is missing. Please provide it in the config or set the ANTHROPIC_API_KEY environment variable." | ||
) | ||
|
||
# Optionally set a custom timeout (default to 30s) | ||
self.timeout = config.get("timeout", 30) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
""" | ||
Makes a request to the Anthropic chat completions endpoint using httpx. | ||
""" | ||
headers = { | ||
"x-api-key": self.api_key, | ||
"anthropic-version": self.API_VERSION, | ||
"content-type": "application/json", | ||
} | ||
|
||
# Extract and handle system message if present | ||
system_message = None | ||
if messages[0]["role"] == "system": | ||
system_message = messages[0]["content"] | ||
messages = messages[1:] | ||
|
||
# Set default max_tokens if not provided | ||
kwargs.setdefault("max_tokens", DEFAULT_MAX_TOKENS) | ||
|
||
data = { | ||
"model": model, | ||
"messages": messages, | ||
"system": system_message, | ||
**kwargs, # Pass any additional arguments to the API | ||
} | ||
|
||
try: | ||
# Make the request to the Anthropic API | ||
response = httpx.post( | ||
self.BASE_URL, json=data, headers=headers, timeout=self.timeout | ||
) | ||
response.raise_for_status() | ||
except httpx.HTTPStatusError as http_err: | ||
raise LLMError(f"Anthropic API request failed: {http_err}") | ||
except Exception as e: | ||
raise LLMError(f"An error occurred: {e}") | ||
|
||
# Return the normalized response | ||
return self._normalize_response(response.json()) | ||
|
||
def _normalize_response(self, response_data): | ||
""" | ||
Normalize the Anthropic API response to a common format (ChatCompletionResponse). | ||
""" | ||
normalized_response = ChatCompletionResponse() | ||
normalized_response.choices[0].message.content = response_data["content"][0][ | ||
"text" | ||
] | ||
return normalized_response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
import json | ||
import httpx | ||
import google.auth | ||
from google.auth.transport.requests import Request | ||
from aisuite.provider import Provider | ||
from aisuite.framework import ChatCompletionResponse | ||
|
||
|
||
class GoogleHttpProvider(Provider): | ||
def __init__(self, **config): | ||
""" | ||
Initializes the GoogleHttpProvider for Gemini. | ||
Checks for either the GCLOUD_APPLICATION_CREDENTIALS environment variable or the GCLOUD_ACCESS_TOKEN environment variable. | ||
Also checks for the ENDPOINT, REGION, and PROJECT_ID values from either config or environment variables. | ||
""" | ||
# Check for GCLOUD_ACCESS_TOKEN environment variable. | ||
# Run `gcloud auth print-access-token` to set this value. Not recommended for production deployment scenarios. | ||
self.access_token = os.environ.get("GCLOUD_ACCESS_TOKEN") | ||
self.project_id = "" | ||
|
||
# If no manual token is provided, use google-auth for credentials | ||
if not self.access_token: | ||
if "GCLOUD_APPLICATION_CREDENTIALS" not in os.environ: | ||
raise EnvironmentError( | ||
"Neither 'GCLOUD_ACCESS_TOKEN' nor 'GCLOUD_APPLICATION_CREDENTIALS' is set. " | ||
"Please set 'GCLOUD_ACCESS_TOKEN' by running 'gcloud auth print-access-token' or " | ||
"set 'GCLOUD_APPLICATION_CREDENTIALS' to the path of your service account JSON key file." | ||
) | ||
|
||
# Load default credentials and project information from google-auth | ||
self.credentials, self.project_id = google.auth.default() | ||
|
||
# Refresh credentials to get the access token | ||
self.credentials.refresh(Request()) | ||
self.access_token = self.credentials.token | ||
|
||
# Set region, and project_id from config or environment variables | ||
self.region = config.get("region", os.environ.get("GOOGLE_REGION")) | ||
self.project_id = config.get( | ||
"project_id", os.environ.get("GOOGLE_PROJECT_ID", self.project_id) | ||
) | ||
|
||
# Validate that all required values are present | ||
if not self.region: | ||
raise ValueError( | ||
"Missing 'region'. Please set the 'REGION' environment variable or provide it in the config." | ||
) | ||
if not self.project_id: | ||
raise ValueError( | ||
"Missing 'project_id'. Please set the 'PROJECT_ID' environment variable or provide it in the config." | ||
) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
""" | ||
Creates chat completions by sending a request to the Google Cloud API for Gemini. | ||
Adapts the message structure to match Gemini's input format. | ||
""" | ||
url = f"https://{self.region}-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.region}/publishers/google/models/{model}:generateContent" | ||
|
||
contents = [] | ||
for message in messages: | ||
role = message["role"] | ||
if role == "system": | ||
role = "user" # Gemini doesn't have a system role, map it to user | ||
elif role == "assistant": | ||
role = "model" # Convert assistant to model | ||
|
||
contents.append({"role": role, "parts": [{"text": message["content"]}]}) | ||
|
||
data = {"contents": contents, **kwargs} | ||
|
||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {self.access_token}", | ||
} | ||
|
||
try: | ||
with httpx.Client() as client: | ||
resp = client.post(url, json=data, headers=headers, timeout=None) | ||
|
||
# Raise for any HTTP error status | ||
resp.raise_for_status() | ||
|
||
# Parse the JSON response | ||
resp_json = resp.json() | ||
|
||
# Create the single choice with the concatenated message | ||
completion_response = ChatCompletionResponse() | ||
completion_response.choices[0].message.content = resp_json[ | ||
"candidates" | ||
][0]["content"]["parts"][0]["text"] | ||
|
||
return completion_response | ||
|
||
except httpx.HTTPStatusError as e: | ||
# Handle non-2xx HTTP status codes | ||
error_message = f"Request failed with status code {e.response.status_code}: {e.response.text}" | ||
raise Exception(error_message) | ||
|
||
except httpx.RequestError as e: | ||
# Handle connection-related errors | ||
error_message = ( | ||
f"An error occurred while requesting {e.request.url!r}: {str(e)}" | ||
) | ||
raise Exception(error_message) | ||
|
||
except json.JSONDecodeError as e: | ||
# Handle issues with parsing the response | ||
error_message = "Failed to parse JSON response: " + str(e) | ||
raise Exception(error_message) | ||
|
||
except Exception as e: | ||
# Catch-all for any other exceptions | ||
error_message = f"An unexpected error occurred: {str(e)}" | ||
raise Exception(error_message) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import os | ||
import httpx | ||
from aisuite.provider import Provider, LLMError | ||
from aisuite.framework import ChatCompletionResponse | ||
|
||
|
||
class HuggingFaceProvider(Provider): | ||
""" | ||
HuggingFace Provider using httpx for direct API calls. | ||
Currently, this provider support calls to HF serverless Inference Endpoints | ||
which uses Text Generation Inference (TGI) as the backend. | ||
TGI is OpenAI protocol compliant. | ||
https://huggingface.co/inference-endpoints/ | ||
""" | ||
|
||
def __init__(self, **config): | ||
""" | ||
Initialize the provider with the given configuration. | ||
The token is fetched from the config or environment variables. | ||
""" | ||
# Ensure API key is provided either in config or via environment variable | ||
self.token = config.get("token") or os.getenv("HUGGINGFACE_TOKEN") | ||
if not self.token: | ||
raise ValueError( | ||
"Hugging Face token is missing. Please provide it in the config or set the HUGGINGFACE_TOKEN environment variable." | ||
) | ||
|
||
# Optionally set a custom timeout (default to 30s) | ||
self.timeout = config.get("timeout", 30) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
""" | ||
Makes a request to the Inference API endpoint using httpx. | ||
""" | ||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {self.token}", | ||
} | ||
|
||
data = { | ||
"model": model, | ||
"messages": messages, | ||
**kwargs, # Pass any additional arguments to the API | ||
} | ||
|
||
url = f"https://api-inference.huggingface.co/models/{model}/v1/chat/completions" | ||
try: | ||
# Make the request to Hugging Face endpoint. | ||
response = httpx.post(url, json=data, headers=headers, timeout=self.timeout) | ||
response.raise_for_status() | ||
except httpx.HTTPStatusError as http_err: | ||
raise LLMError(f"Hugging Face request failed: {http_err}") | ||
except Exception as e: | ||
raise LLMError(f"An error occurred: {e}") | ||
|
||
# Return the normalized response | ||
return self._normalize_response(response.json()) | ||
|
||
def _normalize_response(self, response_data): | ||
""" | ||
Normalize the response to a common format (ChatCompletionResponse). | ||
""" | ||
normalized_response = ChatCompletionResponse() | ||
normalized_response.choices[0].message.content = response_data["choices"][0][ | ||
"message" | ||
]["content"] | ||
return normalized_response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import os | ||
import httpx | ||
from aisuite.provider import Provider, LLMError | ||
from aisuite.framework import ChatCompletionResponse | ||
|
||
|
||
class OpenAIProvider(Provider): | ||
""" | ||
OpenAI Provider using httpx for direct API calls instead of the OpenAI SDK. | ||
""" | ||
|
||
BASE_URL = "https://api.openai.com/v1/chat/completions" | ||
|
||
def __init__(self, **config): | ||
""" | ||
Initialize the OpenAI provider with the given configuration. | ||
The API key is fetched from the config or environment variables. | ||
""" | ||
# Ensure API key is provided either in config or via environment variable | ||
self.api_key = config.get("api_key") or os.getenv("OPENAI_API_KEY") | ||
if not self.api_key: | ||
raise ValueError( | ||
"OpenAI API key is missing. Please provide it in the config or set the OPENAI_API_KEY environment variable." | ||
) | ||
|
||
# Optionally set a custom timeout (default to 30s) | ||
self.timeout = config.get("timeout", 30) | ||
|
||
def chat_completions_create(self, model, messages, **kwargs): | ||
""" | ||
Makes a request to the OpenAI chat completions endpoint using httpx. | ||
""" | ||
headers = { | ||
"Content-Type": "application/json", | ||
"Authorization": f"Bearer {self.api_key}", | ||
} | ||
|
||
data = { | ||
"model": model, | ||
"messages": messages, | ||
**kwargs, # Pass any additional arguments to the API | ||
} | ||
|
||
try: | ||
# Make the request to OpenAI API | ||
response = httpx.post( | ||
self.BASE_URL, json=data, headers=headers, timeout=self.timeout | ||
) | ||
response.raise_for_status() | ||
except httpx.HTTPStatusError as http_err: | ||
raise LLMError(f"OpenAI API request failed: {http_err}") | ||
except Exception as e: | ||
raise LLMError(f"An error occurred: {e}") | ||
|
||
# Return the normalized response | ||
return self._normalize_response(response.json()) | ||
|
||
def _normalize_response(self, response_data): | ||
""" | ||
Normalize the OpenAI API response to a common format (ChatCompletionResponse). | ||
""" | ||
normalized_response = ChatCompletionResponse() | ||
normalized_response.choices[0].message.content = response_data["choices"][0][ | ||
"message" | ||
]["content"] | ||
return normalized_response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters