Skip to content

Commit

Permalink
added AWS Bedrock provider for Llama 3 - with model input converted t…
Browse files Browse the repository at this point in the history
…o using Llama 3 prompting format and output converted to use simulated OpenAI chat completion output type
  • Loading branch information
jeffxtang committed Jul 17, 2024
1 parent dca54bd commit 2e28e85
Show file tree
Hide file tree
Showing 5 changed files with 135 additions and 2 deletions.
4 changes: 3 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ OPENAI_API_KEY=""
OLLAMA_API_URL="http://localhost:11434"
REPLICATE_API_KEY=""
TOGETHER_API_KEY=""
OCTO_API_KEY=""
OCTO_API_KEY=""
AWS_ACCESS_KEY_ID=""
AWS_SECRET_ACCESS_KEY=""
2 changes: 2 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ReplicateInterface,
TogetherInterface,
OctoInterface,
AWSBedrockInterface,
)


Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self):
"replicate": ReplicateInterface,
"together": TogetherInterface,
"octo": OctoInterface,
"aws": AWSBedrockInterface
}

def get_provider_interface(self, model):
Expand Down
1 change: 1 addition & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
from .octo_interface import OctoInterface
from .aws_bedrock_interface import AWSBedrockInterface
101 changes: 101 additions & 0 deletions aimodels/providers/aws_bedrock_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""The interface to the Together API."""

import os
from urllib.request import urlopen
import boto3
import json

from ..framework.provider_interface import ProviderInterface

def convert_messages_to_llama3_prompt(messages):
"""
Convert a list of messages to a prompt in Llama 3 instruction format.
Args:
messages (list of dict): List of messages where each message is a dictionary
with 'role' ('system', 'user', 'assistant') and 'content'.
Returns:
str: Formatted prompt for Llama 3.
"""
prompt = "<|begin_of_text|>"
for message in messages:
prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>{message['content']}<|eot_id|>\n"

prompt += "<|start_header_id|>assistant<|end_header_id|>"

return prompt

class RecursiveNamespace:
"""
Convert dictionaries to objects with attribute access, including nested dictionaries.
This class is used to simulate the OpenAI chat.completions.create's return type, so
response.choices[0].message.content works consistenly for AWS Bedrock's LLM return of a string.
"""
def __init__(self, data):
for key, value in data.items():
if isinstance(value, dict):
value = RecursiveNamespace(value)
elif isinstance(value, list):
value = [RecursiveNamespace(item) if isinstance(item, dict) else item for item in value]
setattr(self, key, value)

@classmethod
def from_dict(cls, data):
return cls(data)

def to_dict(self):
result = {}
for key, value in self.__dict__.items():
if isinstance(value, RecursiveNamespace):
value = value.to_dict()
elif isinstance(value, list):
value = [item.to_dict() if isinstance(item, RecursiveNamespace) else item for item in value]
result[key] = value
return result

class AWSBedrockInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with AWS Bedrock's APIs."""

def __init__(self):
"""Set up the AWS Bedrock client using the AWS access key id and secret access key obtained from the user's environment."""
self.aws_bedrock_client = boto3.client(
service_name="bedrock-runtime",
region_name="us-west-2",
aws_access_key_id=os.getenv("AWS_ACCESS_KEY_ID"),
aws_secret_access_key=os.getenv("AWS_SECRET_ACCESS_KEY"),
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the AWS Bedrock API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
body = json.dumps({
"prompt": convert_messages_to_llama3_prompt(messages),
"temperature": temperature
})
accept = 'application/json'
content_type = 'application/json'
response = self.aws_bedrock_client.invoke_model(body=body, modelId=model, accept=accept, contentType=content_type)
response_body = json.loads(response.get('body').read())
generation = response_body.get('generation')

response_data = {
"choices": [
{
"message": {"content": generation},
}
],
}

return RecursiveNamespace.from_dict(response_data)
29 changes: 28 additions & 1 deletion examples/multi_fm_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@
"os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n",
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n",
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai\n",
"os.environ['OCTO_API_KEY'] = 'xxx' # get a free key at https://octoai.cloud/settings"
"os.environ['OCTO_API_KEY'] = 'xxx' # get a free key at https://octoai.cloud/settings\n",
"os.environ['AWS_ACCESS_KEY_ID'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home\n",
"os.environ['AWS_SECRET_ACCESS_KEY'] = 'xxx' # get or create at https://console.aws.amazon.com/iam/home"
]
},
{
Expand All @@ -70,6 +72,31 @@
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1ffe9a49-638e-4304-b9de-49ee21d9ac8d",
"metadata": {},
"outputs": [],
"source": [
"#!pip install boto3"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9893c7e4-799a-42c9-84de-f9e643044462",
"metadata": {},
"outputs": [],
"source": [
"aws_bedrock_llama3_8b = \"aws:meta.llama3-8b-instruct-v1:0\"\n",
"#aws_bedrock_llama3_8b = \"aws:meta.llama3-70b-instruct-v1:0\"\n",
"\n",
"response = client.chat.completions.create(model=aws_bedrock_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 2e28e85

Please sign in to comment.