Skip to content

Commit

Permalink
Merge branch 'main' into rp_ConvBasedProviderLoad
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitprasad15 authored Oct 3, 2024
2 parents 22be8fc + ad46773 commit a49c100
Show file tree
Hide file tree
Showing 11 changed files with 174 additions and 256 deletions.
6 changes: 0 additions & 6 deletions aisuite/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
"""Provides the individual provider interfaces for each FM provider."""

from .fireworks_interface import FireworksInterface
from .octo_interface import OctoInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
35 changes: 0 additions & 35 deletions aisuite/providers/fireworks_interface.py

This file was deleted.

65 changes: 65 additions & 0 deletions aisuite/providers/fireworks_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import httpx
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse


class FireworksProvider(Provider):
"""
Fireworks AI Provider using httpx for direct API calls.
"""

BASE_URL = "https://api.fireworks.ai/inference/v1/chat/completions"

def __init__(self, **config):
"""
Initialize the Fireworks provider with the given configuration.
The API key is fetched from the config or environment variables.
"""
self.api_key = config.get("api_key", os.getenv("FIREWORKS_API_KEY"))
if not self.api_key:
raise ValueError(
"Fireworks API key is missing. Please provide it in the config or set the FIREWORKS_API_KEY environment variable."
)

# Optionally set a custom timeout (default to 30s)
self.timeout = config.get("timeout", 30)

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the Fireworks AI chat completions endpoint using httpx.
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}

data = {
"model": model,
"messages": messages,
**kwargs, # Pass any additional arguments to the API
}

try:
# Make the request to Fireworks AI endpoint.
response = httpx.post(
self.BASE_URL, json=data, headers=headers, timeout=self.timeout
)
response.raise_for_status()
except httpx.HTTPStatusError as http_err:
raise LLMError(f"Fireworks AI request failed: {http_err}")
except Exception as e:
raise LLMError(f"An error occurred: {e}")

# Return the normalized response
return self._normalize_response(response.json())

def _normalize_response(self, response_data):
"""
Normalize the response to a common format (ChatCompletionResponse).
"""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response_data["choices"][0][
"message"
]["content"]
return normalized_response
40 changes: 0 additions & 40 deletions aisuite/providers/octo_interface.py

This file was deleted.

40 changes: 0 additions & 40 deletions aisuite/providers/replicate_interface.py

This file was deleted.

40 changes: 0 additions & 40 deletions aisuite/providers/together_interface.py

This file was deleted.

65 changes: 65 additions & 0 deletions aisuite/providers/together_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import httpx
from aisuite.provider import Provider, LLMError
from aisuite.framework import ChatCompletionResponse


class TogetherProvider(Provider):
"""
Together AI Provider using httpx for direct API calls.
"""

BASE_URL = "https://api.together.xyz/v1/chat/completions"

def __init__(self, **config):
"""
Initialize the Fireworks provider with the given configuration.
The API key is fetched from the config or environment variables.
"""
self.api_key = config.get("api_key", os.getenv("TOGETHER_API_KEY"))
if not self.api_key:
raise ValueError(
"Together API key is missing. Please provide it in the config or set the TOGETHER_API_KEY environment variable."
)

# Optionally set a custom timeout (default to 30s)
self.timeout = config.get("timeout", 30)

def chat_completions_create(self, model, messages, **kwargs):
"""
Makes a request to the Fireworks AI chat completions endpoint using httpx.
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}

data = {
"model": model,
"messages": messages,
**kwargs, # Pass any additional arguments to the API
}

try:
# Make the request to Fireworks AI endpoint.
response = httpx.post(
self.BASE_URL, json=data, headers=headers, timeout=self.timeout
)
response.raise_for_status()
except httpx.HTTPStatusError as http_err:
raise LLMError(f"Together AI request failed: {http_err}")
except Exception as e:
raise LLMError(f"An error occurred: {e}")

# Return the normalized response
return self._normalize_response(response.json())

def _normalize_response(self, response_data):
"""
Normalize the response to a common format (ChatCompletionResponse).
"""
normalized_response = ChatCompletionResponse()
normalized_response.choices[0].message.content = response_data["choices"][0][
"message"
]["content"]
return normalized_response
35 changes: 28 additions & 7 deletions examples/client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@
"metadata": {},
"outputs": [],
"source": [
"aws_bedrock_llama3_8b = \"aws-bedrock:meta.llama3-1-8b-instruct-v1:0\"\n",
"aws_bedrock_llama3_8b = \"aws:meta.llama3-1-8b-instruct-v1:0\"\n",
"response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n",
"print(response.choices[0].message.content)"
]
Expand All @@ -124,13 +124,12 @@
"# The model name is the deployment name in Project/Deployments.\n",
"# In the exmaple below, the model is \"mistral-large-2407\", but the name given to the\n",
"# deployment is \"aisuite-mistral-large-2407\" under the deployments section in Azure.\n",
"client2 = ai.Client()\n",
"client2.configure({\"azure\" : {\n",
"client.configure({\"azure\" : {\n",
" \"api_key\": os.environ[\"AZURE_API_KEY\"],\n",
" \"base_url\": \"https://aisuite-mistral-large-2407.westus3.models.ai.azure.com/v1/\",\n",
"}});\n",
"azure_model = \"azure:aisuite-mistral-large-2407\"\n",
"response = client2.chat.completions.create(model=azure_model, messages=messages)\n",
"response = client.chat.completions.create(model=azure_model, messages=messages)\n",
"print(response.choices[0].message.content)"
]
},
Expand All @@ -145,9 +144,8 @@
"# The model name is the full name of the model in HuggingFace.\n",
"# In the exmaple below, the model is \"mistralai/Mistral-7B-Instruct-v0.3\".\n",
"# The model is deployed as serverless inference endpoint in HuggingFace.\n",
"client3 = ai.Client()\n",
"hf_model = \"huggingface:mistralai/Mistral-7B-Instruct-v0.3\"\n",
"response = client3.chat.completions.create(model=hf_model, messages=messages)\n",
"response = client.chat.completions.create(model=hf_model, messages=messages)\n",
"print(response.choices[0].message.content)"
]
},
Expand All @@ -164,7 +162,6 @@
"# In the exmaple below, the model is \"llama3-8b-8192\".\n",
"groq_llama3_8b = \"groq:llama3-8b-8192\"\n",
"# groq_llama3_70b = \"groq:llama3-70b-8192\"\n",
"\n",
"response = client.chat.completions.create(model=groq_llama3_8b, messages=messages)\n",
"print(response.choices[0].message.content)"
]
Expand Down Expand Up @@ -210,6 +207,30 @@
"response = client.chat.completions.create(model=openai_gpt35, messages=messages, temperature=0.75)\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "321783ae",
"metadata": {},
"outputs": [],
"source": [
"fireworks_model = \"fireworks:accounts/fireworks/models/llama-v3p2-3b-instruct\"\n",
"response = client.chat.completions.create(model=fireworks_model, messages=messages, temperature=0.75, presence_penalty=0.5, frequency_penalty=0.5)\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e30e5ae0",
"metadata": {},
"outputs": [],
"source": [
"togetherai_model = \"together:meta-llama/Llama-3.2-3B-Instruct-Turbo\"\n",
"response = client.chat.completions.create(model=togetherai_model, messages=messages, temperature=0.75, top_p=0.7, top_k=50)\n",
"print(response.choices[0].message.content)"
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit a49c100

Please sign in to comment.