From f0f961467c8537824215f792ec02da995ffc7c70 Mon Sep 17 00:00:00 2001 From: Kevin Solorio Date: Thu, 4 Jul 2024 10:34:51 -0500 Subject: [PATCH] Add mistral to the available providers This PR: - adds the mistral interface to our providers - updates references to providers by putting them in alpha order - adds mistral to the example notebook - adds mistralai dependency --- .env.sample | 1 + aimodels/client/multi_fm_client.py | 8 ++- aimodels/providers/__init__.py | 5 +- aimodels/providers/mistral_interface.py | 38 ++++++++++ examples/multi_fm_client.ipynb | 94 ++++++++++++++++++------- poetry.lock | 78 +++++++++++++++++++- pyproject.toml | 1 + 7 files changed, 192 insertions(+), 33 deletions(-) create mode 100644 aimodels/providers/mistral_interface.py diff --git a/.env.sample b/.env.sample index cd7561a..906f1b0 100644 --- a/.env.sample +++ b/.env.sample @@ -1,4 +1,5 @@ ANTHROPIC_API_KEY="" GROQ_API_KEY="" +MISTRAL_API_KEY="" OPENAI_API_KEY="" OLLAMA_API_URL="http://localhost:11434" diff --git a/aimodels/client/multi_fm_client.py b/aimodels/client/multi_fm_client.py index 9773aad..f036e86 100644 --- a/aimodels/client/multi_fm_client.py +++ b/aimodels/client/multi_fm_client.py @@ -3,9 +3,10 @@ from .chat import Chat from ..providers import ( AnthropicInterface, - OpenAIInterface, GroqInterface, + MistralInterface, OllamaInterface, + OpenAIInterface, ) @@ -32,10 +33,11 @@ def __init__(self): self.chat = Chat(self) self.all_interfaces = {} self.all_factories = { - "openai": OpenAIInterface, - "groq": GroqInterface, "anthropic": AnthropicInterface, + "groq": GroqInterface, + "mistral": MistralInterface, "ollama": OllamaInterface, + "openai": OpenAIInterface, } def get_provider_interface(self, model): diff --git a/aimodels/providers/__init__.py b/aimodels/providers/__init__.py index 87ad296..08aff09 100644 --- a/aimodels/providers/__init__.py +++ b/aimodels/providers/__init__.py @@ -1,6 +1,7 @@ """Provides the individual provider interfaces for each FM provider.""" -from .openai_interface import OpenAIInterface -from .groq_interface import GroqInterface from .anthropic_interface import AnthropicInterface +from .groq_interface import GroqInterface +from .mistral_interface import MistralInterface from .ollama_interface import OllamaInterface +from .openai_interface import OpenAIInterface diff --git a/aimodels/providers/mistral_interface.py b/aimodels/providers/mistral_interface.py new file mode 100644 index 0000000..b6ff92e --- /dev/null +++ b/aimodels/providers/mistral_interface.py @@ -0,0 +1,38 @@ +import os + +from aimodels.framework import ProviderInterface + + +class MistralInterface(ProviderInterface): + """Implements the provider interface for Mistral.""" + + def __init__(self): + from mistralai.client import MistralClient + + self.mistral_client = MistralClient(api_key=os.getenv("MISTRAL_API_KEY")) + + def chat_completion_create(self, messages=None, model=None, temperature=0): + """Request chat completions from the Mistral API. + + Args: + ---- + model (str): Identifies the specific provider/model to use. + messages (list of dict): A list of message objects in chat history. + temperature (float): The temperature to use in the completion. + + Returns: + ------- + The API response with the completion result. + + """ + from mistralai.models.chat_completion import ChatMessage + + messages = [ + ChatMessage(role=message["role"], content=message["content"]) + for message in messages + ] + return self.mistral_client.chat( + model=model, + messages=messages, + temperature=temperature, + ) diff --git a/examples/multi_fm_client.ipynb b/examples/multi_fm_client.ipynb index 06c280a..08b3d31 100644 --- a/examples/multi_fm_client.ipynb +++ b/examples/multi_fm_client.ipynb @@ -14,15 +14,22 @@ }, { "cell_type": "code", - "execution_count": 1, "id": "initial_id", "metadata": { + "collapsed": true, "ExecuteTime": { - "end_time": "2024-07-02T23:20:19.015491Z", - "start_time": "2024-07-02T23:20:19.004272Z" - }, - "collapsed": true + "end_time": "2024-07-04T15:30:02.064319Z", + "start_time": "2024-07-04T15:30:02.051986Z" + } }, + "source": [ + "import sys\n", + "sys.path.append('../aimodels')\n", + "\n", + "from dotenv import load_dotenv, find_dotenv\n", + "\n", + "load_dotenv(find_dotenv())" + ], "outputs": [ { "data": { @@ -35,21 +42,17 @@ "output_type": "execute_result" } ], - "source": [ - "import sys\n", - "sys.path.append('../aimodels')\n", - "\n", - "from dotenv import load_dotenv, find_dotenv\n", - "\n", - "load_dotenv(find_dotenv())" - ] + "execution_count": 1 }, { "cell_type": "code", - "execution_count": 2, "id": "4de3a24f", - "metadata": {}, - "outputs": [], + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-04T15:31:12.914321Z", + "start_time": "2024-07-04T15:31:12.796445Z" + } + }, "source": [ "from aimodels.client import MultiFMClient\n", "\n", @@ -59,18 +62,26 @@ " {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n", " {\"role\": \"user\", \"content\": \"Tell me a joke\"},\n", "]" - ] + ], + "outputs": [], + "execution_count": 3 }, { "cell_type": "code", - "execution_count": 3, "id": "adebd2f0b578a909", "metadata": { "ExecuteTime": { - "end_time": "2024-07-03T02:22:26.282827Z", - "start_time": "2024-07-03T02:22:18.193996Z" + "end_time": "2024-07-04T15:31:25.060689Z", + "start_time": "2024-07-04T15:31:16.131205Z" } }, + "source": [ + "anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n", + "\n", + "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", + "\n", + "print(response.choices[0].message.content)" + ], "outputs": [ { "name": "stdout", @@ -84,13 +95,7 @@ ] } ], - "source": [ - "anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n", - "\n", - "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", - "\n", - "print(response.choices[0].message.content)" - ] + "execution_count": 4 }, { "cell_type": "code", @@ -117,6 +122,41 @@ "\n", "print(response.choices[0].message.content)" ] + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-04T15:31:39.472675Z", + "start_time": "2024-07-04T15:31:38.283368Z" + } + }, + "cell_type": "code", + "source": [ + "mistral_7b = \"mistral:open-mistral-7b\"\n", + "\n", + "response = client.chat.completions.create(model=mistral_7b, messages=messages, temperature=0.2)\n", + "\n", + "print(response.choices[0].message.content)" + ], + "id": "4a94961b2bddedbb", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Arr matey, I've got a jest fer ye, if ye be ready for a laugh! Why did the pirate bring a clock to the island? Because he wanted to catch the time! Aye, that be a good one, I be thinkin'. Arrr!\n" + ] + } + ], + "execution_count": 5 + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": "", + "id": "611210a4dc92845f" } ], "metadata": { diff --git a/poetry.lock b/poetry.lock index 6fac250..0e81909 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1395,6 +1395,22 @@ files = [ [package.dependencies] traitlets = "*" +[[package]] +name = "mistralai" +version = "0.4.2" +description = "" +optional = false +python-versions = "<4.0,>=3.9" +files = [ + {file = "mistralai-0.4.2-py3-none-any.whl", hash = "sha256:63c98eea139585f0a3b2c4c6c09c453738bac3958055e6f2362d3866e96b0168"}, + {file = "mistralai-0.4.2.tar.gz", hash = "sha256:5eb656710517168ae053f9847b0bb7f617eda07f1f93f946ad6c91a4d407fd93"}, +] + +[package.dependencies] +httpx = ">=0.25,<1" +orjson = ">=3.9.10,<3.11" +pydantic = ">=2.5.2,<3" + [[package]] name = "mistune" version = "3.0.2" @@ -1596,6 +1612,66 @@ typing-extensions = ">=4.7,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +[[package]] +name = "orjson" +version = "3.10.6" +description = "Fast, correct Python JSON library supporting dataclasses, datetimes, and numpy" +optional = false +python-versions = ">=3.8" +files = [ + {file = "orjson-3.10.6-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:fb0ee33124db6eaa517d00890fc1a55c3bfe1cf78ba4a8899d71a06f2d6ff5c7"}, + {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c1c4b53b24a4c06547ce43e5fee6ec4e0d8fe2d597f4647fc033fd205707365"}, + {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eadc8fd310edb4bdbd333374f2c8fec6794bbbae99b592f448d8214a5e4050c0"}, + {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:61272a5aec2b2661f4fa2b37c907ce9701e821b2c1285d5c3ab0207ebd358d38"}, + {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:57985ee7e91d6214c837936dc1608f40f330a6b88bb13f5a57ce5257807da143"}, + {file = "orjson-3.10.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:633a3b31d9d7c9f02d49c4ab4d0a86065c4a6f6adc297d63d272e043472acab5"}, + {file = "orjson-3.10.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:1c680b269d33ec444afe2bdc647c9eb73166fa47a16d9a75ee56a374f4a45f43"}, + {file = "orjson-3.10.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f759503a97a6ace19e55461395ab0d618b5a117e8d0fbb20e70cfd68a47327f2"}, + {file = "orjson-3.10.6-cp310-none-win32.whl", hash = "sha256:95a0cce17f969fb5391762e5719575217bd10ac5a189d1979442ee54456393f3"}, + {file = "orjson-3.10.6-cp310-none-win_amd64.whl", hash = "sha256:df25d9271270ba2133cc88ee83c318372bdc0f2cd6f32e7a450809a111efc45c"}, + {file = "orjson-3.10.6-cp311-cp311-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:b1ec490e10d2a77c345def52599311849fc063ae0e67cf4f84528073152bb2ba"}, + {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:55d43d3feb8f19d07e9f01e5b9be4f28801cf7c60d0fa0d279951b18fae1932b"}, + {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ac3045267e98fe749408eee1593a142e02357c5c99be0802185ef2170086a863"}, + {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c27bc6a28ae95923350ab382c57113abd38f3928af3c80be6f2ba7eb8d8db0b0"}, + {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d27456491ca79532d11e507cadca37fb8c9324a3976294f68fb1eff2dc6ced5a"}, + {file = "orjson-3.10.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05ac3d3916023745aa3b3b388e91b9166be1ca02b7c7e41045da6d12985685f0"}, + {file = "orjson-3.10.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1335d4ef59ab85cab66fe73fd7a4e881c298ee7f63ede918b7faa1b27cbe5212"}, + {file = "orjson-3.10.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:4bbc6d0af24c1575edc79994c20e1b29e6fb3c6a570371306db0993ecf144dc5"}, + {file = "orjson-3.10.6-cp311-none-win32.whl", hash = "sha256:450e39ab1f7694465060a0550b3f6d328d20297bf2e06aa947b97c21e5241fbd"}, + {file = "orjson-3.10.6-cp311-none-win_amd64.whl", hash = "sha256:227df19441372610b20e05bdb906e1742ec2ad7a66ac8350dcfd29a63014a83b"}, + {file = "orjson-3.10.6-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:ea2977b21f8d5d9b758bb3f344a75e55ca78e3ff85595d248eee813ae23ecdfb"}, + {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b6f3d167d13a16ed263b52dbfedff52c962bfd3d270b46b7518365bcc2121eed"}, + {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f710f346e4c44a4e8bdf23daa974faede58f83334289df80bc9cd12fe82573c7"}, + {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7275664f84e027dcb1ad5200b8b18373e9c669b2a9ec33d410c40f5ccf4b257e"}, + {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0943e4c701196b23c240b3d10ed8ecd674f03089198cf503105b474a4f77f21f"}, + {file = "orjson-3.10.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:446dee5a491b5bc7d8f825d80d9637e7af43f86a331207b9c9610e2f93fee22a"}, + {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:64c81456d2a050d380786413786b057983892db105516639cb5d3ee3c7fd5148"}, + {file = "orjson-3.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:960db0e31c4e52fa0fc3ecbaea5b2d3b58f379e32a95ae6b0ebeaa25b93dfd34"}, + {file = "orjson-3.10.6-cp312-none-win32.whl", hash = "sha256:a6ea7afb5b30b2317e0bee03c8d34c8181bc5a36f2afd4d0952f378972c4efd5"}, + {file = "orjson-3.10.6-cp312-none-win_amd64.whl", hash = "sha256:874ce88264b7e655dde4aeaacdc8fd772a7962faadfb41abe63e2a4861abc3dc"}, + {file = "orjson-3.10.6-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:66680eae4c4e7fc193d91cfc1353ad6d01b4801ae9b5314f17e11ba55e934183"}, + {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:caff75b425db5ef8e8f23af93c80f072f97b4fb3afd4af44482905c9f588da28"}, + {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:3722fddb821b6036fd2a3c814f6bd9b57a89dc6337b9924ecd614ebce3271394"}, + {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2c116072a8533f2fec435fde4d134610f806bdac20188c7bd2081f3e9e0133f"}, + {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6eeb13218c8cf34c61912e9df2de2853f1d009de0e46ea09ccdf3d757896af0a"}, + {file = "orjson-3.10.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:965a916373382674e323c957d560b953d81d7a8603fbeee26f7b8248638bd48b"}, + {file = "orjson-3.10.6-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:03c95484d53ed8e479cade8628c9cea00fd9d67f5554764a1110e0d5aa2de96e"}, + {file = "orjson-3.10.6-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:e060748a04cccf1e0a6f2358dffea9c080b849a4a68c28b1b907f272b5127e9b"}, + {file = "orjson-3.10.6-cp38-none-win32.whl", hash = "sha256:738dbe3ef909c4b019d69afc19caf6b5ed0e2f1c786b5d6215fbb7539246e4c6"}, + {file = "orjson-3.10.6-cp38-none-win_amd64.whl", hash = "sha256:d40f839dddf6a7d77114fe6b8a70218556408c71d4d6e29413bb5f150a692ff7"}, + {file = "orjson-3.10.6-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:697a35a083c4f834807a6232b3e62c8b280f7a44ad0b759fd4dce748951e70db"}, + {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd502f96bf5ea9a61cbc0b2b5900d0dd68aa0da197179042bdd2be67e51a1e4b"}, + {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f215789fb1667cdc874c1b8af6a84dc939fd802bf293a8334fce185c79cd359b"}, + {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a2debd8ddce948a8c0938c8c93ade191d2f4ba4649a54302a7da905a81f00b56"}, + {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5410111d7b6681d4b0d65e0f58a13be588d01b473822483f77f513c7f93bd3b2"}, + {file = "orjson-3.10.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb1f28a137337fdc18384079fa5726810681055b32b92253fa15ae5656e1dddb"}, + {file = "orjson-3.10.6-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:bf2fbbce5fe7cd1aa177ea3eab2b8e6a6bc6e8592e4279ed3db2d62e57c0e1b2"}, + {file = "orjson-3.10.6-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:79b9b9e33bd4c517445a62b90ca0cc279b0f1f3970655c3df9e608bc3f91741a"}, + {file = "orjson-3.10.6-cp39-none-win32.whl", hash = "sha256:30b0a09a2014e621b1adf66a4f705f0809358350a757508ee80209b2d8dae219"}, + {file = "orjson-3.10.6-cp39-none-win_amd64.whl", hash = "sha256:49e3bc615652617d463069f91b867a4458114c5b104e13b7ae6872e5f79d0844"}, + {file = "orjson-3.10.6.tar.gz", hash = "sha256:e54b63d0a7c6c54a5f5f726bc93a2078111ef060fec4ecbf34c5db800ca3b3a7"}, +] + [[package]] name = "overrides" version = "7.7.0" @@ -2809,4 +2885,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "edb9cb7c3cf3edcfb8c46efc04a9001c2a44e19f468e5d559416fafbac452add" +content-hash = "884c99afd046e16d927fb98a1501268862cfd1b0f8254907273743a6fc7c05f0" diff --git a/pyproject.toml b/pyproject.toml index 5c11ee0..4bd47c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ groq = "^0.9.0" anthropic = "^0.30.1" notebook = "^7.2.1" ollama = "^0.2.1" +mistralai = "^0.4.2" [build-system] requires = ["poetry-core"]