Skip to content

Commit

Permalink
add metadata filtering
Browse files Browse the repository at this point in the history
  • Loading branch information
omar-sol committed Feb 20, 2024
1 parent e0aadb4 commit 0cfc98f
Showing 1 changed file with 45 additions and 31 deletions.
76 changes: 45 additions & 31 deletions scripts/gradio-ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
from llama_index.core import VectorStoreIndex
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.core.vector_stores import (
MetadataFilters,
MetadataFilter,
FilterCondition,
)
import gradio as gr
from gradio.themes.utils import (
fonts,
Expand Down Expand Up @@ -62,34 +67,9 @@
# Initialize query engine
llm = OpenAI(temperature=0, model="gpt-3.5-turbo-0125", max_tokens=None)
embeds = OpenAIEmbedding(model="text-embedding-3-large", mode="text_search")
query_engine = index.as_query_engine(
llm=llm, similarity_top_k=5, embed_model=embeds, streaming=True
)


AVAILABLE_SOURCES_UI = [
"Gen AI 360: LLMs",
"Gen AI 360: LangChain",
"Gen AI 360: Advanced RAG",
"Towards AI Blog",
"Activeloop Docs",
"HF Transformers Docs",
"Wikipedia",
"OpenAI Docs",
"LangChain Docs",
]

AVAILABLE_SOURCES = [
"llm_course",
"langchain_course",
"advanced_rag_course",
"towards_ai",
"activeloop",
"hf_transformers",
"wikipedia",
"openai",
"langchain_docs",
]
# query_engine = index.as_query_engine(
# llm=llm, similarity_top_k=5, embed_model=embeds, streaming=True
# )


def save_completion(completion, history):
Expand Down Expand Up @@ -178,6 +158,8 @@ def format_sources(completion) -> str:


def add_sources(history, completion):
if history[-1][1] == "No sources selected. Please select sources to search.":
return history

formatted_sources = format_sources(completion)
history.append([None, formatted_sources])
Expand All @@ -192,10 +174,35 @@ def user(user_input, history):

def get_answer(history, sources: Optional[list[str]] = None):
user_input = history[-1][0]
history[-1][1] = ""

if len(sources) == 0:
history[-1][1] = "No sources selected. Please select sources to search."
yield history, "No sources selected. Please select sources to search."
return

# Dynamically create filters list
display_ui_to_source = {
ui: src for ui, src in zip(AVAILABLE_SOURCES_UI, AVAILABLE_SOURCES)
}
sources_renamed = [display_ui_to_source[disp] for disp in sources]
dynamic_filters = [
MetadataFilter(key="source", value=source) for source in sources_renamed
]

filters = MetadataFilters(
filters=dynamic_filters,
condition=FilterCondition.OR,
)
query_engine = index.as_query_engine(
llm=llm,
similarity_top_k=5,
embed_model=embeds,
streaming=True,
filters=filters,
)
completion = query_engine.query(user_input)

history[-1][1] = ""
for token in completion.response_gen:
history[-1][1] += token
yield history, completion
Expand Down Expand Up @@ -224,6 +231,13 @@ def get_answer(history, sources: Optional[list[str]] = None):

latest_completion = gr.State()

source_selection = gr.Dropdown(
choices=AVAILABLE_SOURCES_UI,
label="Select Sources",
value=AVAILABLE_SOURCES_UI,
multiselect=True,
)

chatbot = gr.Chatbot(
elem_id="chatbot", show_copy_button=True, scale=2, likeable=True
)
Expand Down Expand Up @@ -257,14 +271,14 @@ def get_answer(history, sources: Optional[list[str]] = None):
completion = gr.State()

submit.click(user, [question, chatbot], [question, chatbot], queue=False).then(
get_answer, inputs=[chatbot], outputs=[chatbot, completion]
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion]
).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
# .then(
# save_completion, inputs=[completion, chatbot]
# )

question.submit(user, [question, chatbot], [question, chatbot], queue=False).then(
get_answer, inputs=[chatbot], outputs=[chatbot, completion]
get_answer, inputs=[chatbot, source_selection], outputs=[chatbot, completion]
).then(add_sources, inputs=[chatbot, completion], outputs=[chatbot])
# .then(
# save_completion, inputs=[completion, chatbot]
Expand Down

0 comments on commit 0cfc98f

Please sign in to comment.