Skip to content

Commit

Permalink
feat(llm): Use Ollama as the LLM Provider
Browse files Browse the repository at this point in the history
  • Loading branch information
thompsonson committed Mar 1, 2024
1 parent 6a539bb commit 532610d
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 38 deletions.
24 changes: 24 additions & 0 deletions .github/workflows/ci-cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,27 @@ jobs:

- name: Run the tests
run: make test

release:
needs: test
runs-on: ubuntu-latest
if: github.event_name == 'push' && github.ref == 'refs/heads/main'
steps:
- uses: actions/checkout@v2
with:
fetch-depth: 0 # Important for semantic-release to analyze commits

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.10'

- name: Install python-semantic-release
run: pip install python-semantic-release

- name: Semantic Release
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # This token is provided by Actions, no need to create it
PYPI_TOKEN: ${{ secrets.PYPI_TOKEN }} # This token should be created in PyPI and set in your repo's secrets
run: |
semantic-release publish
2 changes: 1 addition & 1 deletion .github/workflows/pr_agent.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ jobs:
uses: Codium-ai/pr-agent@main
env:
OPENAI_KEY: ${{ secrets.OPENAI_API_KEY }}
GITHUB_TOKEN: ${{ secrets.GH_TOKEN }}
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # This token is provided by Actions, no need to create it
9 changes: 3 additions & 6 deletions det/embeddings/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,11 @@ def generate_embeddings(self, texts):
texts_without_embeddings = []

for text in texts:
print(f"Checking cache for text: {text}")
if text in self.embeddings_cache:
print("Cache hit for text.")
logger.debug("Cache hit for text.")
embeddings_to_return.append(self.embeddings_cache[text])
else:
print("Cache miss for text.")
logger.debug("Cache miss for text.")
texts_without_embeddings.append(text)

if texts_without_embeddings:
Expand All @@ -96,11 +95,9 @@ def generate_embeddings(self, texts):
)
for text, embedding in zip(texts_without_embeddings, new_embeddings):
self.embeddings_cache[text] = embedding
print(f"Added new embeddings to cache for text: {text}")
logger.debug("Added new embeddings to cache")
embeddings_to_return.append(embedding)

print(f"embeddings_to_return: {embeddings_to_return}")

return embeddings_to_return

def _save_cache(self):
Expand Down
42 changes: 17 additions & 25 deletions det/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,37 @@
"""
This module defines the BaseLLMClient class, serving as an interface for interactions with
various Large Language Models (LLMs). It establishes a standardized method for generating text
responses across different LLM providers, ensuring flexibility and extendability in integrating
multiple LLM services.
By implementing the `generate_response` method, subclasses can provide specific functionalities
for any LLM provider, such as OpenAI, Google, Anthropic, or others, adhering to a unified API.
This design promotes code reuse and simplifies the process of swapping or combining LLM services
in applications requiring natural language generation.
Example Usage:
--------------
class MyLLMClient(BaseLLMClient):
def generate_response(self, prompt: str, **kwargs):
# Implementation for a specific LLM provider
pass
Implementing this interface allows for easy integration and maintenance of LLM-based features,
supporting a wide range of applications from chatbots to content generation tools.
"""
# det/llm/base.py

from abc import ABC, abstractmethod


class BaseLLMClient(ABC):
class LLMGeneratorInterface(ABC):
"""
Base class for LLM clients.
Example Usage:
--------------
class MyLLMClient(BaseLLMClient):
def generate_response(self, prompt: str, **kwargs):
def __init__(self, **kwargs):
# Initialize any necessary variables or state
pass
def generate_response(self, prompt: str, **kwargs) -> str:
# Implementation for a specific LLM provider
pass
"""

def __init__(self, **kwargs):
# Initialize any necessary variables or state
pass

@abstractmethod
def generate_response(self, prompt: str, **kwargs):
def generate_response(self, prompt: str, **kwargs) -> str:
"""
Generates a response to a given prompt using the LLM.
:param prompt: The input prompt to generate text for.
:param kwargs: Additional parameters specific to the LLM provider.
- `temperature`: default is 0.
- `max_tokens`: default is 256.
:return: The generated text response.
"""
raise NotImplementedError("This method should be implemented by subclasses.")
pass
117 changes: 117 additions & 0 deletions det/llm/llm_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# det/llm/llm_ollama.py

from abc import ABC, abstractmethod

from ollama import Client

import logging


from det.llm.base import LLMGeneratorInterface


logger = logging.getLogger(__name__)


class OllamaClient(LLMGeneratorInterface):
"""
The `OllamaClient` class is a subclass of the `BaseLLMClient` abstract class.
It is used to generate text responses using the Ollama language model (LLM).
The class initializes with a specified model and host, and provides a method
to generate a response to a given prompt using the Ollama LLM.
"""

def __init__(self, model: str = "llama2", host: str = "http://localhost:11434"):
"""
Initializes the `OllamaClient` class with the specified model and host.
Parameters:
- model (str): The specified model for the Ollama LLM.
- host (str): The host URL for the Ollama LLM.
Raises:
- TypeError: If the model or host parameter is not a string.
"""
if not isinstance(model, str):
raise TypeError("Model parameter must be a string.")
if not isinstance(host, str):
raise TypeError("Host parameter must be a string.")
self.model = model
self.client = Client(host=host)

def generate_response(self, prompt: str, **kwargs) -> str:
"""
Generates a response to a given prompt using the Ollama LLM.
Parameters:
- prompt (str): The prompt for generating the response.
- **kwargs: Additional parameters specific to the LLM provider.
Raises:
- ValueError: If the prompt is not a string.
Returns:
- str: The generated text response.
"""
if not isinstance(prompt, str):
raise ValueError("Prompt must be a string.")
try:
response = self.client.chat(
model=self.model, # Use the model specified during initialization
messages=[{"role": "user", "content": prompt}],
stream=False,
options={"temperature": 0},
)
return response["message"]["content"]
except Exception as e:
logging.error(f"An error occurred: {e}")


class LLMAdapterInterface(ABC):
"""
Interface for LLM adapters. Defines the contract that all LLM adapters must follow.
"""

@abstractmethod
def generate(self, prompt: str, **kwargs):
"""
Generates a response to a given prompt using the LLM.
:param prompt: The input prompt to generate text for.
:param kwargs: Additional parameters specific to the LLM provider.
:return: The generated text response.
"""
pass


class OllamaAdapter(LLMAdapterInterface):
"""
The `OllamaAdapter` class is used to adapt the Ollama LLM to the `LLMClient` interface,
following the `LLMAdapterInterface`.
"""

def __init__(self, model: str = "mistral", host: str = "http://localhost:11434"):
self.model = model
self.client = Client(host=host)

def generate(self, prompt: str, **kwargs):
response = self.client.chat(
model=self.model, # Use the model specified during initialization
messages=[{"role": "user", "content": prompt}],
stream=False,
options={"temperature": 0},
)
return response.message.content


class OllamaGenerator(LLMGeneratorInterface):
"""
The `OllamaGenerator` class is a subclass of the `LLMClient` abstract class.
It uses the `OllamaAdapter` to generate text responses using the Ollama LLM.
"""

def __init__(self, model: str = "llama2", host: str = "http://localhost:11434"):
self.adapter = OllamaAdapter(model, host)

def generate_response(self, prompt: str, **kwargs) -> str:
return self.adapter.generate(prompt, **kwargs)
4 changes: 2 additions & 2 deletions det/llm/llm_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@

from openai import OpenAI

from det.llm.base import BaseLLMClient
from det.llm.base import LLMGeneratorInterface


class OpenAIClient(BaseLLMClient):
class OpenAIClient(LLMGeneratorInterface):
"""
Example:
--------
Expand Down
22 changes: 18 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ openai = "^1.12.0"
typer = "^0.9.0"
rich = "^13.7.0"
numpy = "^1.26.4"
ollama = "^0.1.6"

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.6.2"
Expand Down

0 comments on commit 532610d

Please sign in to comment.