Skip to content

Commit

Permalink
better chat interface
Browse files Browse the repository at this point in the history
  • Loading branch information
SubhadityaMukherjee committed Aug 23, 2024
1 parent 52042bc commit c8cd3ce
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 44 deletions.
11 changes: 10 additions & 1 deletion frontend/.streamlit/config.toml
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
[server]
headless = true
headless = true

# Max size, in megabytes, for files uploaded with the file_uploader.
# Default: 200
maxUploadSize = 200

# Max size, in megabytes, of messages that can be sent via the WebSocket
# connection.
# Default: 200
maxMessageSize = 3500
37 changes: 16 additions & 21 deletions frontend/ui.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import json
from pathlib import Path

import pandas as pd
import streamlit as st
from streamlit_feedback import streamlit_feedback
from ui_utils import *

# Streamlit Chat Interface
Expand All @@ -12,10 +9,24 @@
info = """
<p style='text-align: center; color: white;'>Machine learning research should be easily accessible and reusable. OpenML is an open platform for sharing datasets, algorithms, and experiments - to learn how to learn better, together. <br>Ask me anything about OpenML or search for a dataset ... </p>
"""
chatbot_display = "How do I do X using OpenML? / Find me a dataset about Y"
chatbot_max_chars = 500

st.set_page_config(page_title=page_title, page_icon=logo)
st.title("OpenML AI Search")
# message_box = st.container()

with st.spinner("Loading Required Data"):
config_path = Path("../backend/config.json")
ui_loader = UILoader(config_path)

# container for company description and logo
with st.sidebar:
query_type = st.radio(
"Select Query Type", ["General Query", "Dataset", "Flow"], key="query_type_2"
)

user_input = st.chat_input(placeholder=chatbot_display, max_chars=chatbot_max_chars)
col1, col2 = st.columns([1, 4])
with col1:
st.image(logo, width=100)
Expand All @@ -24,22 +35,6 @@
info,
unsafe_allow_html=True,
)

with st.spinner("Loading Required Data"):
config_path = Path("../backend/config.json")
ui_loader = UILoader(config_path)

# Chat input box
user_input = ui_loader.chat_entry()

ui_loader.create_chat_interface(None)
query_type = st.selectbox("Select Query Type", ["General Query","Dataset", "Flow"], key="query_type_2")
llm_filter = st.toggle("LLM Filter")
# Chat interface
ui_loader.create_chat_interface(user_input=None)
if user_input:
ui_loader.create_chat_interface(
user_input, query_type=query_type, llm_filter=llm_filter
)
ui_loader.query_type = st.selectbox("Select Query Type", ["General Query","Dataset", "Flow"], key="query_type_3")
ui_loader.llm_filter = st.toggle("LLM Filter", key="llm_filter_2")

ui_loader.create_chat_interface(user_input, query_type=query_type)
48 changes: 26 additions & 22 deletions frontend/ui_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,6 @@ def feedback_cb():
json.dump(data, file, indent=4)


def display_results(initial_response):
"""
Description: Display the results in a DataFrame
"""
# st.write("OpenML Agent: ")
try:
st.dataframe(initial_response)
except:
st.write(initial_response)

class LLMResponseParser:
"""
Expand Down Expand Up @@ -352,11 +343,11 @@ def __init__(self, config_path):
with open(config_path, "r") as file:
# Load config
self.config = json.load(file)
# self.message_box = message_box

# Paths and display information

self.chatbot_display = "How do I do X using OpenML? / Find me a dataset about Y"
self.chatbot_input_max_chars = 500
# self.chatbot_input_max_chars = 500

# Load metadata chroma database for structured query
self.collec = load_chroma_metadata()
Expand All @@ -378,22 +369,22 @@ def __init__(self, config_path):
if "messages" not in st.session_state:
st.session_state.messages = []

def chat_entry(self):
"""
Description: Create the chat input box with a maximum character limit
# def chat_entry(self):
# """
# Description: Create the chat input box with a maximum character limit

"""
return st.chat_input(
self.chatbot_display, max_chars=self.chatbot_input_max_chars
)
# """
# return st.chat_input(
# self.chatbot_display, max_chars=self.chatbot_input_max_chars
# )

def create_chat_interface(self, user_input, query_type=None, llm_filter=None):
def create_chat_interface(self, user_input, query_type=None):
"""
Description: Create the chat interface and display the chat history and results. Show the user input and the response from the OpenML Agent.
"""
self.query_type = query_type
self.llm_filter = llm_filter
# self.llm_filter = llm_filter
if user_input is None:
with st.chat_message(name = "ai"):
st.write("OpenML Agent: ", "Hello! How can I help you today?")
Expand All @@ -412,10 +403,23 @@ def create_chat_interface(self, user_input, query_type=None, llm_filter=None):
for message in st.session_state.messages:
if message["role"] == "user":
with st.chat_message(name = "user"):
display_results(message["content"])
self.display_results(message["content"], "user")
else:
with st.chat_message(name = "ai"):
display_results(message["content"])
self.display_results(message["content"], "ai")

def display_results(self,initial_response, role):
"""
Description: Display the results in a DataFrame
"""
# st.write("OpenML Agent: ")

try:
st.dataframe(initial_response)
# self.message_box.chat_message(role).write(st.dataframe(initial_response))
except:
st.write(initial_response)
# self.message_box.chat_message(role).write(initial_response)

# Function to handle query processing
def process_query_chat(self, query):
Expand Down

0 comments on commit c8cd3ce

Please sign in to comment.