From 1cc815984aa69f5b67c5f640f804c3c977d3e29a Mon Sep 17 00:00:00 2001 From: srmsoumya Date: Wed, 18 Dec 2024 10:00:37 +0530 Subject: [PATCH] Pause on human input for location --- api.py | 112 ++++++++++++++++++++++------------------ frontend/app.py | 134 ++++++++++++++++++++++++++++++++++-------------- 2 files changed, 158 insertions(+), 88 deletions(-) diff --git a/api.py b/api.py index 5a5a2fe..34e8623 100644 --- a/api.py +++ b/api.py @@ -6,6 +6,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse +from langgraph.types import Command from langfuse.callback import CallbackHandler from zeno.agents.maingraph.agent import graph from zeno.agents.maingraph.utils.state import GraphState @@ -31,63 +32,76 @@ def event_stream(query: str, thread_id: Optional[str]=None, query_type: Optional if not thread_id: thread_id = uuid.uuid4() - initial_state = GraphState(question=query) - config = { # "callbacks": [langfuse_handler], "configurable": {"thread_id": thread_id}, } - - if query_type == "query": - for namespace, chunk in graph.stream( - # for data in graph.stream( - initial_state, + if query_type == "human_input": + print(query) + selected_index = int(query) + current_state = graph.get_state(config) + stream = graph.stream( + Command(resume={ + "action": "update", + "option": selected_index + }), + stream_mode="updates", + subgraphs=True, + config=config, + ) + elif query_type == "query": + stream = graph.stream( + {"question": query, "route": None}, stream_mode="updates", subgraphs=True, config=config, - ): - node_name = list(chunk.keys())[0] - print(f"Namespace -> {namespace}") - print(f"Node name -> {node_name}") - - if node_name == "__interrupt__": - print(f"Waiting for human input") - interrupt_msg = chunk[node_name][0].value - question = interrupt_msg["question"] - options = interrupt_msg["options"] - artifact = interrupt_msg["artifact"] - - print("Waiting for human input") - yield pack({ - "type": "human_input", - "options": options, - "artifact": artifact, - "question": question - }) - elif node_name == "slasher": - pass - else: - if not chunk[node_name]: - continue - messages = chunk[node_name].get("messages", {}) - if not messages: - continue - for msg in messages: - if isinstance(msg, ToolMessage): + ) + else: + raise ValueError(f"Invalid query type from frontend: {query_type}") + + for namespace,chunk in stream: + node_name = list(chunk.keys())[0] + print(f"Namespace -> {namespace}") + print(f"Node name -> {node_name}") + + if node_name == "__interrupt__": + print(f"Waiting for human input") + interrupt_msg = chunk[node_name][0].value + question = interrupt_msg["question"] + options = interrupt_msg["options"] + artifact = interrupt_msg["artifact"] + + print("Waiting for human input") + yield pack({ + "type": "human_input", + "options": options, + "artifact": artifact, + "question": question + }) + elif node_name == "slasher": + pass + else: + if not chunk[node_name]: + continue + messages = chunk[node_name].get("messages", {}) + if not messages: + continue + for msg in messages: + if isinstance(msg, ToolMessage): + yield pack({ + "type": "tool_call", + "tool_name": msg.name, + "content": msg.content, + "artifact": msg.artifact if hasattr(msg, "artifact") else None, + }) + elif isinstance(msg, AIMessage): + if msg.content: yield pack({ - "type": "tool_call", - "tool_name": msg.name, - "content": msg.content, - "artifact": msg.artifact if hasattr(msg, "artifact") else None, + "type": "update", + "content": msg.content }) - elif isinstance(msg, AIMessage): - if msg.content: - yield pack({ - "type": "update", - "content": msg.content - }) - else: - raise ValueError(f"Unknown message type: {type(msg)}") + else: + raise ValueError(f"Unknown message type: {type(msg)}") @@ -120,7 +134,7 @@ async def stream( thread_id: Optional[str] = Body(None), query_type: Optional[str] = Body(None)): - print("POST...") + print("\n\nPOST...\n\n") print(query, thread_id, query_type) print("=" * 30) diff --git a/frontend/app.py b/frontend/app.py index 7c75700..4dd616e 100644 --- a/frontend/app.py +++ b/frontend/app.py @@ -2,7 +2,6 @@ import os import uuid from dotenv import load_dotenv - import folium import pandas as pd import requests @@ -13,8 +12,15 @@ API_BASE_URL = os.environ.get("API_BASE_URL") +# Initialize session state variables if "zeno_session_id" not in st.session_state: st.session_state.zeno_session_id = str(uuid.uuid4()) +if "current_options" not in st.session_state: + st.session_state.current_options = None +if "current_question" not in st.session_state: + st.session_state.current_question = None +if "messages" not in st.session_state: + st.session_state.messages = [] st.header("Zeno") st.caption("Your intelligent EcoBot, saving the forest faster than a 🐼 eats bamboo") @@ -44,59 +50,108 @@ """ ) +def display_message(message): + """Helper function to display a single message""" + if message["role"] == "user": + st.chat_message("user").write(message["content"]) + elif message["role"] == "assistant": + if message["type"] == "location": + st.chat_message("assistant").write("Found location you searched for...") + data = message["content"] + artifact = data.get("artifact", {}) + for feature in artifact["features"]: + st.chat_message("assistant").write(f"Found {feature['properties']['name']} in {feature['properties']['gadmid']}") + + geometry = artifact["features"][0]["geometry"] + if geometry["type"] == "Polygon": + pnt = geometry["coordinates"][0][0] + else: + pnt = geometry["coordinates"][0][0][0] + + m = folium.Map(location=[pnt[1], pnt[0]], zoom_start=11) + g = folium.GeoJson(artifact).add_to(m) + folium_static(m, width=700, height=500) + elif message["type"] == "alerts": + st.chat_message("assistant").write("Computing distributed alerts statistics...") + table = json.loads(message["content"]["content"]) + st.bar_chart(pd.DataFrame(table).T) + elif message["type"] == "context": + st.chat_message("assistant").write(f"Adding context layer {message['content']}") + else: + st.chat_message("assistant").write(message["content"]) + +def handle_human_input_submission(selected_index): + if st.session_state.current_options and selected_index is not None: + with requests.post( + f"{API_BASE_URL}/stream", + json={ + "query": str(selected_index), + "thread_id": st.session_state.zeno_session_id, + "query_type": "human_input" + }, + stream=True, + ) as response: + print("\n POST HUMAN INPUT...\n") + handle_stream_response(response) + def handle_stream_response(stream): for chunk in stream.iter_lines(): data = json.loads(chunk.decode("utf-8")) if data.get("type") == "human_input": - # Show a dropdown with options & a submit button - selected_option = st.selectbox(data["question"], data["options"]) - if st.button("Submit"): - # Send another POST request with the selected option - with requests.post( - f"{API_BASE_URL}/stream", - json={ - "query": selected_option, - "thread_id": st.session_state.zeno_session_id, - "query_type": "human_input" - }, - stream=True, - ) as response: - handle_stream_response(response) + # Store the options and question in session state + st.session_state.current_options = data["options"] + st.session_state.current_question = data["question"] + st.session_state.waiting_for_input = True + st.rerun() + elif data.get("type") == "tool_call": + message = None if data.get("tool_name") == "location-tool": - st.chat_message("assistant").write("Found location you searched for...") - artifact = data.get("artifact", {}) - for feature in artifact["features"]: - st.chat_message("assistant").write(f"Found {feature['properties']['name']} in {feature['properties']['gadmid']}") - - # Add the artifact to the map - geometry = artifact["features"][0]["geometry"] - if geometry["type"] == "Polygon": - pnt = geometry["coordinates"][0][0] - else: - pnt = geometry["coordinates"][0][0][0] - - m = folium.Map(location=[pnt[1], pnt[0]], zoom_start=11) - g = folium.GeoJson( - artifact, - ).add_to(m) - folium_static(m, width=700, height=500) + message = {"role": "assistant", "type": "location", "content": data} elif data.get("tool_name") == "dist-alerts-tool": - st.chat_message("assistant").write("Computing distributed alerts statistics...") - table = json.loads(data["content"]) - st.bar_chart(pd.DataFrame(table).T) + message = {"role": "assistant", "type": "alerts", "content": data} elif data.get("tool_name") == "context-layer-tool": - st.chat_message("assistant").write(f"Adding context layer {data['content']}") + message = {"role": "assistant", "type": "context", "content": data["content"]} else: - st.chat_message("assistant").write(data["content"]) + message = {"role": "assistant", "type": "text", "content": data["content"]} + + if message: + st.session_state.messages.append(message) + display_message(message) + elif data.get("type") == "update": - st.chat_message("assistant").write(data["content"]) + message = {"role": "assistant", "type": "text", "content": data["content"]} + st.session_state.messages.append(message) + display_message(message) else: raise ValueError(f"Unknown message type: {data.get('type')}") +# Display chat history +for message in st.session_state.messages: + display_message(message) + +# Handle human input interface if options are available +if st.session_state.current_options: + selected_option = st.selectbox( + st.session_state.current_question, + st.session_state.current_options, + key="selected_option" + ) + selected_index = st.session_state.current_options.index(selected_option) + if st.button("Submit"): + handle_human_input_submission(selected_index) + # Clear the options after submission + st.session_state.current_options = None + st.session_state.current_question = None + +# Main chat input if user_input := st.chat_input("Type your message here..."): - st.chat_message("user").write(user_input) + # Add user message to history + message = {"role": "user", "type": "text", "content": user_input} + st.session_state.messages.append(message) + display_message(message) + with requests.post( f"{API_BASE_URL}/stream", json={ @@ -106,4 +161,5 @@ def handle_stream_response(stream): }, stream=True, ) as stream: + print("\nPOST...\n") handle_stream_response(stream)