Skip to content

Commit

Permalink
Merge pull request #5 from andrewyng/add-mistral-provider
Browse files Browse the repository at this point in the history
  • Loading branch information
ksolo authored Jul 4, 2024
2 parents cdacc8a + f0f9614 commit 43f8a8a
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 33 deletions.
1 change: 1 addition & 0 deletions .env.sample
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
ANTHROPIC_API_KEY=""
GROQ_API_KEY=""
MISTRAL_API_KEY=""
OPENAI_API_KEY=""
OLLAMA_API_URL="http://localhost:11434"
8 changes: 5 additions & 3 deletions aimodels/client/multi_fm_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from .chat import Chat
from ..providers import (
AnthropicInterface,
OpenAIInterface,
GroqInterface,
MistralInterface,
OllamaInterface,
OpenAIInterface,
)


Expand All @@ -32,10 +33,11 @@ def __init__(self):
self.chat = Chat(self)
self.all_interfaces = {}
self.all_factories = {
"openai": OpenAIInterface,
"groq": GroqInterface,
"anthropic": AnthropicInterface,
"groq": GroqInterface,
"mistral": MistralInterface,
"ollama": OllamaInterface,
"openai": OpenAIInterface,
}

def get_provider_interface(self, model):
Expand Down
5 changes: 3 additions & 2 deletions aimodels/providers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Provides the individual provider interfaces for each FM provider."""

from .openai_interface import OpenAIInterface
from .groq_interface import GroqInterface
from .anthropic_interface import AnthropicInterface
from .groq_interface import GroqInterface
from .mistral_interface import MistralInterface
from .ollama_interface import OllamaInterface
from .openai_interface import OpenAIInterface
38 changes: 38 additions & 0 deletions aimodels/providers/mistral_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

from aimodels.framework import ProviderInterface


class MistralInterface(ProviderInterface):
"""Implements the provider interface for Mistral."""

def __init__(self):
from mistralai.client import MistralClient

self.mistral_client = MistralClient(api_key=os.getenv("MISTRAL_API_KEY"))

def chat_completion_create(self, messages=None, model=None, temperature=0):
"""Request chat completions from the Mistral API.
Args:
----
model (str): Identifies the specific provider/model to use.
messages (list of dict): A list of message objects in chat history.
temperature (float): The temperature to use in the completion.
Returns:
-------
The API response with the completion result.
"""
from mistralai.models.chat_completion import ChatMessage

messages = [
ChatMessage(role=message["role"], content=message["content"])
for message in messages
]
return self.mistral_client.chat(
model=model,
messages=messages,
temperature=temperature,
)
94 changes: 67 additions & 27 deletions examples/multi_fm_client.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,22 @@
},
{
"cell_type": "code",
"execution_count": 1,
"id": "initial_id",
"metadata": {
"collapsed": true,
"ExecuteTime": {
"end_time": "2024-07-02T23:20:19.015491Z",
"start_time": "2024-07-02T23:20:19.004272Z"
},
"collapsed": true
"end_time": "2024-07-04T15:30:02.064319Z",
"start_time": "2024-07-04T15:30:02.051986Z"
}
},
"source": [
"import sys\n",
"sys.path.append('../aimodels')\n",
"\n",
"from dotenv import load_dotenv, find_dotenv\n",
"\n",
"load_dotenv(find_dotenv())"
],
"outputs": [
{
"data": {
Expand All @@ -35,21 +42,17 @@
"output_type": "execute_result"
}
],
"source": [
"import sys\n",
"sys.path.append('../aimodels')\n",
"\n",
"from dotenv import load_dotenv, find_dotenv\n",
"\n",
"load_dotenv(find_dotenv())"
]
"execution_count": 1
},
{
"cell_type": "code",
"execution_count": 2,
"id": "4de3a24f",
"metadata": {},
"outputs": [],
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-04T15:31:12.914321Z",
"start_time": "2024-07-04T15:31:12.796445Z"
}
},
"source": [
"from aimodels.client import MultiFMClient\n",
"\n",
Expand All @@ -59,18 +62,26 @@
" {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n",
" {\"role\": \"user\", \"content\": \"Tell me a joke\"},\n",
"]"
]
],
"outputs": [],
"execution_count": 3
},
{
"cell_type": "code",
"execution_count": 3,
"id": "adebd2f0b578a909",
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-03T02:22:26.282827Z",
"start_time": "2024-07-03T02:22:18.193996Z"
"end_time": "2024-07-04T15:31:25.060689Z",
"start_time": "2024-07-04T15:31:16.131205Z"
}
},
"source": [
"anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n",
"\n",
"response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
],
"outputs": [
{
"name": "stdout",
Expand All @@ -84,13 +95,7 @@
]
}
],
"source": [
"anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n",
"\n",
"response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n",
"\n",
"print(response.choices[0].message.content)"
]
"execution_count": 4
},
{
"cell_type": "code",
Expand All @@ -117,6 +122,41 @@
"\n",
"print(response.choices[0].message.content)"
]
},
{
"metadata": {
"ExecuteTime": {
"end_time": "2024-07-04T15:31:39.472675Z",
"start_time": "2024-07-04T15:31:38.283368Z"
}
},
"cell_type": "code",
"source": [
"mistral_7b = \"mistral:open-mistral-7b\"\n",
"\n",
"response = client.chat.completions.create(model=mistral_7b, messages=messages, temperature=0.2)\n",
"\n",
"print(response.choices[0].message.content)"
],
"id": "4a94961b2bddedbb",
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Arr matey, I've got a jest fer ye, if ye be ready for a laugh! Why did the pirate bring a clock to the island? Because he wanted to catch the time! Aye, that be a good one, I be thinkin'. Arrr!\n"
]
}
],
"execution_count": 5
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "",
"id": "611210a4dc92845f"
}
],
"metadata": {
Expand Down
78 changes: 77 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ groq = "^0.9.0"
anthropic = "^0.30.1"
notebook = "^7.2.1"
ollama = "^0.2.1"
mistralai = "^0.4.2"

[build-system]
requires = ["poetry-core"]
Expand Down

0 comments on commit 43f8a8a

Please sign in to comment.