Skip to content

Commit

Permalink
Merge pull request #1 from toggle-corp/feat/embed-model
Browse files Browse the repository at this point in the history
Embedding models handler
  • Loading branch information
udaynwa authored Sep 13, 2024
2 parents 6760dc3 + 8ce2d77 commit b7c5e23
Show file tree
Hide file tree
Showing 11 changed files with 3,106 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
venv
.venv
1 change: 1 addition & 0 deletions .env.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
OPENAI_API_KEY=
9 changes: 9 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
[flake8]
extend-ignore = C901, W504
max-line-length = 125
# NOTE: Update in .pre-commit-config.yaml as well
extend-exclude = .git,__pycache__,old,build,dist,*/migrations/*.py,legacy/,.venv
max-complexity = 10
per-file-ignores =
/**/tests/*_mock_data.py: E501
**/snap_test_*.py: E501
31 changes: 31 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
default_language_version:
python: python3

# NOTE: Update in .flake8 pyproject.toml as well
exclude: |
(?x)^(
\.git|
__pycache__|
.*snap_test_.*\.py|
.+\/.+\/migrations\/.*|
\.venv
)
repos:
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
# args: ["--check"]

- repo: https://github.com/PyCQA/isort
rev: 5.13.2
hooks:
- id: isort
# args: ["--check"]

- repo: https://github.com/PyCQA/flake8
rev: 7.0.0
hooks:
- id: flake8
24 changes: 24 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
FROM python:3.10-slim-buster

LABEL maintainer="TC Developers"

ENV PYTHONUNBUFFERED 1

WORKDIR /code

RUN apt-get update -y && \
rm -rf /var/lib/apt/lists/*

COPY pyproject.toml poetry.lock /code/

# Upgrade pip and install python packages for code
RUN pip install --upgrade --no-cache-dir pip poetry \
&& poetry --version \
# Configure to use system instead of virtualenvs
&& poetry config virtualenvs.create false \
&& poetry install --no-root \
# Remove installer
&& rm -rf /root/.cache/pypoetry \
&& pip uninstall -y poetry virtualenv-clone virtualenv

COPY . /code/
71 changes: 71 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from enum import Enum
from typing import List, Optional, Union

from dotenv import load_dotenv
from fastapi import FastAPI
from pydantic import BaseModel

from embedding_models import (
OllamaEmbeddingModel,
OpenAIEmbeddingModel,
SentenceTransformerEmbeddingModel,
)

load_dotenv()

app = FastAPI()


class EmbeddingModelType(Enum):
"""
Embedding model types
"""

SENTENCE_TRANSFORMERS = 1
OLLAMA = 2
OPENAI = 3


class RequestSchemaForEmbeddings(BaseModel):
"""Request Schema"""

type_model: EmbeddingModelType
name_model: str
texts: Union[str, List[str]]
base_url: Optional[str] = None


@app.get("/")
async def home():
return "Embedding handler using models for texts", 200


@app.post("/get_embeddings")
async def generate_embeddings(item: RequestSchemaForEmbeddings):
"""
Generates the embedding vectors for the text/documents
based on different models
"""
type_model = item.type_model
name_model = item.name_model
base_url = item.base_url
texts = item.texts

def generate(em_model, texts):
if isinstance(texts, str):
return em_model.embed_query(text=texts)
elif isinstance(texts, list):
return em_model.embed_documents(texts=texts)
return None

if type_model == EmbeddingModelType.SENTENCE_TRANSFORMERS:
embedding_model = SentenceTransformerEmbeddingModel(model=name_model)
return generate(em_model=embedding_model, texts=texts)

elif type_model == EmbeddingModelType.OLLAMA:
embedding_model = OllamaEmbeddingModel(model=name_model, base_url=base_url)
return generate(em_model=embedding_model, texts=texts)

elif type_model == EmbeddingModelType.OPENAI:
embedding_model = OpenAIEmbeddingModel(model=name_model)
return generate(em_model=embedding_model, texts=texts)
12 changes: 12 additions & 0 deletions docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
services:
embedding_model:
build: .
volumes:
- embedding_models:/opt/models
command: bash -c 'uvicorn app:app --host=0.0.0.0 --port=8000'
ports:
- "8000:8000"
restart: always

volumes:
embedding_models:
115 changes: 115 additions & 0 deletions embedding_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from dataclasses import dataclass, field
from typing import List

import numpy as np
from langchain_community.embeddings import OllamaEmbeddings
from langchain_core.embeddings import Embeddings
from langchain_openai import OpenAIEmbeddings
from sentence_transformers import SentenceTransformer
from torch import Tensor

from utils import download_models


@dataclass
class SentenceTransformerEmbeddingModel(Embeddings):
"""
Embedding model using Sentence Transformers
"""

model: str
embedding_model: SentenceTransformer = field(init=False)

def __post_init__(self):
"""
Post initialization
"""
models_info = download_models(sent_embedding_model=self.model)
self.st_embedding_model = SentenceTransformer(model_name_or_path=models_info["model_path"])

def embed_documents(self, texts: list) -> np.ndarray:
"""
Generate embeddings for a list of documents
"""
v_representation = self.st_embedding_model.encode(texts)
return v_representation.tolist()

def embed_query(self, text: str) -> np.ndarray:
"""
Generate embedding for a piece of text
"""
v_representation = self.st_embedding_model.encode(text)
return v_representation.tolist()

def check_similarity(self, embeddings_1: np.ndarray, embeddings_2: np.ndarray) -> Tensor:
"""
Computes the cosine similarity between two embeddings
"""
return self.st_embedding_model.similarity(embeddings_1, embeddings_2)

def get_model(self):
"""Returns the model"""
return self.st_embedding_model


@dataclass
class OllamaEmbeddingModel(Embeddings):
"""
Embedding model using Ollama (locally deployed)
"""

model: str
base_url: str

def __post_init__(self):
"""
Post initialization
"""
self.ollama_embed_model = OllamaEmbeddings(model=self.model, base_url=self.base_url)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""
Generate embeddings for a list of documents.
"""
return self.ollama_embed_model.embed_documents(texts)

def embed_query(self, text: str) -> List[float]:
"""
Generate embedding for a piece of text
"""
return self.ollama_embed_model.embed_query(text=text)

def get_model(self):
"""Returns the model"""
return self.ollama_embed_model


@dataclass
class OpenAIEmbeddingModel(Embeddings):
"""
Embedding Model using OpenAI
"""

model: str

def __post_init__(self):
"""
Post initialization
"""
self.openai_embed_model = OpenAIEmbeddings(model=self.model)

def embed_documents(self, texts: List[str]):
"""
Generate embeddings for a list of documents.
"""
return self.openai_embed_model.embed_documents(texts=texts)

def embed_query(self, text: str):
"""
Generate embedding for a piece of text
"""
return self.openai_embed_model.embed_query(text=text)

def get_model(self):
"""Returns the model"""
return self.openai_embed_model
Loading

0 comments on commit b7c5e23

Please sign in to comment.