From 4b7d1d88afe185f4f7cdd4d17ba399fedb22d878 Mon Sep 17 00:00:00 2001 From: standsleeping Date: Wed, 3 Jul 2024 14:06:18 -0500 Subject: [PATCH 1/3] Add OllamaInterface --- aimodels/client/multi_fm_client.py | 10 +- aimodels/providers/__init__.py | 1 + aimodels/providers/ollama_interface.py | 51 +++++++++ examples/multi_fm_client.ipynb | 143 +++++++++++++++++++++++++ examples/test_anthropic.ipynb | 92 ---------------- poetry.lock | 16 ++- pyproject.toml | 1 + 7 files changed, 219 insertions(+), 95 deletions(-) create mode 100644 aimodels/providers/ollama_interface.py create mode 100644 examples/multi_fm_client.ipynb delete mode 100644 examples/test_anthropic.ipynb diff --git a/aimodels/client/multi_fm_client.py b/aimodels/client/multi_fm_client.py index 210e7e6..9773aad 100644 --- a/aimodels/client/multi_fm_client.py +++ b/aimodels/client/multi_fm_client.py @@ -1,7 +1,12 @@ """MultiFMClient manages a Chat across multiple provider interfaces.""" -from ..providers import AnthropicInterface, OpenAIInterface, GroqInterface from .chat import Chat +from ..providers import ( + AnthropicInterface, + OpenAIInterface, + GroqInterface, + OllamaInterface, +) class MultiFMClient: @@ -30,6 +35,7 @@ def __init__(self): "openai": OpenAIInterface, "groq": GroqInterface, "anthropic": AnthropicInterface, + "ollama": OllamaInterface, } def get_provider_interface(self, model): @@ -59,7 +65,7 @@ def get_provider_interface(self, model): model_name = model_parts[1] if provider in self.all_interfaces: - return self.all_interfaces[provider] + return self.all_interfaces[provider], model_name if provider not in self.all_factories: raise Exception( diff --git a/aimodels/providers/__init__.py b/aimodels/providers/__init__.py index 04e2bd3..87ad296 100644 --- a/aimodels/providers/__init__.py +++ b/aimodels/providers/__init__.py @@ -3,3 +3,4 @@ from .openai_interface import OpenAIInterface from .groq_interface import GroqInterface from .anthropic_interface import AnthropicInterface +from .ollama_interface import OllamaInterface diff --git a/aimodels/providers/ollama_interface.py b/aimodels/providers/ollama_interface.py new file mode 100644 index 0000000..a34e6d3 --- /dev/null +++ b/aimodels/providers/ollama_interface.py @@ -0,0 +1,51 @@ +"""The interface to the Ollama API.""" + +from aimodels.framework import ProviderInterface, ChatCompletionResponse +from httpx import ConnectError + + +class OllamaInterface(ProviderInterface): + """Implements the ProviderInterface for interacting with the Ollama API.""" + + _OLLAMA_STATUS_ERROR_MESSAGE = "Ollama is likely not running. Start Ollama by running `ollama serve` on your host." + + def __init__(self, server_url="http://localhost:11434"): + """Set up the Ollama API client with the key from the user's environment.""" + from ollama import Client + + self.ollama_client = Client(host=server_url) + + def chat_completion_create(self, messages=None, model=None, temperature=0): + """Request chat completions from Ollama. + + Args: + ---- + model (str): Identifies the specific provider/model to use. + messages (list of dict): A list of message objects in chat history. + temperature (float): The temperature to use in the completion. + + Raises: + ------ + RuntimeError: If the Ollama server is not reachable, + we catch the ConnectError from the underlying httpx library + used by the Ollama client. + + Returns: + ------- + The ChatCompletionResponse with the completion result. + + """ + try: + response = self.ollama_client.chat( + model=model, + messages=messages, + options={"temperature": temperature}, + ) + except ConnectError: + raise RuntimeError(self._OLLAMA_STATUS_ERROR_MESSAGE) + + text_response = response["message"]["content"] + chat_completion_response = ChatCompletionResponse() + chat_completion_response.choices[0].message.content = text_response + + return chat_completion_response diff --git a/examples/multi_fm_client.ipynb b/examples/multi_fm_client.ipynb new file mode 100644 index 0000000..06c280a --- /dev/null +++ b/examples/multi_fm_client.ipynb @@ -0,0 +1,143 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "60c7fb39", + "metadata": {}, + "source": [ + "# MultiFMClient\n", + "\n", + "MultiFMClient provides a uniform interface for interacting with LLMs from various providers. It adapts the official python libraries from providers such as Mistral, OpenAI, Meta, Anthropic, etc. to conform to the OpenAI chat completion interface.\n", + "\n", + "Below are some examples of how to use MultiFMClient to interact with different LLMs." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "initial_id", + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-02T23:20:19.015491Z", + "start_time": "2024-07-02T23:20:19.004272Z" + }, + "collapsed": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 1, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import sys\n", + "sys.path.append('../aimodels')\n", + "\n", + "from dotenv import load_dotenv, find_dotenv\n", + "\n", + "load_dotenv(find_dotenv())" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4de3a24f", + "metadata": {}, + "outputs": [], + "source": [ + "from aimodels.client import MultiFMClient\n", + "\n", + "client = MultiFMClient()\n", + "\n", + "messages = [\n", + " {\"role\": \"system\", \"content\": \"Respond in Pirate English.\"},\n", + " {\"role\": \"user\", \"content\": \"Tell me a joke\"},\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "adebd2f0b578a909", + "metadata": { + "ExecuteTime": { + "end_time": "2024-07-03T02:22:26.282827Z", + "start_time": "2024-07-03T02:22:18.193996Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Arrr, me bucko, 'ere be a jolly jest fer ye!\n", + "\n", + "What did th' pirate say on 'is 80th birthday? \"Aye matey!\"\n", + "\n", + "Ye see, it be a play on words, as \"Aye matey\" sounds like \"I'm eighty\". Har har har! 'Tis a clever bit o' pirate humor, if I do say so meself. Now, 'ow about ye fetch me a mug o' grog while I spin ye another yarn?\n" + ] + } + ], + "source": [ + "anthropic_claude_3_opus = \"anthropic:claude-3-opus-20240229\"\n", + "\n", + "response = client.chat.completions.create(model=anthropic_claude_3_opus, messages=messages)\n", + "\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "6819ac17", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Arrrr, here be a joke fer ye!\n", + "\n", + "Why did the pirate take a parrot on his ship?\n", + "\n", + "Because it were a hootin' good bird to have around, savvy? Aye, and it kept 'im company while he were swabbin' the decks! Arrrgh, I hope that made ye laugh, matey!\n" + ] + } + ], + "source": [ + "ollama_llama3 = \"ollama:llama3\"\n", + "\n", + "response = client.chat.completions.create(model=ollama_llama3, messages=messages, temperature=0.75)\n", + "\n", + "print(response.choices[0].message.content)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/test_anthropic.ipynb b/examples/test_anthropic.ipynb deleted file mode 100644 index 63f03c5..0000000 --- a/examples/test_anthropic.ipynb +++ /dev/null @@ -1,92 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "id": "initial_id", - "metadata": { - "collapsed": true, - "ExecuteTime": { - "end_time": "2024-07-02T23:20:19.015491Z", - "start_time": "2024-07-02T23:20:19.004272Z" - } - }, - "source": [ - "import sys\n", - "sys.path.append('../aimodels')\n", - "\n", - "from dotenv import load_dotenv, find_dotenv\n", - "\n", - "load_dotenv(find_dotenv())" - ], - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "execution_count": 1 - }, - { - "metadata": { - "ExecuteTime": { - "end_time": "2024-07-03T02:22:26.282827Z", - "start_time": "2024-07-03T02:22:18.193996Z" - } - }, - "cell_type": "code", - "source": [ - "from aimodels.client import MultiFMClient\n", - "\n", - "client = MultiFMClient()\n", - "model_string = \"anthropic:claude-3-opus-20240229\"\n", - "messages=[{\"role\": \"system\", \"content\": \"Respond in Pirate English.\"}, \n", - " {\"role\": \"user\", \"content\": \"Tell me a joke\"} ] \n", - "\n", - "response = client.chat.completions.create(model=model_string, messages=messages)\n", - "print(response.choices[0].message.content)" - ], - "id": "adebd2f0b578a909", - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Arrr, me bucko, 'ere be a jolly jest fer ye!\n", - "\n", - "What did th' pirate say on 'is 80th birthday? \"Aye matey!\"\n", - "\n", - "Ye see, it be a play on words, as \"Aye matey\" sounds like \"I'm eighty\". Har har har! 'Tis a clever bit o' pirate humor, if I do say so meself. Now, 'ow about ye fetch me a mug o' grog while I spin ye another yarn?\n" - ] - } - ], - "execution_count": 6 - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/poetry.lock b/poetry.lock index cc7c959..4ebd519 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1559,6 +1559,20 @@ jupyter-server = ">=1.8,<3" [package.extras] test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync"] +[[package]] +name = "ollama" +version = "0.2.1" +description = "The official Python client for Ollama." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "ollama-0.2.1-py3-none-any.whl", hash = "sha256:b6e2414921c94f573a903d1069d682ba2fb2607070ea9e19ca4a7872f2a460ec"}, + {file = "ollama-0.2.1.tar.gz", hash = "sha256:fa316baa9a81eac3beb4affb0a17deb3008fdd6ed05b123c26306cfbe4c349b6"}, +] + +[package.dependencies] +httpx = ">=0.27.0,<0.28.0" + [[package]] name = "openai" version = "1.35.8" @@ -2795,4 +2809,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "de419b2a192211f40cffd1350cd58103ea21bd22009a641e5d5994066ac023d7" +content-hash = "f69d154a3134b90dbfb0bfb43d2a0cfdf79e80a0977bd437004efc49b60c90d8" diff --git a/pyproject.toml b/pyproject.toml index 4008d51..35025e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,7 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.10" +ollama = "^0.2.1" [tool.poetry.group.dev.dependencies] From abeb1deec687fb84b5c51545575fd4db98c368cb Mon Sep 17 00:00:00 2001 From: standsleeping Date: Wed, 3 Jul 2024 15:37:09 -0500 Subject: [PATCH 2/3] Move dependency to dev group --- poetry.lock | 2 +- pyproject.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 4ebd519..6fac250 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2809,4 +2809,4 @@ test = ["websockets"] [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "f69d154a3134b90dbfb0bfb43d2a0cfdf79e80a0977bd437004efc49b60c90d8" +content-hash = "edb9cb7c3cf3edcfb8c46efc04a9001c2a44e19f468e5d559416fafbac452add" diff --git a/pyproject.toml b/pyproject.toml index 35025e1..5c11ee0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ readme = "README.md" [tool.poetry.dependencies] python = "^3.10" -ollama = "^0.2.1" [tool.poetry.group.dev.dependencies] @@ -19,6 +18,7 @@ openai = "^1.35.8" groq = "^0.9.0" anthropic = "^0.30.1" notebook = "^7.2.1" +ollama = "^0.2.1" [build-system] requires = ["poetry-core"] From 97c8e18d0a5f99d7c2d96eb5ff6719aa1e376fc2 Mon Sep 17 00:00:00 2001 From: standsleeping Date: Wed, 3 Jul 2024 15:41:26 -0500 Subject: [PATCH 3/3] Fetch server_url from env --- .env.sample | 1 + aimodels/providers/ollama_interface.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/.env.sample b/.env.sample index a243234..cd7561a 100644 --- a/.env.sample +++ b/.env.sample @@ -1,3 +1,4 @@ ANTHROPIC_API_KEY="" GROQ_API_KEY="" OPENAI_API_KEY="" +OLLAMA_API_URL="http://localhost:11434" diff --git a/aimodels/providers/ollama_interface.py b/aimodels/providers/ollama_interface.py index a34e6d3..e9f7828 100644 --- a/aimodels/providers/ollama_interface.py +++ b/aimodels/providers/ollama_interface.py @@ -2,6 +2,7 @@ from aimodels.framework import ProviderInterface, ChatCompletionResponse from httpx import ConnectError +import os class OllamaInterface(ProviderInterface): @@ -9,7 +10,9 @@ class OllamaInterface(ProviderInterface): _OLLAMA_STATUS_ERROR_MESSAGE = "Ollama is likely not running. Start Ollama by running `ollama serve` on your host." - def __init__(self, server_url="http://localhost:11434"): + def __init__( + self, server_url=os.getenv("OLLAMA_API_URL", "http://localhost:11434") + ): """Set up the Ollama API client with the key from the user's environment.""" from ollama import Client