Skip to content

Commit

Permalink
Merge pull request #35 from wri/rag-for-location
Browse files Browse the repository at this point in the history
Rag for location and improved streamlit
  • Loading branch information
yellowcap authored Dec 5, 2024
2 parents 8e4c215 + 80dad02 commit 4a6e860
Show file tree
Hide file tree
Showing 14 changed files with 240 additions and 74 deletions.
73 changes: 55 additions & 18 deletions api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from langfuse.callback import CallbackHandler
from zeno.agents.maingraph.agent import graph
from zeno.agents.maingraph.utils.state import GraphState

from langchain_core.messages import ToolMessage, AIMessage
app = FastAPI()
langfuse_handler = CallbackHandler()

Expand All @@ -35,32 +35,69 @@ def event_stream(query: str, thread_id: Optional[str]=None):
initial_state = GraphState(question=query)

config = {
"callbacks": [langfuse_handler],
# "callbacks": [langfuse_handler],
"configurable": {"thread_id": thread_id},
}

for namespace, data in graph.stream(
for namespace, chunk in graph.stream(
# for data in graph.stream(
initial_state,
stream_mode="updates",
subgraphs=True,
config=config,
):
node_name = list(chunk.keys())[0]
print(f"Namespace {namespace}")
for key, val in data.items():
print(f"Messenger is {key}")
if key in ["agent", "assistant"]:
continue
if not val:
continue
for key2, val2 in val.items():
if key2 == "messages":
for msg in val2:
if msg.content:
yield pack({"message": msg.content})
if hasattr(msg, "tool_calls") and msg.tool_calls:
yield pack({"tool_calls": msg.tool_calls})
if hasattr(msg, "artifact") and msg.artifact:
yield pack({"artifact": msg.artifact})
if not namespace:
continue
print(f"Node {node_name}")
if not chunk[node_name]:
continue
messages = chunk[node_name].get("messages")
if not messages:
continue
for msg in messages:
# print(msg)
# yield pack({
# "type":
# })
if isinstance(msg, ToolMessage):
yield pack({
"type": "tool",
"tool_name": msg.name,
"message": msg.content,
"artifact": msg.artifact if hasattr(msg, "artifact") else None,
})
elif isinstance(msg, AIMessage):
if msg.content:
yield pack({
"type": "assistant",
"message": msg.content
})



# node_name = list(chunk.keys())[0]
# yield pack(chunk[node_name])
# print(f"Namespace {namespace}")
# for key, val in data.items():
# print(f"Messenger is {key}")
# # if key in ["agent", "assistant"]:
# # continue
# if not val:
# continue
# for key2, val2 in val.items():
# print("Messenger2", key2)
# if key2 == "messages":
# for msg in val2:
# if msg.content:
# yield pack({"message": msg.content})
# if hasattr(msg, "tool_calls") and msg.tool_calls:
# yield pack({"tool_calls": msg.tool_calls})
# if hasattr(msg, "artifact") and msg.artifact:
# yield pack({"artifact": msg.artifact})
# else:
# print("NANANA", key, msg)


@app.post("/stream")
Expand Down
38 changes: 31 additions & 7 deletions frontend/app.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import json
import os
import uuid
import uuid

import folium
import pandas as pd
import requests
import streamlit as st
from streamlit_folium import folium_static

API_BASE_URL = os.environ.get("API_BASE_URL")

if 'zeno_session_id' not in st.session_state:
if "zeno_session_id" not in st.session_state:
st.session_state.zeno_session_id = str(uuid.uuid4())

st.header("Zeno")
Expand Down Expand Up @@ -49,17 +51,39 @@
) as stream:
for chunk in stream.iter_lines():
data = json.loads(chunk.decode("utf-8"))
if data.get("artifact", {}).get("type") == "FeatureCollection":
geom = data.get("artifact")["features"][0]["geometry"]
artifact = data.pop("artifact", {})
print(data)
if data.get("tool_name") == "dist-alerts-tool":
st.markdown("#### Dist alerts statistics")
table = json.loads(data["message"])
st.bar_chart(pd.DataFrame(table).T)
st.markdown("#### Map of dist alerts")
elif data.get("tool_name") == "context-layer-tool":
st.markdown("#### Context layer")
st.markdown(f"Using context layer **{data['message']}**")
elif data.get("tool_name") == "location-tool":
st.markdown("#### Matched location")
for feat in artifact["features"]:
st.markdown(
f'Found area **{feat["properties"]["gadmid"]}** {feat["properties"]["name"]}'
)
st.markdown("#### Map of location")
elif data.get("type") == "assistant":
st.markdown("#### Assistant message")
st.markdown(data["message"])
else:
st.write(data)
if artifact and artifact.get("type") == "FeatureCollection":
geom = artifact["features"][0]["geometry"]
if geom["type"] == "Polygon":
pnt = geom["coordinates"][0][0]
else:
pnt = geom["coordinates"][0][0][0]

m = folium.Map(location=[pnt[1], pnt[0]], zoom_start=11)
g = folium.GeoJson(
data.get("artifact"),
artifact,
).add_to(m)
folium_static(m, width=700, height=500)
else:
st.write(data)
elif artifact:
st.write(artifact)
1 change: 1 addition & 0 deletions frontend/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
streamlit==1.40.1
streamlit_folium==0.23.0
pandas==0.19.0
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies = [
"streamlit-folium>=0.23.2",
"earthengine-api>=1.4.0",
"geojson-pydantic>=1.1.2",
"fiona>=1.10.1",
]

[dependency-groups]
Expand Down
19 changes: 15 additions & 4 deletions tests/test_dist_agent.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,27 @@
from zeno.agents.distalert.agent import graph
from zeno.agents.maingraph.utils.state import GraphState

from langchain_core.messages import ToolMessage, AIMessage

def test_distalert_agent():
initial_state = GraphState(
question="Provide data about disturbance alerts in Aveiro summarized by landcover"
)
for level, data in graph.stream(
for namespace, chunk in graph.stream(
initial_state, stream_mode="updates", subgraphs=True
):
print(f"Level {level}")
for key, val in data.items():
node_name = list(chunk.keys())[0]
print(f"Namespace {namespace}")
print(f"Node {node_name}")
messages = chunk[node_name].get("messages")
if not messages:
continue
msg = messages[0]
if isinstance(msg, ToolMessage):
yield pack({msg.name, msg.content})
elif isinstance(msg, AIMessage):
yield pack

for key, val in chunk.items():
print(f"Messager is {key}")
for key2, val2 in val.items():
if key2 == "messages":
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dist_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def test_dist_alert_tool():

natural_lands = "WRI/SBTN/naturalLands/v1/2020"
features = ["PRT.6.2.5_1"]
features = ["23"]
result = dist_alerts_tool.invoke(
input={"features": features, "landcover": natural_lands, "threshold": 5}
)
Expand Down
21 changes: 12 additions & 9 deletions tests/test_location_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,25 @@


def test_location_matcher():
matcher = LocationMatcher("/Users/tam/Downloads/gadm41_PRT.gpkg")
matcher = LocationMatcher("data/gadm_410_small.gpkg")
# Test queries demonstrating priority order
test_queries = {
"lisboa": ["PRT.12.7.18_1", "PRT.12.7.49_1", "PRT.12.7.17_1"],
"Lamego": ["PRT.20.5.11_1", "PRT.9.6.4_1", "PRT.20.5.9_1"],
"Sao Joao": ["PRT.12.7.41_1", "PRT.12.7.45_1", "PRT.12.7.49_1"],
"Castelo Branco": ["PRT.6.2.5_1", "PRT.6.2.7_1", "PRT.6.2.10_1"],
"lisboa portugal": ["PRT.12.7.52_1"],
"Liisboa portugal": ["PRT.6.2.5_1"],
"Lisbon portugal": ["PRT.6.2.5_1"],
"Lamego viseu portugal": ['PRT.20.5.11_1'],
"Sao Joao Porto": ["PRT.12.7.41_1"],
"Bern Switzerland": ["PRT.6.2.5_1"],
}

for query, expected in test_queries.items():
matches = matcher.find_matches(query)
assert list(matches.GID_3) == expected
print(query, matches.name, matches.gadmid)
# assert list(matches.gadmid) == expected


def test_location_matcher_bbox():
matcher = LocationMatcher("/Users/tam/Downloads/gadm41_PRT.gpkg")
matcher = LocationMatcher("data/gadm_410_small.gpkg")
coords = (
-28.759502410999914,
38.517414093000184,
Expand All @@ -31,6 +34,6 @@ def test_location_matcher_bbox():


def test_location_matcher_id():
matcher = LocationMatcher("/Users/tam/Downloads/gadm41_PRT.gpkg")
matcher = LocationMatcher("data/gadm_410_small.gpkg")
result = matcher.get_by_id("PRT.6.2.5_1")
assert result["features"][0]["properties"]["GID_3"] == "PRT.6.2.5_1"
assert result["features"][0]["properties"]["gadmid"] == "PRT.6.2.5_1"
53 changes: 53 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions zeno/agents/docfinder/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
retrieve = ToolNode([retriever_tool])
workflow.add_node("retrieve", retrieve)
workflow.add_node("rewrite", rewrite)
workflow.add_node(
"generate", generate
)
workflow.add_node("generate", generate)
workflow.add_edge(START, "agent")

workflow.add_conditional_edges(
Expand Down
2 changes: 1 addition & 1 deletion zeno/agents/docfinder/utils/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from zeno.agents.maingraph.models import ModelFactory
from zeno.tools.docretrieve.document_retrieve_tool import retriever_tool

model_name = "claude-3-5-sonnet-latest"

model_name = "llama3.2"

def grade_documents(state, config: RunnableConfig) -> Literal["generate", "rewrite"]:
"""
Expand Down
Loading

0 comments on commit 4a6e860

Please sign in to comment.