Skip to content

Commit

Permalink
added Octo.ai provider
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffxtang committed Jul 14, 2024
1 parent 5527a07 commit dca54bd
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 3 deletions.
3 changes: 2 additions & 1 deletion .env.sample
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ MISTRAL_API_KEY=""
OPENAI_API_KEY=""
OLLAMA_API_URL="http://localhost:11434"
REPLICATE_API_KEY=""
TOGETHER_API_KEY=""
TOGETHER_API_KEY=""
OCTO_API_KEY=""
2 changes: 2 additions & 0 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
OpenAIInterface,
ReplicateInterface,
TogetherInterface,
OctoInterface,
)


Expand Down Expand Up @@ -44,6 +45,7 @@ def __init__(self):
"openai": OpenAIInterface,
"replicate": ReplicateInterface,
"together": TogetherInterface,
"octo": OctoInterface,
}

def get_provider_interface(self, model):
Expand Down
1 change: 1 addition & 0 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
from .openai_interface import OpenAIInterface
from .replicate_interface import ReplicateInterface
from .together_interface import TogetherInterface
from .octo_interface import OctoInterface
40 changes: 40 additions & 0 deletions aimodels/providers/octo_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""The interface to the Octo API."""

import os

from ..framework.provider_interface import ProviderInterface

_OCTO_BASE_URL = "https://text.octoai.run/v1"


class OctoInterface(ProviderInterface):
"""Implements the ProviderInterface for interacting with Octo's APIs."""

def __init__(self):
"""Set up the Octo client using the API key obtained from the user's environment."""
from openai import OpenAI

self.octo_client = OpenAI(
api_key=os.getenv("OCTO_API_KEY"),
base_url=_OCTO_BASE_URL,
)

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Together API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
return self.octo_client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
)
20 changes: 18 additions & 2 deletions examples/multi_fm_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"outputs": [],
"source": [
"import sys\n",
"sys.path.append('../aimodels')\n",
"sys.path.append('../../aimodels')\n",
"\n",
"from dotenv import load_dotenv, find_dotenv\n",
"\n",
Expand All @@ -44,7 +44,8 @@
"os.environ['GROQ_API_KEY'] = 'xxx' # get a free key at https://console.groq.com/keys\n",
"os.environ['FIREWORKS_API_KEY'] = 'xxx' # get a free key at https://fireworks.ai/api-keys\n",
"os.environ['REPLICATE_API_KEY'] = 'xxx' # get a free key at https://replicate.com/account/api-tokens\n",
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai/"
"os.environ['TOGETHER_API_KEY'] = 'xxx' # get a free key at https://api.together.ai\n",
"os.environ['OCTO_API_KEY'] = 'xxx' # get a free key at https://octoai.cloud/settings"
]
},
{
Expand All @@ -69,6 +70,21 @@
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5388efc4-3fd2-4dc6-ab58-7b179ce07943",
"metadata": {},
"outputs": [],
"source": [
"octo_llama3_8b = \"octo:meta-llama-3-8b-instruct\"\n",
"#octo_llama3_70b = \"octo:meta-llama-3-70b-instruct\"\n",
"\n",
"response = client.chat.completions.create(model=octo_llama3_8b, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit dca54bd

Please sign in to comment.