Skip to content

Commit

Permalink
Pause on human input for location
Browse files Browse the repository at this point in the history
  • Loading branch information
srmsoumya committed Dec 18, 2024
1 parent f3f4099 commit 1cc8159
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 88 deletions.
112 changes: 63 additions & 49 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse

from langgraph.types import Command
from langfuse.callback import CallbackHandler
from zeno.agents.maingraph.agent import graph
from zeno.agents.maingraph.utils.state import GraphState
Expand All @@ -31,63 +32,76 @@ def event_stream(query: str, thread_id: Optional[str]=None, query_type: Optional
if not thread_id:
thread_id = uuid.uuid4()

initial_state = GraphState(question=query)

config = {
# "callbacks": [langfuse_handler],
"configurable": {"thread_id": thread_id},
}

if query_type == "query":
for namespace, chunk in graph.stream(
# for data in graph.stream(
initial_state,
if query_type == "human_input":
print(query)
selected_index = int(query)
current_state = graph.get_state(config)
stream = graph.stream(
Command(resume={
"action": "update",
"option": selected_index
}),
stream_mode="updates",
subgraphs=True,
config=config,
)
elif query_type == "query":
stream = graph.stream(
{"question": query, "route": None},
stream_mode="updates",
subgraphs=True,
config=config,
):
node_name = list(chunk.keys())[0]
print(f"Namespace -> {namespace}")
print(f"Node name -> {node_name}")

if node_name == "__interrupt__":
print(f"Waiting for human input")
interrupt_msg = chunk[node_name][0].value
question = interrupt_msg["question"]
options = interrupt_msg["options"]
artifact = interrupt_msg["artifact"]

print("Waiting for human input")
yield pack({
"type": "human_input",
"options": options,
"artifact": artifact,
"question": question
})
elif node_name == "slasher":
pass
else:
if not chunk[node_name]:
continue
messages = chunk[node_name].get("messages", {})
if not messages:
continue
for msg in messages:
if isinstance(msg, ToolMessage):
)
else:
raise ValueError(f"Invalid query type from frontend: {query_type}")

for namespace,chunk in stream:
node_name = list(chunk.keys())[0]
print(f"Namespace -> {namespace}")
print(f"Node name -> {node_name}")

if node_name == "__interrupt__":
print(f"Waiting for human input")
interrupt_msg = chunk[node_name][0].value
question = interrupt_msg["question"]
options = interrupt_msg["options"]
artifact = interrupt_msg["artifact"]

print("Waiting for human input")
yield pack({
"type": "human_input",
"options": options,
"artifact": artifact,
"question": question
})
elif node_name == "slasher":
pass
else:
if not chunk[node_name]:
continue
messages = chunk[node_name].get("messages", {})
if not messages:
continue
for msg in messages:
if isinstance(msg, ToolMessage):
yield pack({
"type": "tool_call",
"tool_name": msg.name,
"content": msg.content,
"artifact": msg.artifact if hasattr(msg, "artifact") else None,
})
elif isinstance(msg, AIMessage):
if msg.content:
yield pack({
"type": "tool_call",
"tool_name": msg.name,
"content": msg.content,
"artifact": msg.artifact if hasattr(msg, "artifact") else None,
"type": "update",
"content": msg.content
})
elif isinstance(msg, AIMessage):
if msg.content:
yield pack({
"type": "update",
"content": msg.content
})
else:
raise ValueError(f"Unknown message type: {type(msg)}")
else:
raise ValueError(f"Unknown message type: {type(msg)}")



Expand Down Expand Up @@ -120,7 +134,7 @@ async def stream(
thread_id: Optional[str] = Body(None),
query_type: Optional[str] = Body(None)):

print("POST...")
print("\n\nPOST...\n\n")
print(query, thread_id, query_type)
print("=" * 30)

Expand Down
134 changes: 95 additions & 39 deletions frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import uuid
from dotenv import load_dotenv

import folium
import pandas as pd
import requests
Expand All @@ -13,8 +12,15 @@

API_BASE_URL = os.environ.get("API_BASE_URL")

# Initialize session state variables
if "zeno_session_id" not in st.session_state:
st.session_state.zeno_session_id = str(uuid.uuid4())
if "current_options" not in st.session_state:
st.session_state.current_options = None
if "current_question" not in st.session_state:
st.session_state.current_question = None
if "messages" not in st.session_state:
st.session_state.messages = []

st.header("Zeno")
st.caption("Your intelligent EcoBot, saving the forest faster than a 🐼 eats bamboo")
Expand Down Expand Up @@ -44,59 +50,108 @@
"""
)

def display_message(message):
"""Helper function to display a single message"""
if message["role"] == "user":
st.chat_message("user").write(message["content"])
elif message["role"] == "assistant":
if message["type"] == "location":
st.chat_message("assistant").write("Found location you searched for...")
data = message["content"]
artifact = data.get("artifact", {})
for feature in artifact["features"]:
st.chat_message("assistant").write(f"Found {feature['properties']['name']} in {feature['properties']['gadmid']}")

geometry = artifact["features"][0]["geometry"]
if geometry["type"] == "Polygon":
pnt = geometry["coordinates"][0][0]
else:
pnt = geometry["coordinates"][0][0][0]

m = folium.Map(location=[pnt[1], pnt[0]], zoom_start=11)
g = folium.GeoJson(artifact).add_to(m)
folium_static(m, width=700, height=500)
elif message["type"] == "alerts":
st.chat_message("assistant").write("Computing distributed alerts statistics...")
table = json.loads(message["content"]["content"])
st.bar_chart(pd.DataFrame(table).T)
elif message["type"] == "context":
st.chat_message("assistant").write(f"Adding context layer {message['content']}")
else:
st.chat_message("assistant").write(message["content"])

def handle_human_input_submission(selected_index):
if st.session_state.current_options and selected_index is not None:
with requests.post(
f"{API_BASE_URL}/stream",
json={
"query": str(selected_index),
"thread_id": st.session_state.zeno_session_id,
"query_type": "human_input"
},
stream=True,
) as response:
print("\n POST HUMAN INPUT...\n")
handle_stream_response(response)

def handle_stream_response(stream):
for chunk in stream.iter_lines():
data = json.loads(chunk.decode("utf-8"))

if data.get("type") == "human_input":
# Show a dropdown with options & a submit button
selected_option = st.selectbox(data["question"], data["options"])
if st.button("Submit"):
# Send another POST request with the selected option
with requests.post(
f"{API_BASE_URL}/stream",
json={
"query": selected_option,
"thread_id": st.session_state.zeno_session_id,
"query_type": "human_input"
},
stream=True,
) as response:
handle_stream_response(response)
# Store the options and question in session state
st.session_state.current_options = data["options"]
st.session_state.current_question = data["question"]
st.session_state.waiting_for_input = True
st.rerun()

elif data.get("type") == "tool_call":
message = None
if data.get("tool_name") == "location-tool":
st.chat_message("assistant").write("Found location you searched for...")
artifact = data.get("artifact", {})
for feature in artifact["features"]:
st.chat_message("assistant").write(f"Found {feature['properties']['name']} in {feature['properties']['gadmid']}")

# Add the artifact to the map
geometry = artifact["features"][0]["geometry"]
if geometry["type"] == "Polygon":
pnt = geometry["coordinates"][0][0]
else:
pnt = geometry["coordinates"][0][0][0]

m = folium.Map(location=[pnt[1], pnt[0]], zoom_start=11)
g = folium.GeoJson(
artifact,
).add_to(m)
folium_static(m, width=700, height=500)
message = {"role": "assistant", "type": "location", "content": data}
elif data.get("tool_name") == "dist-alerts-tool":
st.chat_message("assistant").write("Computing distributed alerts statistics...")
table = json.loads(data["content"])
st.bar_chart(pd.DataFrame(table).T)
message = {"role": "assistant", "type": "alerts", "content": data}
elif data.get("tool_name") == "context-layer-tool":
st.chat_message("assistant").write(f"Adding context layer {data['content']}")
message = {"role": "assistant", "type": "context", "content": data["content"]}
else:
st.chat_message("assistant").write(data["content"])
message = {"role": "assistant", "type": "text", "content": data["content"]}

if message:
st.session_state.messages.append(message)
display_message(message)

elif data.get("type") == "update":
st.chat_message("assistant").write(data["content"])
message = {"role": "assistant", "type": "text", "content": data["content"]}
st.session_state.messages.append(message)
display_message(message)
else:
raise ValueError(f"Unknown message type: {data.get('type')}")

# Display chat history
for message in st.session_state.messages:
display_message(message)

# Handle human input interface if options are available
if st.session_state.current_options:
selected_option = st.selectbox(
st.session_state.current_question,
st.session_state.current_options,
key="selected_option"
)
selected_index = st.session_state.current_options.index(selected_option)
if st.button("Submit"):
handle_human_input_submission(selected_index)
# Clear the options after submission
st.session_state.current_options = None
st.session_state.current_question = None

# Main chat input
if user_input := st.chat_input("Type your message here..."):
st.chat_message("user").write(user_input)
# Add user message to history
message = {"role": "user", "type": "text", "content": user_input}
st.session_state.messages.append(message)
display_message(message)

with requests.post(
f"{API_BASE_URL}/stream",
json={
Expand All @@ -106,4 +161,5 @@ def handle_stream_response(stream):
},
stream=True,
) as stream:
print("\nPOST...\n")
handle_stream_response(stream)

0 comments on commit 1cc8159

Please sign in to comment.