From 7eb9b7e7ea57da0c7bb92b8d292ee5453270174a Mon Sep 17 00:00:00 2001 From: ValMobYKang <115156261+ValMobYKang@users.noreply.github.com> Date: Tue, 19 Dec 2023 09:23:04 +0100 Subject: [PATCH] Add bm25 (#2) * update * update to mistral * format --- .env.example | 1 + requirements.txt | 3 ++- src/backend.py | 28 +++++++++++++++++++++++----- start_server | 2 +- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/.env.example b/.env.example index b6a49af..1f120d7 100644 --- a/.env.example +++ b/.env.example @@ -4,3 +4,4 @@ CONFLUENCE_SPACE=YOUR_SPACE_NAME MODEL=YOUR_GGUF_MODEL BITBUCKET_URL=YOUR_BITBUCKET_URL BITBUCKET_PROJECT=YOUR_BITBUCKET_PROJECT +PHOENIX_PORT=YOUR_PHOENIX_PORT \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index cf18974..f983a25 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,4 +12,5 @@ svglib streamlit watchdog streamlit-authenticator -sentence-transformers \ No newline at end of file +sentence-transformers +rank_bm25 \ No newline at end of file diff --git a/src/backend.py b/src/backend.py index ea70adb..9a80b21 100644 --- a/src/backend.py +++ b/src/backend.py @@ -10,7 +10,7 @@ from llama_index.embeddings import HuggingFaceEmbedding from llama_index.indices.prompt_helper import PromptHelper from llama_index.query_engine import CustomQueryEngine -from llama_index.retrievers import BaseRetriever +from llama_index.retrievers import BaseRetriever, BM25Retriever from llama_index.response_synthesizers import ( get_response_synthesizer, BaseSynthesizer, @@ -54,7 +54,9 @@ def custom_query(self, query_str: str): def service_context(): - LLM = OpenAI(temperature=0.1, max_tokens=2048, callback_manager=cb_manager) + LLM = OpenAI( + temperature=0.1, max_tokens=2048, stop=[""], callback_manager=cb_manager + ) EMBEDDING = HuggingFaceEmbedding( model_name="BAAI/bge-base-en-v1.5", callback_manager=cb_manager ) @@ -70,7 +72,7 @@ def service_context(): def init_index(persist_dir: Literal["confluence_store", "bitbucket_store"]): if os.path.exists(persist_dir): - print(f"Loading {persist_dir} ...") + print(f"... Loading {persist_dir}") return load_index_from_storage( storage_context=StorageContext.from_defaults(persist_dir=persist_dir), service_context=service_context(), @@ -132,17 +134,33 @@ def get_query_engine(indices: list): "<|im_start|>assistant" ) + mistral_qa_prompt = PromptTemplate( + "[INST] You will be presented with context. Your task is to answer the query only based on the context. " + "If the context cannot answer the query, you responses 'I don't know' directly without any more responses. " + "Approach this task step-by-step, take your time. This is very important to my career.\n" + "The Context information is below. \n" + "---------------------\n{context_str}\n--------------------- [/INST]\n" + "[INST] {query_str} [/INST]" + ) + if len(indices) == 1: return indices[0].as_query_engine( similarity_top_k=5, service_context=service_context(), response_mode="compact", node_postprocessors=[RERANK], - text_qa_template=dolphin_qa_prompt, + text_qa_template=mistral_qa_prompt, ) + retrievers = [] + for index in indices: + retriever = BM25Retriever.from_defaults(index=index, similarity_top_k=5) + retriever.callback_manager = cb_manager + retrievers.append(retriever) + return QueryMultiEngine( - retrievers=[index.as_retriever(similarity_top_k=5) for index in indices], + retrievers=[index.as_retriever(similarity_top_k=5) for index in indices] + + retrievers, node_postprocessors=[RERANK], response_synthesizer=get_response_synthesizer( service_context=service_context(), diff --git a/start_server b/start_server index 37b2b1d..feb02f6 100755 --- a/start_server +++ b/start_server @@ -28,5 +28,5 @@ else fi -python3 -m llama_cpp.server --model $MODEL --n_gpu_layers 1 --verbose False +python3 -m llama_cpp.server --model $MODEL --n_gpu_layers 1 --n_ctx 32768 --verbose False