Skip to content

Commit

Permalink
[Embedding] Support Ollama embedding (#81)
Browse files Browse the repository at this point in the history
* [emb] Support ollama as embedding provider
*  A new embedding_utils for easier filtering/sorting relevant candidates according to the metric type
* Fix new ollama emb collections cleanup
* Enhance logging
* update README
  • Loading branch information
finaldie authored Jul 17, 2024
1 parent 6c57e91 commit 5177087
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 17 deletions.
13 changes: 10 additions & 3 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,22 @@ OLLAMA_MODEL=llama3
OLLAMA_URL=http://localhost:11434

# The generic Text embedding provider. Supported providers:
# openai, hf, hf_inst
# openai, hf, hf_inst, ollama
EMBEDDING_PROVIDER=openai

# models: text-embedding-ada-002, text-embedding-3-small
# models
# - openai: text-embedding-ada-002, text-embedding-3-small, ...
# - ollama: nomic-embed-text, ...
EMBEDDING_MODEL=text-embedding-ada-002

EMBEDDING_MAX_LENGTH=5000

TEXT_CHUNK_SIZE=10240
TEXT_CHUNK_OVERLAP=256

# For any summary, specific the translation language if needed
TRANSLATION_LANG=

EMBEDDING_MAX_LENGTH=5000
SUMMARY_MAX_LENGTH=20000

#########################################
Expand Down Expand Up @@ -100,6 +103,10 @@ RSS_ENABLE_CLASSIFICATION=false
# Milvus database
#########################################
MILVUS_HOST=milvus-standalone
MILVUS_PORT=19530

# L2, IP, COSINE
MILVUS_SIMILARITY_METRICS=L2

#########################################
# MySQL database
Expand Down
13 changes: 10 additions & 3 deletions .env.template.k8s
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,22 @@ OLLAMA_MODEL=llama3
OLLAMA_URL=http://localhost:11434

# The generic Text embedding provider. Supported providers:
# openai, hf, hf_inst
# openai, hf, hf_inst, ollama
EMBEDDING_PROVIDER=openai

# models: text-embedding-ada-002, text-embedding-3-small
# models:
# - openai: text-embedding-ada-002, text-embedding-3-small, ...
# - ollama: nomic-embed-text, ...
EMBEDDING_MODEL=text-embedding-ada-002

EMBEDDING_MAX_LENGTH=5000

TEXT_CHUNK_SIZE=10240
TEXT_CHUNK_OVERLAP=256

# For any summary, specific the translation language if needed
TRANSLATION_LANG=

EMBEDDING_MAX_LENGTH=5000
SUMMARY_MAX_LENGTH=20000

#########################################
Expand Down Expand Up @@ -100,6 +103,10 @@ RSS_ENABLE_CLASSIFICATION=false
# Milvus database
#########################################
MILVUS_HOST=auto-news-milvus
MILVUS_PORT=19530

# L2, IP, COSINE
MILVUS_SIMILARITY_METRICS=L2

#########################################
# MySQL database
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ In the AI era, speed and productivity are extremely important. We need AI tools

For more background, see this [Blog post](https://finaldie.com/blog/auto-news-an-automated-news-aggregator-with-llm/) and these videos [Introduction](https://www.youtube.com/watch?v=hKFIyfAF4Z4), [Data flows](https://www.youtube.com/watch?v=WAGlnRht8LE).

https://github.com/finaldie/auto-news/assets/1088543/4387f688-61d3-4270-b5a6-105aa8ee0ea9
[<img src="https://img.youtube.com/vi/hKFIyfAF4Z4/0.jpg" width="80%" />](https://www.youtube.com/watch?v=hKFIyfAF4Z4 "AutoNews Intro on YouTube")

## Features
- Aggregate feed sources (including RSS, Reddit, Tweets, etc), and proactive generate with insights
Expand Down
2 changes: 1 addition & 1 deletion helm/values.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ airflow:
images:
airflow:
repository: finaldie/auto-news
tag: 0.9.10
tag: 0.9.11

useDefaultImageForMigration: true

Expand Down
7 changes: 6 additions & 1 deletion src/embedding_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from embedding_openai import EmbeddingOpenAI
from embedding_hf import EmbeddingHuggingFace
from embedding_hf_inst import EmbeddingHuggingFaceInstruct
from embedding_ollama import EmbeddingOllama


class EmbeddingAgent:
Expand Down Expand Up @@ -30,8 +31,12 @@ def __init__(

elif self.provider == "hkunlp/instructor-xl":
self.model = EmbeddingHuggingFaceInstruct(model_name=self.model_name)

elif self.provider == "ollama":
self.model = EmbeddingOllama(model_name=self.model_name)

else:
print(f"[ERROR] Unknown embedding model: {self.model_name}")
print(f"[ERROR] Unknown embedding provider: {self.provider}")
return None

def dim(self):
Expand Down
125 changes: 125 additions & 0 deletions src/embedding_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import os
import json
import time

import numpy as np

from embedding import Embedding
from langchain_community.embeddings import OllamaEmbeddings
import utils


class EmbeddingOllama(Embedding):
"""
Embedding via Ollama
"""
def __init__(self, model_name="nomic-embed-text", base_url=""):
super().__init__(model_name)

self.base_url = base_url or os.getenv("OLLAMA_URL")
self.dimensions = -1

self.client = OllamaEmbeddings(
base_url=self.base_url,
model=self.model_name,
)

print(f"Initialized EmbeddingOllama: model_name: {self.model_name}, base_url: {self.base_url}")

def dim(self):
if self.dimensions > 0:
return self.dimensions

text = "This is a test query"
query_result = self.client.embed_query(text)
self.dimensions = len(query_result)
return self.dimensions

def getname(self, start_date, prefix="ollama"):
"""
Get a embedding collection name of milvus
"""
return f"embedding__{prefix}__ollama_{self.model_name}__{start_date}".replace("-", "_")

def create(
self,
text: str,
num_retries=3,
retry_wait_time=0.5,
error_wait_time=0.5,

# ollama embedding query result is not normalized, for most
# of the vector database would suggest us do the normalization
# first before inserting into the vector database
# here, we can apply a post-step for the normalization
normalize=True,
):
emb = None

for i in range(1, num_retries + 1):
try:
emb = self.client.embed_query(text)

if normalize:
emb = (np.array(emb) / np.linalg.norm(emb)).tolist()

break

except Exception as e:
print(f"[ERROR] APIError during embedding ({i}/{num_retries}): {e}")

if i == num_retries:
raise

time.sleep(error_wait_time)

return emb

def get_or_create(
self,
text: str,
source="",
page_id="",
db_client=None,
key_ttl=86400 * 30
):
"""
Get embedding from cache (or create if not exist)
"""
client = db_client
embedding = None

if client:
# Tips: the quickest way to get rid of all previous
# cache, change the provider (1st arg)
embedding = client.get_milvus_embedding_item_id(
"ollama-norm",
self.model_name,
source,
page_id)

if embedding:
print("[EmbeddingOllama] Embedding got from cache")
return utils.fix_and_parse_json(embedding)

# Not found in cache, generate one
print("[EmbeddingOllama] Embedding not found, create a new one and cache it")

# Most of the emb models have 8k tokens, exceed it will
# throw exceptions. Here we simply limited it <= 5000 chars
# for the input

EMBEDDING_MAX_LENGTH = int(os.getenv("EMBEDDING_MAX_LENGTH", 5000))
embedding = self.create(text[:EMBEDDING_MAX_LENGTH])

# store embedding into redis (ttl = 1 month)
if client:
client.set_milvus_embedding_item_id(
"ollama-norm",
self.model_name,
source,
page_id,
json.dumps(embedding),
expired_time=key_ttl)

return embedding
66 changes: 66 additions & 0 deletions src/embedding_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
###############################################################################
# Embedding Utils
###############################################################################

def similarity_topk(embedding_items: list, metric_type, threshold=None, k=3):
"""
@param embedding_items [{item_id, distance}, ...]
@param metric_type L2, IP, COSINE
@threshold to filter the result
@k max number of returns
"""
if metric_type == "L2":
return similarity_topk_l2(embedding_items, threshold, k)
elif metric_type in ("IP", "COSINE"):
# assume IP type all embeddings has been normalized
return similarity_topk_cosine(embedding_items, threshold, k)
else:
raise Exception(f"Unknown metric_type: {metric_type}")


def similarity_topk_l2(items: list, threshold, k):
"""
metric_type L2, the value range [0, +inf)
* The smaller (Close to 0), the more similiar
* The larger, the less similar
so, we will filter in distance <= threshold first, then get top-k
"""
valid_items = items

if threshold is not None:
valid_items = [x for x in items if x["distance"] <= threshold]

# sort in ASC
sorted_items = sorted(
valid_items,
key=lambda item: item["distance"],
)

# The returned value is sorted by most similar -> least similar
return sorted_items[:k]


def similarity_topk_cosine(items: list, threshold, k):
"""
metric_type IP (normalized) or COSINE, the value range [-1, 1]
* 1 indicates that the vectors are identical in direction.
* 0 indicates orthogonality (no similarity in direction).
* -1 indicates that the vectors are opposite in direction.
so, we will filter in distance >= threshold first, then get top-k
"""
valid_items = items

if threshold is not None:
valid_items = [x for x in items if x["distance"] >= threshold]

# sort in DESC
sorted_items = sorted(
valid_items,
key=lambda item: item["distance"],
reverse=True,
)

# The returned value is sorted by most similar -> least similar
return sorted_items[:k]
7 changes: 5 additions & 2 deletions src/milvus_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ def createCollection(
name="embedding_table",
desc="embeddings",
dim=1536,
distance_metric="L2",
distance_metric="",
):
distance_metric = distance_metric or os.getenv("MILVUS_SIMILARITY_METRICS", "L2")

# Create table schema
self.fields = [
FieldSchema(name="pk",
Expand Down Expand Up @@ -149,9 +151,10 @@ def get(
topk=1,
fallback=None,
emb=None,
distance_metric="L2",
distance_metric="",
timeout=60, # timeout (unit second)
):
distance_metric = distance_metric or os.getenv("MILVUS_SIMILARITY_METRICS", "L2")
collection = None

try:
Expand Down
17 changes: 11 additions & 6 deletions src/ops_milvus.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import json
import copy
import traceback
Expand All @@ -7,6 +8,7 @@
from notion import NotionAgent
from milvus_cli import MilvusClient
from embedding_agent import EmbeddingAgent
import embedding_utils as emb_utils
import utils


Expand Down Expand Up @@ -154,22 +156,25 @@ def get_relevant(
db_client=client,
key_ttl=key_ttl)

# response_arr: [{item_id, distance}, ...]
response_arr = milvus_client.get(
collection_name, text, topk=topk,
fallback=fallback, emb=embedding)

# filter by distance (similiarity value) according to the
# metrics type
metric_type = os.getenv("MILVUS_SIMILARITY_METRICS", "L2")
valid_embs = emb_utils.similarity_topk(response_arr, metric_type, max_distance, topk)
print(f"[get_relevant] metric_type: {metric_type}, max_distance: {max_distance}, raw emb response_arr size: {len(response_arr)}, post emb_utils.topk: {len(valid_embs)}")

res = []

for response in response_arr:
for response in valid_embs:
print(f"[get_relevant] Processing response: {response}")

page_id = response["item_id"]
distance = response["distance"]

if distance > max_distance:
print(f"[get_relevant] Filtered it out due to the distance: {distance} > max_distance {max_distance}, page_id: {page_id}")
continue

page_metadata = client.get_page_item_id(page_id)

if not page_metadata:
Expand Down Expand Up @@ -301,7 +306,7 @@ def clear(self, cleanup_date):
print(f"Collections: {collections}")

for name in collections:
suffix = name.split("__")[1]
suffix = name.split("__")[-1]
dt = date.fromisoformat(suffix.replace("_", "-"))
stats = milvus_client.get_stats(name)

Expand Down

0 comments on commit 5177087

Please sign in to comment.