Skip to content

Commit

Permalink
Add bm25 (#2)
Browse files Browse the repository at this point in the history
* update

* update to mistral

* format
  • Loading branch information
ValMobYKang authored Dec 19, 2023
1 parent 8789abc commit 7eb9b7e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 7 deletions.
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ CONFLUENCE_SPACE=YOUR_SPACE_NAME
MODEL=YOUR_GGUF_MODEL
BITBUCKET_URL=YOUR_BITBUCKET_URL
BITBUCKET_PROJECT=YOUR_BITBUCKET_PROJECT
PHOENIX_PORT=YOUR_PHOENIX_PORT
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ svglib
streamlit
watchdog
streamlit-authenticator
sentence-transformers
sentence-transformers
rank_bm25
28 changes: 23 additions & 5 deletions src/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from llama_index.embeddings import HuggingFaceEmbedding
from llama_index.indices.prompt_helper import PromptHelper
from llama_index.query_engine import CustomQueryEngine
from llama_index.retrievers import BaseRetriever
from llama_index.retrievers import BaseRetriever, BM25Retriever
from llama_index.response_synthesizers import (
get_response_synthesizer,
BaseSynthesizer,
Expand Down Expand Up @@ -54,7 +54,9 @@ def custom_query(self, query_str: str):


def service_context():
LLM = OpenAI(temperature=0.1, max_tokens=2048, callback_manager=cb_manager)
LLM = OpenAI(
temperature=0.1, max_tokens=2048, stop=["</s>"], callback_manager=cb_manager
)
EMBEDDING = HuggingFaceEmbedding(
model_name="BAAI/bge-base-en-v1.5", callback_manager=cb_manager
)
Expand All @@ -70,7 +72,7 @@ def service_context():

def init_index(persist_dir: Literal["confluence_store", "bitbucket_store"]):
if os.path.exists(persist_dir):
print(f"Loading {persist_dir} ...")
print(f"... Loading {persist_dir}")
return load_index_from_storage(
storage_context=StorageContext.from_defaults(persist_dir=persist_dir),
service_context=service_context(),
Expand Down Expand Up @@ -132,17 +134,33 @@ def get_query_engine(indices: list):
"<|im_start|>assistant"
)

mistral_qa_prompt = PromptTemplate(
"<s>[INST] You will be presented with context. Your task is to answer the query only based on the context. "
"If the context cannot answer the query, you responses 'I don't know' directly without any more responses. "
"Approach this task step-by-step, take your time. This is very important to my career.\n"
"The Context information is below. \n"
"---------------------\n{context_str}\n--------------------- [/INST]</s>\n"
"[INST] {query_str} [/INST]"
)

if len(indices) == 1:
return indices[0].as_query_engine(
similarity_top_k=5,
service_context=service_context(),
response_mode="compact",
node_postprocessors=[RERANK],
text_qa_template=dolphin_qa_prompt,
text_qa_template=mistral_qa_prompt,
)

retrievers = []
for index in indices:
retriever = BM25Retriever.from_defaults(index=index, similarity_top_k=5)
retriever.callback_manager = cb_manager
retrievers.append(retriever)

return QueryMultiEngine(
retrievers=[index.as_retriever(similarity_top_k=5) for index in indices],
retrievers=[index.as_retriever(similarity_top_k=5) for index in indices]
+ retrievers,
node_postprocessors=[RERANK],
response_synthesizer=get_response_synthesizer(
service_context=service_context(),
Expand Down
2 changes: 1 addition & 1 deletion start_server
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ else
fi


python3 -m llama_cpp.server --model $MODEL --n_gpu_layers 1 --verbose False
python3 -m llama_cpp.server --model $MODEL --n_gpu_layers 1 --n_ctx 32768 --verbose False

0 comments on commit 7eb9b7e

Please sign in to comment.