From 5177087dba976645f1e20a16e002717d06ee2750 Mon Sep 17 00:00:00 2001 From: Yuzhang Hu Date: Tue, 16 Jul 2024 18:18:58 -0700 Subject: [PATCH] [Embedding] Support Ollama embedding (#81) * [emb] Support ollama as embedding provider * A new embedding_utils for easier filtering/sorting relevant candidates according to the metric type * Fix new ollama emb collections cleanup * Enhance logging * update README --- .env.template | 13 ++++- .env.template.k8s | 13 ++++- README.md | 2 +- helm/values.yaml | 2 +- src/embedding_agent.py | 7 ++- src/embedding_ollama.py | 125 ++++++++++++++++++++++++++++++++++++++++ src/embedding_utils.py | 66 +++++++++++++++++++++ src/milvus_cli.py | 7 ++- src/ops_milvus.py | 17 ++++-- 9 files changed, 235 insertions(+), 17 deletions(-) create mode 100644 src/embedding_ollama.py create mode 100644 src/embedding_utils.py diff --git a/.env.template b/.env.template index c5469ee..0269f40 100644 --- a/.env.template +++ b/.env.template @@ -35,19 +35,22 @@ OLLAMA_MODEL=llama3 OLLAMA_URL=http://localhost:11434 # The generic Text embedding provider. Supported providers: -# openai, hf, hf_inst +# openai, hf, hf_inst, ollama EMBEDDING_PROVIDER=openai -# models: text-embedding-ada-002, text-embedding-3-small +# models +# - openai: text-embedding-ada-002, text-embedding-3-small, ... +# - ollama: nomic-embed-text, ... EMBEDDING_MODEL=text-embedding-ada-002 +EMBEDDING_MAX_LENGTH=5000 + TEXT_CHUNK_SIZE=10240 TEXT_CHUNK_OVERLAP=256 # For any summary, specific the translation language if needed TRANSLATION_LANG= -EMBEDDING_MAX_LENGTH=5000 SUMMARY_MAX_LENGTH=20000 ######################################### @@ -100,6 +103,10 @@ RSS_ENABLE_CLASSIFICATION=false # Milvus database ######################################### MILVUS_HOST=milvus-standalone +MILVUS_PORT=19530 + +# L2, IP, COSINE +MILVUS_SIMILARITY_METRICS=L2 ######################################### # MySQL database diff --git a/.env.template.k8s b/.env.template.k8s index 1f9b237..a408197 100644 --- a/.env.template.k8s +++ b/.env.template.k8s @@ -35,19 +35,22 @@ OLLAMA_MODEL=llama3 OLLAMA_URL=http://localhost:11434 # The generic Text embedding provider. Supported providers: -# openai, hf, hf_inst +# openai, hf, hf_inst, ollama EMBEDDING_PROVIDER=openai -# models: text-embedding-ada-002, text-embedding-3-small +# models: +# - openai: text-embedding-ada-002, text-embedding-3-small, ... +# - ollama: nomic-embed-text, ... EMBEDDING_MODEL=text-embedding-ada-002 +EMBEDDING_MAX_LENGTH=5000 + TEXT_CHUNK_SIZE=10240 TEXT_CHUNK_OVERLAP=256 # For any summary, specific the translation language if needed TRANSLATION_LANG= -EMBEDDING_MAX_LENGTH=5000 SUMMARY_MAX_LENGTH=20000 ######################################### @@ -100,6 +103,10 @@ RSS_ENABLE_CLASSIFICATION=false # Milvus database ######################################### MILVUS_HOST=auto-news-milvus +MILVUS_PORT=19530 + +# L2, IP, COSINE +MILVUS_SIMILARITY_METRICS=L2 ######################################### # MySQL database diff --git a/README.md b/README.md index 84afab2..262d1fd 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,7 @@ In the AI era, speed and productivity are extremely important. We need AI tools For more background, see this [Blog post](https://finaldie.com/blog/auto-news-an-automated-news-aggregator-with-llm/) and these videos [Introduction](https://www.youtube.com/watch?v=hKFIyfAF4Z4), [Data flows](https://www.youtube.com/watch?v=WAGlnRht8LE). -https://github.com/finaldie/auto-news/assets/1088543/4387f688-61d3-4270-b5a6-105aa8ee0ea9 +[](https://www.youtube.com/watch?v=hKFIyfAF4Z4 "AutoNews Intro on YouTube") ## Features - Aggregate feed sources (including RSS, Reddit, Tweets, etc), and proactive generate with insights diff --git a/helm/values.yaml b/helm/values.yaml index b9adcfd..c661bfa 100644 --- a/helm/values.yaml +++ b/helm/values.yaml @@ -8,7 +8,7 @@ airflow: images: airflow: repository: finaldie/auto-news - tag: 0.9.10 + tag: 0.9.11 useDefaultImageForMigration: true diff --git a/src/embedding_agent.py b/src/embedding_agent.py index 15dbadc..0d6aa68 100644 --- a/src/embedding_agent.py +++ b/src/embedding_agent.py @@ -2,6 +2,7 @@ from embedding_openai import EmbeddingOpenAI from embedding_hf import EmbeddingHuggingFace from embedding_hf_inst import EmbeddingHuggingFaceInstruct +from embedding_ollama import EmbeddingOllama class EmbeddingAgent: @@ -30,8 +31,12 @@ def __init__( elif self.provider == "hkunlp/instructor-xl": self.model = EmbeddingHuggingFaceInstruct(model_name=self.model_name) + + elif self.provider == "ollama": + self.model = EmbeddingOllama(model_name=self.model_name) + else: - print(f"[ERROR] Unknown embedding model: {self.model_name}") + print(f"[ERROR] Unknown embedding provider: {self.provider}") return None def dim(self): diff --git a/src/embedding_ollama.py b/src/embedding_ollama.py new file mode 100644 index 0000000..a66827b --- /dev/null +++ b/src/embedding_ollama.py @@ -0,0 +1,125 @@ +import os +import json +import time + +import numpy as np + +from embedding import Embedding +from langchain_community.embeddings import OllamaEmbeddings +import utils + + +class EmbeddingOllama(Embedding): + """ + Embedding via Ollama + """ + def __init__(self, model_name="nomic-embed-text", base_url=""): + super().__init__(model_name) + + self.base_url = base_url or os.getenv("OLLAMA_URL") + self.dimensions = -1 + + self.client = OllamaEmbeddings( + base_url=self.base_url, + model=self.model_name, + ) + + print(f"Initialized EmbeddingOllama: model_name: {self.model_name}, base_url: {self.base_url}") + + def dim(self): + if self.dimensions > 0: + return self.dimensions + + text = "This is a test query" + query_result = self.client.embed_query(text) + self.dimensions = len(query_result) + return self.dimensions + + def getname(self, start_date, prefix="ollama"): + """ + Get a embedding collection name of milvus + """ + return f"embedding__{prefix}__ollama_{self.model_name}__{start_date}".replace("-", "_") + + def create( + self, + text: str, + num_retries=3, + retry_wait_time=0.5, + error_wait_time=0.5, + + # ollama embedding query result is not normalized, for most + # of the vector database would suggest us do the normalization + # first before inserting into the vector database + # here, we can apply a post-step for the normalization + normalize=True, + ): + emb = None + + for i in range(1, num_retries + 1): + try: + emb = self.client.embed_query(text) + + if normalize: + emb = (np.array(emb) / np.linalg.norm(emb)).tolist() + + break + + except Exception as e: + print(f"[ERROR] APIError during embedding ({i}/{num_retries}): {e}") + + if i == num_retries: + raise + + time.sleep(error_wait_time) + + return emb + + def get_or_create( + self, + text: str, + source="", + page_id="", + db_client=None, + key_ttl=86400 * 30 + ): + """ + Get embedding from cache (or create if not exist) + """ + client = db_client + embedding = None + + if client: + # Tips: the quickest way to get rid of all previous + # cache, change the provider (1st arg) + embedding = client.get_milvus_embedding_item_id( + "ollama-norm", + self.model_name, + source, + page_id) + + if embedding: + print("[EmbeddingOllama] Embedding got from cache") + return utils.fix_and_parse_json(embedding) + + # Not found in cache, generate one + print("[EmbeddingOllama] Embedding not found, create a new one and cache it") + + # Most of the emb models have 8k tokens, exceed it will + # throw exceptions. Here we simply limited it <= 5000 chars + # for the input + + EMBEDDING_MAX_LENGTH = int(os.getenv("EMBEDDING_MAX_LENGTH", 5000)) + embedding = self.create(text[:EMBEDDING_MAX_LENGTH]) + + # store embedding into redis (ttl = 1 month) + if client: + client.set_milvus_embedding_item_id( + "ollama-norm", + self.model_name, + source, + page_id, + json.dumps(embedding), + expired_time=key_ttl) + + return embedding diff --git a/src/embedding_utils.py b/src/embedding_utils.py new file mode 100644 index 0000000..4a3e95a --- /dev/null +++ b/src/embedding_utils.py @@ -0,0 +1,66 @@ +############################################################################### +# Embedding Utils +############################################################################### + +def similarity_topk(embedding_items: list, metric_type, threshold=None, k=3): + """ + @param embedding_items [{item_id, distance}, ...] + @param metric_type L2, IP, COSINE + @threshold to filter the result + @k max number of returns + """ + if metric_type == "L2": + return similarity_topk_l2(embedding_items, threshold, k) + elif metric_type in ("IP", "COSINE"): + # assume IP type all embeddings has been normalized + return similarity_topk_cosine(embedding_items, threshold, k) + else: + raise Exception(f"Unknown metric_type: {metric_type}") + + +def similarity_topk_l2(items: list, threshold, k): + """ + metric_type L2, the value range [0, +inf) + * The smaller (Close to 0), the more similiar + * The larger, the less similar + + so, we will filter in distance <= threshold first, then get top-k + """ + valid_items = items + + if threshold is not None: + valid_items = [x for x in items if x["distance"] <= threshold] + + # sort in ASC + sorted_items = sorted( + valid_items, + key=lambda item: item["distance"], + ) + + # The returned value is sorted by most similar -> least similar + return sorted_items[:k] + + +def similarity_topk_cosine(items: list, threshold, k): + """ + metric_type IP (normalized) or COSINE, the value range [-1, 1] + * 1 indicates that the vectors are identical in direction. + * 0 indicates orthogonality (no similarity in direction). + * -1 indicates that the vectors are opposite in direction. + + so, we will filter in distance >= threshold first, then get top-k + """ + valid_items = items + + if threshold is not None: + valid_items = [x for x in items if x["distance"] >= threshold] + + # sort in DESC + sorted_items = sorted( + valid_items, + key=lambda item: item["distance"], + reverse=True, + ) + + # The returned value is sorted by most similar -> least similar + return sorted_items[:k] diff --git a/src/milvus_cli.py b/src/milvus_cli.py index 57e0f55..4c78607 100644 --- a/src/milvus_cli.py +++ b/src/milvus_cli.py @@ -49,8 +49,10 @@ def createCollection( name="embedding_table", desc="embeddings", dim=1536, - distance_metric="L2", + distance_metric="", ): + distance_metric = distance_metric or os.getenv("MILVUS_SIMILARITY_METRICS", "L2") + # Create table schema self.fields = [ FieldSchema(name="pk", @@ -149,9 +151,10 @@ def get( topk=1, fallback=None, emb=None, - distance_metric="L2", + distance_metric="", timeout=60, # timeout (unit second) ): + distance_metric = distance_metric or os.getenv("MILVUS_SIMILARITY_METRICS", "L2") collection = None try: diff --git a/src/ops_milvus.py b/src/ops_milvus.py index 087efd3..383e136 100644 --- a/src/ops_milvus.py +++ b/src/ops_milvus.py @@ -1,3 +1,4 @@ +import os import json import copy import traceback @@ -7,6 +8,7 @@ from notion import NotionAgent from milvus_cli import MilvusClient from embedding_agent import EmbeddingAgent +import embedding_utils as emb_utils import utils @@ -154,22 +156,25 @@ def get_relevant( db_client=client, key_ttl=key_ttl) + # response_arr: [{item_id, distance}, ...] response_arr = milvus_client.get( collection_name, text, topk=topk, fallback=fallback, emb=embedding) + # filter by distance (similiarity value) according to the + # metrics type + metric_type = os.getenv("MILVUS_SIMILARITY_METRICS", "L2") + valid_embs = emb_utils.similarity_topk(response_arr, metric_type, max_distance, topk) + print(f"[get_relevant] metric_type: {metric_type}, max_distance: {max_distance}, raw emb response_arr size: {len(response_arr)}, post emb_utils.topk: {len(valid_embs)}") + res = [] - for response in response_arr: + for response in valid_embs: print(f"[get_relevant] Processing response: {response}") page_id = response["item_id"] distance = response["distance"] - if distance > max_distance: - print(f"[get_relevant] Filtered it out due to the distance: {distance} > max_distance {max_distance}, page_id: {page_id}") - continue - page_metadata = client.get_page_item_id(page_id) if not page_metadata: @@ -301,7 +306,7 @@ def clear(self, cleanup_date): print(f"Collections: {collections}") for name in collections: - suffix = name.split("__")[1] + suffix = name.split("__")[-1] dt = date.fromisoformat(suffix.replace("_", "-")) stats = milvus_client.get_stats(name)