Skip to content

Commit

Permalink
✨ add label in sidebar
Browse files Browse the repository at this point in the history
  • Loading branch information
baptiste-pasquier committed Apr 17, 2024
1 parent 7868535 commit 15605f2
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 19 deletions.
60 changes: 46 additions & 14 deletions backend/api_plugins/sessions/sessions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from datetime import datetime
from pathlib import Path
from typing import List, Optional, Sequence
from typing import Optional, Sequence
from uuid import uuid4

from fastapi import APIRouter, Depends, FastAPI, Response
Expand All @@ -22,7 +22,9 @@ def session_routes(
connection.run_script(Path(__file__).parent / "sessions_tables.sql")

@app.post("/session/new")
async def chat_new(current_user: User=authentication, dependencies=dependencies) -> dict:
async def chat_new(
current_user: User = authentication, dependencies=dependencies
) -> dict:
chat_id = str(uuid4())
timestamp = datetime.utcnow().isoformat()
user_id = current_user.email if current_user else "unauthenticated"
Expand All @@ -33,26 +35,51 @@ async def chat_new(current_user: User=authentication, dependencies=dependencies)
)
return {"session_id": chat_id}


@app.get("/session/list")
async def chat_list(current_user: User=authentication, dependencies=dependencies) -> List[dict]:
async def chat_list(
current_user: User = authentication, dependencies=dependencies
) -> list[dict]:
user_email = current_user.email if current_user else "unauthenticated"
chats = []
with Database() as connection:
result = connection.execute(
"SELECT id, timestamp FROM session WHERE user_id = ? ORDER BY timestamp DESC",
(user_email,),
# Check if message_history table exists (first time running the app will not
# have this table created yet)
message_history_exists = connection.fetchone(
"SELECT name FROM sqlite_master WHERE type='table' AND"
" name='message_history'"
)
chats = [{"id": row[0], "timestamp": row[1]} for row in result]
if message_history_exists:
# Join session with message_history and get the first message
result = connection.execute(
"SELECT s.id, s.timestamp, mh.message FROM session s LEFT JOIN"
" (SELECT *, ROW_NUMBER() OVER (PARTITION BY session_id ORDER BY"
" timestamp ASC) as rn FROM message_history) mh ON s.id ="
" mh.session_id AND mh.rn = 1 WHERE s.user_id = ? ORDER BY"
" s.timestamp DESC",
(user_email,),
)
for row in result:
# Extract the first message content if available
first_message_content = (
json.loads(row[2])["data"]["content"] if row[2] else ""
)
chat = {
"id": row[0],
"timestamp": row[1],
"first_message": first_message_content,
}
chats.append(chat)
return chats


@app.get("/session/{session_id}")
async def chat(session_id: str, current_user: User=authentication, dependencies=dependencies) -> dict:
messages: List[Message] = []
async def chat(
session_id: str, current_user: User = authentication, dependencies=dependencies
) -> dict:
messages: list[Message] = []
with Database() as connection:
result = connection.execute(
"SELECT id, timestamp, session_id, message FROM message_history WHERE session_id = ? ORDER BY timestamp ASC",
"SELECT id, timestamp, session_id, message FROM message_history WHERE"
" session_id = ? ORDER BY timestamp ASC",
(session_id,),
)
for row in result:
Expand All @@ -66,8 +93,13 @@ async def chat(session_id: str, current_user: User=authentication, dependencies=
content=content,
)
messages.append(message)
return {"chat_id": session_id, "messages": [message.dict() for message in messages]}
return {
"chat_id": session_id,
"messages": [message.dict() for message in messages],
}

@app.get("/session")
async def session_root(current_user: User=authentication, dependencies=dependencies) -> dict:
async def session_root(
current_user: User = authentication, dependencies=dependencies
) -> dict:
return Response("Sessions management routes are enabled.", status_code=200)
45 changes: 40 additions & 5 deletions frontend/lib/sidebar.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,24 @@
def sidebar():
with st.sidebar:
st.sidebar.title("RAG Industrialization Kit", anchor="top")
st.sidebar.markdown(f"<p style='color:grey;'>Logged in as {st.session_state['email']}</p>", unsafe_allow_html=True)
st.sidebar.markdown(
f"<p style='color:grey;'>Logged in as {st.session_state['email']}</p>",
unsafe_allow_html=True,
)

if st.sidebar.button("New Chat", use_container_width=True, key="new_chat_button"):
if st.sidebar.button(
"New Chat", use_container_width=True, key="new_chat_button"
):
st.session_state["messages"] = []

with st.empty():
chat_list = list_sessions()
chats_by_time_ago = {}
for chat in chat_list:
chat_id, timestamp = chat["id"], chat["timestamp"]
time_ago = humanize.naturaltime(datetime.utcnow() - datetime.fromisoformat(timestamp))
time_ago = humanize.naturaltime(
datetime.utcnow() - datetime.fromisoformat(timestamp)
)
if time_ago not in chats_by_time_ago:
chats_by_time_ago[time_ago] = []
chats_by_time_ago[time_ago].append(chat)
Expand All @@ -29,9 +36,22 @@ def sidebar():
st.sidebar.markdown(time_ago)
for chat in chats:
chat_id = chat["id"]
if st.sidebar.button(chat_id, key=chat_id, use_container_width=True):
chat_first_message = chat["first_message"]
label = (
truncate_label(chat_first_message, 100)
if chat_first_message
else "*No content*"
)
if st.sidebar.button(
label=label,
key=chat_id,
use_container_width=True,
):
st.session_state["chat_id"] = chat_id
messages = [Message(**message) for message in get_session(chat_id)["messages"]]
messages = [
Message(**message)
for message in get_session(chat_id)["messages"]
]
st.session_state["messages"] = messages


Expand All @@ -42,3 +62,18 @@ def list_sessions():
def get_session(session_id: str):
session = query("get", f"/session/{session_id}").json()
return session


def truncate_label(string: str, max_len: int) -> str:
"""Truncate a string to a maximum length, appending ellipsis if necessary.
Args:
string (str): String to be truncated.
max_len (int): Maximum allowed length of the string after truncation.
Returns:
str: Truncated string.
"""
if string and len(string) > max_len:
string = string[: max_len - 3] + "..."
return string

0 comments on commit 15605f2

Please sign in to comment.