-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from toggle-corp/feat/embed-model
Embedding models handler
- Loading branch information
Showing
11 changed files
with
3,106 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
venv | ||
.venv |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
OPENAI_API_KEY= |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
[flake8] | ||
extend-ignore = C901, W504 | ||
max-line-length = 125 | ||
# NOTE: Update in .pre-commit-config.yaml as well | ||
extend-exclude = .git,__pycache__,old,build,dist,*/migrations/*.py,legacy/,.venv | ||
max-complexity = 10 | ||
per-file-ignores = | ||
/**/tests/*_mock_data.py: E501 | ||
**/snap_test_*.py: E501 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
default_language_version: | ||
python: python3 | ||
|
||
# NOTE: Update in .flake8 pyproject.toml as well | ||
exclude: | | ||
(?x)^( | ||
\.git| | ||
__pycache__| | ||
.*snap_test_.*\.py| | ||
.+\/.+\/migrations\/.*| | ||
\.venv | ||
) | ||
repos: | ||
- repo: https://github.com/psf/black | ||
rev: 24.4.2 | ||
hooks: | ||
- id: black | ||
# args: ["--check"] | ||
|
||
- repo: https://github.com/PyCQA/isort | ||
rev: 5.13.2 | ||
hooks: | ||
- id: isort | ||
# args: ["--check"] | ||
|
||
- repo: https://github.com/PyCQA/flake8 | ||
rev: 7.0.0 | ||
hooks: | ||
- id: flake8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
FROM python:3.10-slim-buster | ||
|
||
LABEL maintainer="TC Developers" | ||
|
||
ENV PYTHONUNBUFFERED 1 | ||
|
||
WORKDIR /code | ||
|
||
RUN apt-get update -y && \ | ||
rm -rf /var/lib/apt/lists/* | ||
|
||
COPY pyproject.toml poetry.lock /code/ | ||
|
||
# Upgrade pip and install python packages for code | ||
RUN pip install --upgrade --no-cache-dir pip poetry \ | ||
&& poetry --version \ | ||
# Configure to use system instead of virtualenvs | ||
&& poetry config virtualenvs.create false \ | ||
&& poetry install --no-root \ | ||
# Remove installer | ||
&& rm -rf /root/.cache/pypoetry \ | ||
&& pip uninstall -y poetry virtualenv-clone virtualenv | ||
|
||
COPY . /code/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from enum import Enum | ||
from typing import List, Optional, Union | ||
|
||
from dotenv import load_dotenv | ||
from fastapi import FastAPI | ||
from pydantic import BaseModel | ||
|
||
from embedding_models import ( | ||
OllamaEmbeddingModel, | ||
OpenAIEmbeddingModel, | ||
SentenceTransformerEmbeddingModel, | ||
) | ||
|
||
load_dotenv() | ||
|
||
app = FastAPI() | ||
|
||
|
||
class EmbeddingModelType(Enum): | ||
""" | ||
Embedding model types | ||
""" | ||
|
||
SENTENCE_TRANSFORMERS = 1 | ||
OLLAMA = 2 | ||
OPENAI = 3 | ||
|
||
|
||
class RequestSchemaForEmbeddings(BaseModel): | ||
"""Request Schema""" | ||
|
||
type_model: EmbeddingModelType | ||
name_model: str | ||
texts: Union[str, List[str]] | ||
base_url: Optional[str] = None | ||
|
||
|
||
@app.get("/") | ||
async def home(): | ||
return "Embedding handler using models for texts", 200 | ||
|
||
|
||
@app.post("/get_embeddings") | ||
async def generate_embeddings(item: RequestSchemaForEmbeddings): | ||
""" | ||
Generates the embedding vectors for the text/documents | ||
based on different models | ||
""" | ||
type_model = item.type_model | ||
name_model = item.name_model | ||
base_url = item.base_url | ||
texts = item.texts | ||
|
||
def generate(em_model, texts): | ||
if isinstance(texts, str): | ||
return em_model.embed_query(text=texts) | ||
elif isinstance(texts, list): | ||
return em_model.embed_documents(texts=texts) | ||
return None | ||
|
||
if type_model == EmbeddingModelType.SENTENCE_TRANSFORMERS: | ||
embedding_model = SentenceTransformerEmbeddingModel(model=name_model) | ||
return generate(em_model=embedding_model, texts=texts) | ||
|
||
elif type_model == EmbeddingModelType.OLLAMA: | ||
embedding_model = OllamaEmbeddingModel(model=name_model, base_url=base_url) | ||
return generate(em_model=embedding_model, texts=texts) | ||
|
||
elif type_model == EmbeddingModelType.OPENAI: | ||
embedding_model = OpenAIEmbeddingModel(model=name_model) | ||
return generate(em_model=embedding_model, texts=texts) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
services: | ||
embedding_model: | ||
build: . | ||
volumes: | ||
- embedding_models:/opt/models | ||
command: bash -c 'uvicorn app:app --host=0.0.0.0 --port=8000' | ||
ports: | ||
- "8000:8000" | ||
restart: always | ||
|
||
volumes: | ||
embedding_models: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
from dataclasses import dataclass, field | ||
from typing import List | ||
|
||
import numpy as np | ||
from langchain_community.embeddings import OllamaEmbeddings | ||
from langchain_core.embeddings import Embeddings | ||
from langchain_openai import OpenAIEmbeddings | ||
from sentence_transformers import SentenceTransformer | ||
from torch import Tensor | ||
|
||
from utils import download_models | ||
|
||
|
||
@dataclass | ||
class SentenceTransformerEmbeddingModel(Embeddings): | ||
""" | ||
Embedding model using Sentence Transformers | ||
""" | ||
|
||
model: str | ||
embedding_model: SentenceTransformer = field(init=False) | ||
|
||
def __post_init__(self): | ||
""" | ||
Post initialization | ||
""" | ||
models_info = download_models(sent_embedding_model=self.model) | ||
self.st_embedding_model = SentenceTransformer(model_name_or_path=models_info["model_path"]) | ||
|
||
def embed_documents(self, texts: list) -> np.ndarray: | ||
""" | ||
Generate embeddings for a list of documents | ||
""" | ||
v_representation = self.st_embedding_model.encode(texts) | ||
return v_representation.tolist() | ||
|
||
def embed_query(self, text: str) -> np.ndarray: | ||
""" | ||
Generate embedding for a piece of text | ||
""" | ||
v_representation = self.st_embedding_model.encode(text) | ||
return v_representation.tolist() | ||
|
||
def check_similarity(self, embeddings_1: np.ndarray, embeddings_2: np.ndarray) -> Tensor: | ||
""" | ||
Computes the cosine similarity between two embeddings | ||
""" | ||
return self.st_embedding_model.similarity(embeddings_1, embeddings_2) | ||
|
||
def get_model(self): | ||
"""Returns the model""" | ||
return self.st_embedding_model | ||
|
||
|
||
@dataclass | ||
class OllamaEmbeddingModel(Embeddings): | ||
""" | ||
Embedding model using Ollama (locally deployed) | ||
""" | ||
|
||
model: str | ||
base_url: str | ||
|
||
def __post_init__(self): | ||
""" | ||
Post initialization | ||
""" | ||
self.ollama_embed_model = OllamaEmbeddings(model=self.model, base_url=self.base_url) | ||
|
||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
""" | ||
Generate embeddings for a list of documents. | ||
""" | ||
return self.ollama_embed_model.embed_documents(texts) | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
""" | ||
Generate embedding for a piece of text | ||
""" | ||
return self.ollama_embed_model.embed_query(text=text) | ||
|
||
def get_model(self): | ||
"""Returns the model""" | ||
return self.ollama_embed_model | ||
|
||
|
||
@dataclass | ||
class OpenAIEmbeddingModel(Embeddings): | ||
""" | ||
Embedding Model using OpenAI | ||
""" | ||
|
||
model: str | ||
|
||
def __post_init__(self): | ||
""" | ||
Post initialization | ||
""" | ||
self.openai_embed_model = OpenAIEmbeddings(model=self.model) | ||
|
||
def embed_documents(self, texts: List[str]): | ||
""" | ||
Generate embeddings for a list of documents. | ||
""" | ||
return self.openai_embed_model.embed_documents(texts=texts) | ||
|
||
def embed_query(self, text: str): | ||
""" | ||
Generate embedding for a piece of text | ||
""" | ||
return self.openai_embed_model.embed_query(text=text) | ||
|
||
def get_model(self): | ||
"""Returns the model""" | ||
return self.openai_embed_model |
Oops, something went wrong.