Skip to content

Commit

Permalink
Merge pull request #41 from wri/feature/dist-alerts-rag
Browse files Browse the repository at this point in the history
Add notebook for generating embeddings and initial RAG implementation for fetching the appropriate layer
  • Loading branch information
yellowcap authored Dec 19, 2024
2 parents 78baa3b + e20380b commit 84d97e6
Show file tree
Hide file tree
Showing 18 changed files with 2,477 additions and 221 deletions.
4 changes: 4 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ ENV PYTHONDONTWRITEBYTECODE=1 \
RUN apt-get update && apt-get install -y --no-install-recommends curl ca-certificates libexpat1 \
&& rm -rf /var/lib/apt/lists/*

RUN apt-get update && apt-get install -y build-essential libgdal-dev

ADD https://astral.sh/uv/0.5.4/install.sh /uv-installer.sh

# Run the installer then remove it
Expand All @@ -19,6 +21,8 @@ ENV PATH="/root/.local/bin/:$PATH"
# Copy the project into the image
ADD . /app

COPY ./ee-zeno-service-account.json /app/ee-zeno-service-account.json

# Sync the project into a new environment, using the frozen lockfile
WORKDIR /app

Expand Down
61 changes: 32 additions & 29 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from zeno.agents.maingraph.agent import graph
from zeno.agents.maingraph.utils.state import GraphState
from langchain_core.messages import ToolMessage, AIMessage

app = FastAPI()
langfuse_handler = CallbackHandler()

Expand All @@ -28,7 +29,9 @@ def pack(data):


# Streams the response from the graph
def event_stream(query: str, thread_id: Optional[str]=None, query_type: Optional[str]=None):
def event_stream(
query: str, thread_id: Optional[str] = None, query_type: Optional[str] = None
):
if not thread_id:
thread_id = uuid.uuid4()

Expand All @@ -41,10 +44,7 @@ def event_stream(query: str, thread_id: Optional[str]=None, query_type: Optional
selected_index = int(query)
current_state = graph.get_state(config)
stream = graph.stream(
Command(resume={
"action": "update",
"option": selected_index
}),
Command(resume={"action": "update", "option": selected_index}),
stream_mode="updates",
subgraphs=True,
config=config,
Expand All @@ -70,18 +70,20 @@ def event_stream(query: str, thread_id: Optional[str]=None, query_type: Optional

if node_name == "__interrupt__":
print(f"Waiting for human input")
interrupt_msg = chunk[node_name][0].value
interrupt_msg = chunk[node_name][0].value
question = interrupt_msg["question"]
options = interrupt_msg["options"]
artifact = interrupt_msg["artifact"]

print("Waiting for human input")
yield pack({
"type": "human_input",
"options": options,
"artifact": artifact,
"question": question
})
yield pack(
{
"type": "human_input",
"options": options,
"artifact": artifact,
"question": question,
}
)
elif node_name == "slasher":
pass
else:
Expand All @@ -92,23 +94,22 @@ def event_stream(query: str, thread_id: Optional[str]=None, query_type: Optional
continue
for msg in messages:
if isinstance(msg, ToolMessage):
yield pack({
"type": "tool_call",
"tool_name": msg.name,
"content": msg.content,
"artifact": msg.artifact if hasattr(msg, "artifact") else None,
})
yield pack(
{
"type": "tool_call",
"tool_name": msg.name,
"content": msg.content,
"artifact": (
msg.artifact if hasattr(msg, "artifact") else None
),
}
)
elif isinstance(msg, AIMessage):
if msg.content:
yield pack({
"type": "update",
"content": msg.content
})
yield pack({"type": "update", "content": msg.content})
else:
raise ValueError(f"Unknown message type: {type(msg)}")



# node_name = list(chunk.keys())[0]
# yield pack(chunk[node_name])
# print(f"Namespace {namespace}")
Expand Down Expand Up @@ -136,20 +137,20 @@ def event_stream(query: str, thread_id: Optional[str]=None, query_type: Optional
async def stream(
query: Annotated[str, Body(embed=True)],
thread_id: Optional[str] = Body(None),
query_type: Optional[str] = Body(None)):
query_type: Optional[str] = Body(None),
):

print("\n\nPOST...\n\n")
print(query, thread_id, query_type)
print("=" * 30)

return StreamingResponse(
event_stream(query, thread_id, query_type),
media_type="application/x-ndjson"
event_stream(query, thread_id, query_type), media_type="application/x-ndjson"
)


# Processes the query and returns the response
def process_query(query: str, thread_id: Optional[str]=None):
def process_query(query: str, thread_id: Optional[str] = None):

if not thread_id:
thread_id = uuid.uuid4()
Expand All @@ -166,5 +167,7 @@ def process_query(query: str, thread_id: Optional[str]=None):


@app.post("/query")
async def query(query: Annotated[str, Body(embed=True)], thread_id: Optional[str]=None):
async def query(
query: Annotated[str, Body(embed=True)], thread_id: Optional[str] = None
):
return process_query(query, thread_id)
4 changes: 3 additions & 1 deletion docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@ services:
- LANGFUSE_SECRET_KEY=lf_sk_1234567890
- ANTHROPIC_API_KEY=${ANTHROPIC_API_KEY}
- OPENAI_API_KEY=${OPENAI_API_KEY}
- GEE_SERVICE_ACCOUNT_PATH=${GEE_SERVICE_ACCOUNT_PATH}
env_file:
- .env

volumes:
- ./zeno:/app/zeno
- .:/app
#- ./data:/app/data

frontend:
build:
Expand Down
36 changes: 27 additions & 9 deletions frontend/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
st.subheader("🧐 Try asking:")
st.write(
"""
- Provide data about disturbance alerts in Aveiro summarized by landcover
- Provide data about disturbance alerts in Aveiro summarized by natural lands
- What is happening with Gold Mining Deforestation?
- What do you know about Forest Protection in remote islands in Indonesia?
- How many users are using GFW and how long did it take to get there?
Expand All @@ -50,6 +50,7 @@
"""
)


def display_message(message):
"""Helper function to display a single message"""
if message["role"] == "user":
Expand All @@ -60,7 +61,9 @@ def display_message(message):
data = message["content"]
artifact = data.get("artifact", {})
for feature in artifact["features"]:
st.chat_message("assistant").write(f"Found {feature['properties']['name']} in {feature['properties']['gadmid']}")
st.chat_message("assistant").write(
f"Found {feature['properties']['name']} in {feature['properties']['gadmid']}"
)

geometry = artifact["features"][0]["geometry"]
if geometry["type"] == "Polygon":
Expand All @@ -72,28 +75,34 @@ def display_message(message):
g = folium.GeoJson(artifact).add_to(m)
folium_static(m, width=700, height=500)
elif message["type"] == "alerts":
st.chat_message("assistant").write("Computing distributed alerts statistics...")
st.chat_message("assistant").write(
"Computing distributed alerts statistics..."
)
table = json.loads(message["content"]["content"])
st.bar_chart(pd.DataFrame(table).T)
elif message["type"] == "context":
st.chat_message("assistant").write(f"Adding context layer {message['content']}")
st.chat_message("assistant").write(
f"Adding context layer {message['content']}"
)
else:
st.chat_message("assistant").write(message["content"])


def handle_human_input_submission(selected_index):
if st.session_state.current_options and selected_index is not None:
with requests.post(
f"{API_BASE_URL}/stream",
json={
"query": str(selected_index),
"thread_id": st.session_state.zeno_session_id,
"query_type": "human_input"
"query_type": "human_input",
},
stream=True,
) as response:
print("\n POST HUMAN INPUT...\n")
handle_stream_response(response)


def handle_stream_response(stream):
for chunk in stream.iter_lines():
data = json.loads(chunk.decode("utf-8"))
Expand All @@ -112,9 +121,17 @@ def handle_stream_response(stream):
elif data.get("tool_name") == "dist-alerts-tool":
message = {"role": "assistant", "type": "alerts", "content": data}
elif data.get("tool_name") == "context-layer-tool":
message = {"role": "assistant", "type": "context", "content": data["content"]}
message = {
"role": "assistant",
"type": "context",
"content": data["content"],
}
else:
message = {"role": "assistant", "type": "text", "content": data["content"]}
message = {
"role": "assistant",
"type": "text",
"content": data["content"],
}

if message:
st.session_state.messages.append(message)
Expand All @@ -127,6 +144,7 @@ def handle_stream_response(stream):
else:
raise ValueError(f"Unknown message type: {data.get('type')}")


# Display chat history
for message in st.session_state.messages:
display_message(message)
Expand All @@ -136,7 +154,7 @@ def handle_stream_response(stream):
selected_option = st.selectbox(
st.session_state.current_question,
st.session_state.current_options,
key="selected_option"
key="selected_option",
)
selected_index = st.session_state.current_options.index(selected_option)
if st.button("Submit"):
Expand All @@ -157,7 +175,7 @@ def handle_stream_response(stream):
json={
"query": user_input,
"thread_id": st.session_state.zeno_session_id,
"query_type": "query"
"query_type": "query",
},
stream=True,
) as stream:
Expand Down
2 changes: 1 addition & 1 deletion frontend/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
streamlit==1.40.1
streamlit_folium==0.23.0
pandas==0.19.0
pandas==2.2.3
2 changes: 1 addition & 1 deletion nbs/04_rag-datasets.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.4"
"version": "3.10.13"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 84d97e6

Please sign in to comment.