Skip to content

Commit

Permalink
add tags for keyword search
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Mar 26, 2023
1 parent 8dc2129 commit ef8cd21
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 37 deletions.
1 change: 1 addition & 0 deletions .streamlit/secrets.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
URL="http://0.0.0.0:8000"
121 changes: 88 additions & 33 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
import streamlit as st

from utils.audio import convert_to_hz, extract_audio_from_video
from utils.text import get_top_timestamp_for_question, get_top_timestamps_for_keyword
from utils.text import get_top_keywords, get_top_timestamp_for_question, get_top_timestamps_for_keyword
from utils.video import download_video, valid_link

stage = 1


@st.cache_data(show_spinner=False, ttl=None)
def _start_server():
Expand All @@ -19,7 +17,7 @@ def _start_server():

def _display_input_type():
file_type = st.selectbox(label="File type", options=["<select>", "Audio", "YT Video", "Existing Sample"])
return file_type
return file_type if file_type != "<select>" else None


def _sample_upload():
Expand Down Expand Up @@ -111,20 +109,32 @@ def _display_media_at_timestamp(media_bytes, timestamp, media_type, st_obj):
_display_media_at_timestamp(media_bytes, timestamp, media_type, col2)


def _initialize_session_state():
if "stage" not in st.session_state:
st.session_state["stage"] = 1

if "search_query" not in st.session_state:
st.session_state["search_query"] = ""

if "search_type" not in st.session_state:
st.session_state["search_type"] = ""


def _main():
_start_server()
global stage

file_type = None
if stage == 1:
if st.session_state["stage"] >= 1:
st.header("Earwise (Only English)")
st.subheader("Search within Audio")
st.info("To restart the app, please refresh :)")
st.warning("Please don't overuse, it's running on free-tier :)")
file_type = _display_input_type()
stage = 2

if stage == 2:
if file_type:
st.session_state["stage"] = 2

if st.session_state["stage"] >= 2:
media_path = None
if file_type == "Audio":
media_path = _audio_upload()
Expand All @@ -134,50 +144,94 @@ def _main():
media_path = _sample_upload()

if media_path:
stage = 3
st.session_state["stage"] = 3

if stage == 3:
if st.session_state["stage"] >= 3:
with st.spinner("Processing..."):
url = f"{st.secrets['URL']}/transcribe"
transcriptions = _whisper_recognize(url, media_path, file_type)

stage = 4
st.session_state["stage"] = 4

if st.session_state["stage"] >= 4:
if not st.session_state["search_query"]:
search_type = st.selectbox(
label="What do you want to do?", options=["<select>", "Keyword Search", "Ask a question"]
)
st.session_state["search_type"] = search_type

if st.session_state["search_type"] == "Keyword Search":
# Set up some sample tags
selected_tag = ""
if not st.session_state["search_query"]:
url = f"{st.secrets['URL']}/extract_keywords"
keywords = get_top_keywords(url, transcriptions)
num_cols = len(keywords)
container = st.container()
cols = container.columns(num_cols)

for i, tag in enumerate(keywords):
col_idx = i % num_cols
button = cols[col_idx].button(tag)
if button:
selected_tag = tag
st.session_state["search_query"] = selected_tag
st.experimental_rerun()

search_query = ""
if not st.session_state["search_query"]:
search_query = st.text_input(label="Search a keyword", placeholder="weekend routine")

if stage == 4:
search_type = st.selectbox(
label="What do you want to do?", options=["<select>", "Keyword Search", "Ask a question"]
)
if search_type == "Keyword Search":
search_query = st.text_input(label="Search a keyword", placeholder="weekend routine")
clear = st.button("Clear results")
st.warning("Make sure to clear the results before new search :)")
placeholder = st.empty()
if clear:
placeholder.empty()

if search_query and not clear:
if st.session_state["search_query"]:
clear = st.button("Clear results")
if clear:
placeholder.empty()
st.session_state["search_query"] = ""
st.session_state["search_type"] = ""
st.experimental_rerun()

if search_query:
st.session_state["search_query"] = search_query
st.experimental_rerun()

if st.session_state["search_query"] and not clear:
with st.spinner("Searching audio..."):
url = f"{st.secrets['URL']}/keyword_query"
timestamps = get_top_timestamps_for_keyword(url, transcriptions, search_query, threshold=0.5)
timestamps = get_top_timestamps_for_keyword(
url, transcriptions, st.session_state["search_query"], threshold=0.5
)

if not timestamps:
st.text("No result. Please try something else :)")
else:
media_type = "video" if file_type in ("YT Video", "Existing Sample") else "audio"
display_media(media_path, timestamps, media_type, placeholder)

elif search_type == "Ask a question":
search_query = st.text_input(label="Ask a question", placeholder="What do you do on weekend?")
clear = st.button("Clear results")
st.warning("Make sure to clear the results before new search :)")
placeholder = st.empty()
if clear:
placeholder.empty()
elif st.session_state["search_type"] == "Ask a question":
search_query = ""
if not st.session_state["search_query"]:
search_query = st.text_input(label="Ask a question", placeholder="What do you do on weekend?")

if search_query and not clear:
placeholder = st.empty()
if st.session_state["search_query"]:
clear = st.button("Clear results")
if clear:
placeholder.empty()
st.session_state["search_query"] = ""
st.session_state["search_type"] = ""
st.experimental_rerun()

if search_query:
st.session_state["search_query"] = search_query
st.experimental_rerun()

if st.session_state["search_query"] and not clear:
with st.spinner("Searching audio..."):
url = f"{st.secrets['URL']}/question_query"
timestamp = get_top_timestamp_for_question(url, transcriptions, search_query, threshold=0.1)
timestamp = get_top_timestamp_for_question(
url, transcriptions, st.session_state["search_query"], threshold=0.1
)

if not timestamp:
st.text("No result. Please try something else :)")
Expand All @@ -187,4 +241,5 @@ def _main():


if __name__ == "__main__":
_initialize_session_state()
_main()
13 changes: 13 additions & 0 deletions backend/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from pydantic import BaseModel

from backend.models.nlp import (
predict_keywords,
predict_top_timestamp_for_question,
predict_top_timestamps_for_keyword,
prepare_keyword_extractor,
prepare_qna_pipeline,
prepare_similarity_model,
)
Expand All @@ -17,6 +19,7 @@
prepare_whisper_model()
sim_model, sim_tokenizer = prepare_similarity_model()
qna_pipeline = prepare_qna_pipeline()
keyword_extractor = prepare_keyword_extractor()


app = FastAPI()
Expand All @@ -41,6 +44,10 @@ class QuestionPayload(BaseModel):
question: str


class KeywordSearchPayload(BaseModel):
context: str


@app.get("/healthcheck")
async def pong():
return {"status": "alive"}
Expand Down Expand Up @@ -73,3 +80,9 @@ async def predict_keyword_timestamps(data: KeywordPayload):
async def predict_question_timestamp(data: QuestionPayload):
result = predict_top_timestamp_for_question(data.context, data.question, qna_pipeline)
return {"result": result}


@app.post("/extract_keywords")
async def extract_keywords(data: KeywordSearchPayload):
keywords = predict_keywords(data.context, keyword_extractor)
return {"result": [x[0] for x in keywords]}
9 changes: 9 additions & 0 deletions backend/models/nlp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from keybert import KeyBERT
from transformers import AutoModel, AutoTokenizer, pipeline


Expand All @@ -15,6 +16,10 @@ def prepare_qna_pipeline():
return question_answerer


def prepare_keyword_extractor():
return KeyBERT()


def _mean_pooling(model_output, attention_mask):
token_embeddings = model_output[0] # First element of model_output contains all token embeddings
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
Expand Down Expand Up @@ -47,3 +52,7 @@ def predict_top_timestamps_for_keyword(texts, search_query, threshold, model, to
def predict_top_timestamp_for_question(context, question, pipeline):
result = pipeline(question=question, context=context)
return result


def predict_keywords(docs, model):
return model.extract_keywords(docs, keyphrase_ngram_range=(1, 2), stop_words="english", use_mmr=True, diversity=0.6)
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-r requirements/backend.txt
streamlit==1.19.0
yt-dlp==2023.1.6
moviepy==1.0.3
pydub==0.25.1
streamlit==1.19.0
yt-dlp==2023.1.6
5 changes: 3 additions & 2 deletions requirements/backend.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
fastapi==0.91.0
keybert==0.7.0
pysrt==1.1.2
python-multipart==0.0.5
torch==1.13.1
transformers==4.26.0
uvicorn==0.20.0
python-multipart==0.0.5
pysrt==1.1.2
10 changes: 10 additions & 0 deletions utils/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,13 @@ def get_top_timestamp_for_question(url, transcriptions, question, threshold=0.4)
timestamp = x[0]

return timestamp


def get_top_keywords(url, transcriptions):
context = "".join(x["text"] for x in transcriptions)
json_data = {
"context": context,
}
response = requests.post(url, json=json_data)
result = response.json()["result"]
return result

0 comments on commit ef8cd21

Please sign in to comment.