Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Octo.ai provider #13

Merged
merged 5 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ MISTRAL_API_KEY=""
OPENAI_API_KEY=""
OLLAMA_API_URL="http://localhost:11434"
REPLICATE_API_KEY=""
TOGETHER_API_KEY=""
TOGETHER_API_KEY=""
OCTO_API_KEY=""
AWS_ACCESS_KEY_ID=""
AWS_SECRET_ACCESS_KEY=""
4 changes: 4 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
OpenAIInterface,
ReplicateInterface,
TogetherInterface,
OctoInterface,
AWSBedrockInterface,
)


Expand Down Expand Up @@ -44,6 +46,8 @@ def __init__(self):
"openai": OpenAIInterface,
"replicate": ReplicateInterface,
"together": TogetherInterface,
"octo": OctoInterface,
"aws": AWSBedrockInterface,
}

def get_provider_interface(self, model):
Expand Down
2 changes: 2 additions & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
from .octo_interface import OctoInterface
from .aws_bedrock_interface import AWSBedrockInterface
115 changes: 115 additions & 0 deletions aimodels/providers/aws_bedrock_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""The interface to the Together API."""

import os
from urllib.request import urlopen
import boto3
import json

from ..framework.provider_interface import ProviderInterface


def convert_messages_to_llama3_prompt(messages):
"""
Convert a list of messages to a prompt in Llama 3 instruction format.

Args:
messages (list of dict): List of messages where each message is a dictionary
with 'role' ('system', 'user', 'assistant') and 'content'.

Returns:
str: Formatted prompt for Llama 3.
"""
prompt = "<|begin_of_text|>"
for message in messages:
prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>{message['content']}<|eot_id|>\n"

prompt += "<|start_header_id|>assistant<|end_header_id|>"

return prompt


class RecursiveNamespace:
"""
Convert dictionaries to objects with attribute access, including nested dictionaries.
This class is used to simulate the OpenAI chat.completions.create's return type, so
response.choices[0].message.content works consistenly for AWS Bedrock's LLM return of a string.
"""

def __init__(self, data):
for key, value in data.items():
if isinstance(value, dict):
value = RecursiveNamespace(value)
elif isinstance(value, list):
value = [
RecursiveNamespace(item) if isinstance(item, dict) else item
for item in value
]
setattr(self, key, value)

@classmethod
def from_dict(cls, data):
return cls(data)

def to_dict(self):
result = {}
for key, value in self.__dict__.items():
if isinstance(value, RecursiveNamespace):
value = value.to_dict()
elif isinstance(value, list):
value = [
item.to_dict() if isinstance(item, RecursiveNamespace) else item
for item in value
]
result[key] = value
return result


class AWSBedrockInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with AWS Bedrock's APIs."""

def __init__(self):
"""Set up the AWS Bedrock client using the AWS access key id and secret access key obtained from the user's environment."""
self.aws_bedrock_client = boto3.client(
service_name="bedrock-runtime",
region_name="us-west-2",
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the AWS Bedrock API.

Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.

Returns:
-------
The API response with the completion result.

"""
body = json.dumps(
{
"prompt": convert_messages_to_llama3_prompt(messages),
"temperature": temperature,
}
)
accept = "application/json"
content_type = "application/json"
response = self.aws_bedrock_client.invoke_model(
body=body, modelId=model, accept=accept, contentType=content_type
)
response_body = json.loads(response.get("body").read())
generation = response_body.get("generation")

response_data = {
"choices": [
{
"message": {"content": generation},
}
],
}

return RecursiveNamespace.from_dict(response_data)
40 changes: 40 additions & 0 deletions aimodels/providers/octo_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""The interface to the Octo API."""

import os

from ..framework.provider_interface import ProviderInterface

_OCTO_BASE_URL = "https://text.octoai.run/v1"


class OctoInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Octo's APIs."""

def __init__(self):
"""Set up the Octo client using the API key obtained from the user's environment."""
from openai import OpenAI

self.octo_client = OpenAI(
api_key=os.getenv("OCTO_API_KEY"),
base_url=_OCTO_BASE_URL,
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Together API.

Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.

Returns:
-------
The API response with the completion result.

"""
return self.octo_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
47 changes: 45 additions & 2 deletions examples/multi_fm_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../aimodels')\n",
"sys.path.append('../../aimodels')\n",
"\n",
"from dotenv import load_dotenv, find_dotenv\n",
"\n",
Expand All @@ -44,7 +44,10 @@
"os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n",
"os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n",
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n",
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai/"
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai\n",
"os.environ['OCTO_API_KEY'] = 'xxx' # get a free key at https://octoai.cloud/settings\n",
"os.environ['AWS_ACCESS_KEY_ID'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home\n",
"os.environ['AWS_SECRET_ACCESS_KEY'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home"
]
},
{
Expand All @@ -69,6 +72,46 @@
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ffe9a49-638e-4304-b9de-49ee21d9ac8d",
"metadata": {},
"outputs": [],
"source": [
"#!pip install boto3"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9893c7e4-799a-42c9-84de-f9e643044462",
"metadata": {},
"outputs": [],
"source": [
"aws_bedrock_llama3_8b = \"aws:meta.llama3-8b-instruct-v1:0\"\n",
"#aws_bedrock_llama3_8b = \"aws:meta.llama3-70b-instruct-v1:0\"\n",
"\n",
"response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5388efc4-3fd2-4dc6-ab58-7b179ce07943",
"metadata": {},
"outputs": [],
"source": [
"octo_llama3_8b = \"octo:meta-llama-3-8b-instruct\"\n",
"#octo_llama3_70b = \"octo:meta-llama-3-70b-instruct\"\n",
"\n",
"response = client.chat.completions.create(model=octo_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down
68 changes: 67 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ anthropic = "^0.30.1"
notebook = "^7.2.1"
ollama = "^0.2.1"
mistralai = "^0.4.2"
boto3 = "^1.34.144"

[build-system]
requires = ["poetry-core"]
Expand Down
Loading