Skip to content

Commit

Permalink
OpenAI embeddings for vector search (#1869)
Browse files Browse the repository at this point in the history
* pinning onnx runtime dep and adding openai api key setting

* adding util method to generate points

* adding docstring

* adding initial fastembed and openai encoder classes

* adding working encoder for search and similarity

* adding dummy encoder and fixing tests

* fixing model name generation in qdrant collection

* switch to using dummy encoder

* adding daily task to embed new learning resources

* adding test for new celery task

* adjusting test

* fixing openai embedding method

* fixing embed method

* fixing test flakiness

* adding litellm encoder

* adding docstrings and moving dummy encoder into conftest

* adding litellm dependency

* added ability to filter out existing embedded resources and bumped up frequency of new resource embedding task

* fix tests

* adding tests for creating qdrant collections

* fixing tests and removing period from daily embedding task

* remove period from celery task def

* fix encoe

* fix encode method

* fix encode method

* fixing scroll

* fixing scroll in test

* fixing scroll in test
  • Loading branch information
shanbady authored Dec 5, 2024
1 parent 70057a0 commit 9b34340
Show file tree
Hide file tree
Showing 17 changed files with 1,405 additions and 201 deletions.
5 changes: 3 additions & 2 deletions learning_resources_search/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,16 +922,17 @@ def _qdrant_similar_results(doc, num_resources):
list of dict:
list of serialized resources
"""
from vector_search.utils import qdrant_client, vector_point_id
from vector_search.utils import dense_encoder, qdrant_client, vector_point_id

encoder = dense_encoder()
client = qdrant_client()
return [
hit.payload
for hit in client.query_points(
collection_name=f"{settings.QDRANT_BASE_COLLECTION_NAME}.resources",
query=vector_point_id(doc["readable_id"]),
limit=num_resources,
using=settings.QDRANT_SEARCH_VECTOR_NAME,
using=encoder.model_short_name(),
).points
]

Expand Down
15 changes: 9 additions & 6 deletions main/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,13 +795,16 @@ def get_all_config_keys():
name="QDRANT_COLLECTION_NAME", default="resource_embeddings"
)

QDRANT_SEARCH_VECTOR_NAME = get_string(
name="QDRANT_SEARCH_VECTOR_NAME", default="fast-bge-small-en"
)

QDRANT_DENSE_MODEL = get_string(
name="QDRANT_DENSE_MODEL", default="sentence-transformers/all-MiniLM-L6-v2"
)
QDRANT_DENSE_MODEL = get_string(name="QDRANT_DENSE_MODEL", default=None)
QDRANT_SPARSE_MODEL = get_string(
name="QDRANT_SPARSE_MODEL", default="prithivida/Splade_PP_en_v1"
)
QDRANT_ENCODER = get_string(
name="QDRANT_ENCODER", default="vector_search.encoders.fastembed.FastEmbedEncoder"
)

OPENAI_API_KEY = get_string(
name="OPENAI_API_KEY",
default=None,
)
6 changes: 6 additions & 0 deletions main/settings_celery.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,12 @@
"schedule": crontab(minute=30, hour=18), # 2:30pm EST
"kwargs": {"period": "daily", "subscription_type": "channel_subscription_type"},
},
"daily_embed_new_learning_resources": {
"task": "vector_search.tasks.embed_new_learning_resources",
"schedule": get_int(
"EMBED_NEW_RESOURCES_SCHEDULE_SECONDS", 60 * 30
), # default is every 30 minutes
},
"send-search-subscription-emails-every-1-days": {
"task": "learning_resources_search.tasks.send_subscription_emails",
"schedule": crontab(minute=0, hour=19), # 3:00pm EST
Expand Down
1,100 changes: 935 additions & 165 deletions poetry.lock

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ uwsgitop = "^0.12"
pytest-lazy-fixtures = "^1.1.1"
pycountry = "^24.6.1"
qdrant-client = {extras = ["fastembed"], version = "^1.12.0"}
onnxruntime = "1.19.2"
openai = "^1.55.3"
litellm = "^1.53.5"


[tool.poetry.group.dev.dependencies]
Expand Down
39 changes: 39 additions & 0 deletions vector_search/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import numpy as np
import pytest

from vector_search.encoders.base import BaseEncoder


class DummyEmbedEncoder(BaseEncoder):
"""
A dummy encoder that returns random vectors
"""

def __init__(self, model_name="dummy-embedding"):
self.model_name = model_name

def encode(self, text: str) -> list: # noqa: ARG002
return np.random.random((10, 1))

def encode_batch(self, texts: list[str]) -> list[list[float]]:
return np.random.random((10, len(texts)))


@pytest.fixture(autouse=True)
def _use_dummy_encoder(settings):
settings.QDRANT_ENCODER = "vector_search.conftest.DummyEmbedEncoder"


@pytest.fixture(autouse=True)
def _use_test_qdrant_settings(settings, mocker):
settings.QDRANT_HOST = "https://test"
settings.QDRANT_BASE_COLLECTION_NAME = "test"
mock_qdrant = mocker.patch("qdrant_client.QdrantClient")
mock_qdrant.scroll.return_value = [
[],
None,
]
mocker.patch(
"vector_search.utils.qdrant_client",
return_value=mock_qdrant,
)
Empty file.
37 changes: 37 additions & 0 deletions vector_search/encoders/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from abc import ABC, abstractmethod


class BaseEncoder(ABC):
"""
Base encoder class
"""

def model_short_name(self):
"""
Return the short name of the model
used as the vector name in qdrant
"""
split_model_name = self.model_name.split("/")
model_name = self.model_name
if len(split_model_name) > 1:
model_name = split_model_name[1]
return model_name

def encode(self, text):
"""
Embed a single text
"""
return next(iter(self.encode_batch([text])))

@abstractmethod
def encode_batch(self, texts: list[str]) -> list[list[float]]:
"""
Embed multiple texts
"""
return [self.encode(text) for text in texts]

def dim(self):
"""
Return the dimension of the embeddings
"""
return len(self.encode("test"))
27 changes: 27 additions & 0 deletions vector_search/encoders/fastembed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from fastembed import TextEmbedding

from vector_search.encoders.base import BaseEncoder


class FastEmbedEncoder(BaseEncoder):
"""
FastEmbed encoder
"""

def __init__(self, model_name="BAAI/bge-small-en-v1.5"):
self.model_name = model_name
self.model = TextEmbedding(model_name=model_name, lazy_load=True)

def encode_batch(self, texts: list[str]) -> list[list[float]]:
return self.model.embed(texts)

def dim(self):
"""
Return the dimension of the embeddings
"""
supported_models = [
model_config
for model_config in self.model.list_supported_models()
if model_config["model"] == self.model.model_name
]
return supported_models[0]["dim"]
20 changes: 20 additions & 0 deletions vector_search/encoders/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from litellm import embedding

from vector_search.encoders.base import BaseEncoder


class LiteLLMEncoder(BaseEncoder):
"""
LiteLLM encoder
"""

def __init__(self, model_name="text-embedding-3-small"):
self.model_name = model_name

def encode_batch(self, texts: list[str]) -> list[list[float]]:
return [
result["embedding"]
for result in embedding(model=self.model_name, input=texts).to_dict()[
"data"
]
]
12 changes: 12 additions & 0 deletions vector_search/encoders/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from django.conf import settings
from django.utils.module_loading import import_string


def dense_encoder():
"""
Return the dense encoder based on settings
"""
Encoder = import_string(settings.QDRANT_ENCODER)
if settings.QDRANT_DENSE_MODEL:
return Encoder(model_name=settings.QDRANT_DENSE_MODEL)
return Encoder()
4 changes: 2 additions & 2 deletions vector_search/management/commands/generate_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from django.core.management.base import BaseCommand, CommandError

from learning_resources_search.constants import LEARNING_RESOURCE_TYPES
from main.utils import now_in_utc
from main.utils import clear_search_cache, now_in_utc
from vector_search.tasks import start_embed_resources
from vector_search.utils import (
create_qdrand_collections,
Expand Down Expand Up @@ -81,7 +81,7 @@ def handle(self, *args, **options): # noqa: ARG002
if error:
msg = f"Geenerate embeddings errored: {error}"
raise CommandError(msg)

clear_search_cache()
total_seconds = (now_in_utc() - start).total_seconds()
self.stdout.write(
f"Embeddings generated and stored, took {total_seconds} seconds"
Expand Down
29 changes: 28 additions & 1 deletion vector_search/tasks.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import datetime
import logging

import celery
Expand Down Expand Up @@ -27,8 +28,9 @@
from main.celery import app
from main.utils import (
chunks,
now_in_utc,
)
from vector_search.utils import embed_learning_resources
from vector_search.utils import embed_learning_resources, filter_existing_qdrant_points

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -148,3 +150,28 @@ def start_embed_resources(self, indexes, skip_content_files):
# Use self.replace so that code waiting on this task will also wait on the embedding
# and finish tasks
return self.replace(celery.chain(*index_tasks))


@app.task(bind=True)
def embed_new_learning_resources(self):
"""
Embed new resources from the last day
"""
log.info("Running new resource embedding task")
delta = datetime.timedelta(days=1)
since = now_in_utc() - delta
new_learning_resources = LearningResource.objects.filter(
published=True,
created_on__gt=since,
).exclude(resource_type=CONTENT_FILE_TYPE)
filtered_resources = filter_existing_qdrant_points(new_learning_resources)
embed_tasks = celery.group(
[
generate_embeddings.si(ids, COURSE_TYPE)
for ids in chunks(
filtered_resources.order_by("id").values_list("id", flat=True),
chunk_size=settings.OPENSEARCH_INDEXING_CHUNK_SIZE,
)
]
)
return self.replace(embed_tasks)
49 changes: 46 additions & 3 deletions vector_search/tasks_test.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import datetime

import pytest
from django.conf import settings

from learning_resources.etl.constants import ETLSource
from learning_resources.factories import (
ContentFileFactory,
CourseFactory,
LearningResourceFactory,
ProgramFactory,
)
from learning_resources.models import LearningResource
from learning_resources_search.constants import (
COURSE_TYPE,
)
from vector_search.tasks import start_embed_resources
from main.utils import now_in_utc
from vector_search.tasks import embed_new_learning_resources, start_embed_resources

pytestmark = pytest.mark.django_db

Expand All @@ -23,8 +28,7 @@ def test_start_embed_resources(mocker, mocked_celery, index):
"""
start_embed_resources should generate embeddings for each resource type
"""
settings.QDRANT_HOST = "http://test"
settings.QDRANT_BASE_COLLECTION_NAME = "test"

mocker.patch("vector_search.tasks.load_course_blocklist", return_value=[])

if index == COURSE_TYPE:
Expand Down Expand Up @@ -95,3 +99,42 @@ def test_start_embed_resources_without_settings(mocker, mocked_celery, index):
start_embed_resources.delay([index], skip_content_files=True)

generate_embeddings_mock.si.assert_not_called()


def test_embed_new_learning_resources(mocker, mocked_celery):
"""
embed_new_learning_resources should generate embeddings for new resources
based on the period
"""
mocker.patch("vector_search.tasks.load_course_blocklist", return_value=[])

daily_since = now_in_utc() - datetime.timedelta(hours=5)

LearningResourceFactory.create_batch(
4, created_on=daily_since, resource_type=COURSE_TYPE, published=True
)
# create resources older than a day
LearningResourceFactory.create_batch(
4,
created_on=now_in_utc() - datetime.timedelta(days=5),
resource_type=COURSE_TYPE,
published=True,
)

daily_resource_ids = [
resource.id
for resource in LearningResource.objects.filter(
created_on__gt=now_in_utc() - datetime.timedelta(days=1)
)
]

generate_embeddings_mock = mocker.patch(
"vector_search.tasks.generate_embeddings", autospec=True
)

with pytest.raises(mocked_celery.replace_exception_class):
embed_new_learning_resources.delay()
list(mocked_celery.group.call_args[0][0])

embedded_ids = generate_embeddings_mock.si.mock_calls[0].args[0]
assert sorted(daily_resource_ids) == sorted(embedded_ids)
Loading

0 comments on commit 9b34340

Please sign in to comment.