Skip to content

Commit

Permalink
Merge pull request #10 from andrewyng/together
Browse files Browse the repository at this point in the history
  • Loading branch information
ksolo authored Jul 12, 2024
2 parents 5e7d5c5 + 93e6bff commit 5f8e1f9
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 63 deletions.
1 change: 1 addition & 0 deletions .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ MISTRAL_API_KEY=""
OPENAI_API_KEY=""
OLLAMA_API_URL="http://localhost:11434"
REPLICATE_API_KEY=""
TOGETHER_API_KEY=""
2 changes: 2 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
OllamaInterface,
OpenAIInterface,
ReplicateInterface,
TogetherInterface,
)


Expand Down Expand Up @@ -42,6 +43,7 @@ def __init__(self):
"ollama": OllamaInterface,
"openai": OpenAIInterface,
"replicate": ReplicateInterface,
"together": TogetherInterface,
}

def get_provider_interface(self, model):
Expand Down
1 change: 1 addition & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .ollama_interface import OllamaInterface
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
40 changes: 40 additions & 0 deletions aimodels/providers/together_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""The interface to the Together API."""

import os

from ..framework.provider_interface import ProviderInterface

_TOGETHER_BASE_URL = "https://api.together.xyz/v1"


class TogetherInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Together's APIs."""

def __init__(self):
"""Set up the Together client using the API key obtained from the user's environment."""
from openai import OpenAI

self.together_client = OpenAI(
api_key=os.getenv("TOGETHER_API_KEY"),
base_url=_TOGETHER_BASE_URL,
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Together API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.together_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
91 changes: 28 additions & 63 deletions examples/multi_fm_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,26 +14,15 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"id": "initial_id",
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-04T15:30:02.064319Z",
"start_time": "2024-07-04T15:30:02.051986Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../aimodels')\n",
Expand All @@ -45,7 +34,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"id": "a54491b7-6aa9-4337-9aba-3a0aef263bb4",
"metadata": {},
"outputs": [],
Expand All @@ -54,12 +43,13 @@
"\n",
"os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n",
"os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n",
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens"
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n",
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai/"
]
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"id": "4de3a24f",
"metadata": {
"ExecuteTime": {
Expand All @@ -79,34 +69,37 @@
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4b3e6c41-070d-4041-9ed9-c8977790fe18",
"metadata": {},
"outputs": [],
"source": [
"together_llama3_8b = \"together:meta-llama/Llama-3-8b-chat-hf\"\n",
"#together_llama3_70b = \"together:meta-llama/Llama-3-70b-chat-hf\"\n",
"\n",
"response = client.chat.completions.create(model=together_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "668a6cfa-9011-480a-ae1b-6dbd6a51e716",
"metadata": {},
"outputs": [],
"source": [
"# !pip install fireworks-ai"
"#!pip install fireworks-ai"
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"id": "9900fdf3-a113-40fd-b42f-0e6d866838be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arrrr, listen close me hearty! Here be a joke fer ye:\n",
"\n",
"Why did the pirate quit his job?\n",
"\n",
"Because he was sick o' all the arrrr-guments! (get it? arguments, but with an \"arrr\" like a pirate says? aye, I thought it be a good one, matey!)\n"
]
}
],
"outputs": [],
"source": [
"fireworks_llama3_8b = \"fireworks:accounts/fireworks/models/llama-v3-8b-instruct\"\n",
"#fireworks_llama3_70b = \"fireworks:accounts/fireworks/models/llama-v3-70b-instruct\"\n",
Expand All @@ -118,24 +111,10 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"id": "c9b2aad6-8603-4227-9566-778f714eb0b5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arrrr, listen close me hearty! Here be a joke fer ye:\n",
"\n",
"Why did the pirate quit his job?\n",
"\n",
"Because he were sick o' all the arrrr-guments! (get it? arguments, but with arrrr, like a pirate says \"arrgh\"! ahhahahah!)\n",
"\n",
"Yer turn, matey! Got a joke to share?\n"
]
}
],
"outputs": [],
"source": [
"groq_llama3_8b = \"groq:llama3-8b-8192\"\n",
"# groq_llama3_70b = \"groq:llama3-70b-8192\"\n",
Expand All @@ -147,24 +126,10 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"id": "6baf88b8-2ecb-4bdf-9263-4af949668d16",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arrrr, listen close me hearty! Here be a joke fer ye:\n",
"\n",
"Why did the pirate quit his job?\n",
"\n",
"Because he were sick o' all the arrrr-guments! (get it? arguments, but with arrrr, like a pirate says \"arrgh\"! ahhahahah!)\n",
"\n",
"Yer turn, matey! Got a joke to share?\n"
]
}
],
"outputs": [],
"source": [
"replicate_llama3_8b = \"replicate:meta/meta-llama-3-8b-instruct\"\n",
"#replicate_llama3_70b = \"replicate:meta/meta-llama-3-70b-instruct\"\n",
Expand Down

0 comments on commit 5f8e1f9

Please sign in to comment.