Skip to content

Commit

Permalink
Merge pull request #100 from wri/add_tms_url_to_contextlayer
Browse files Browse the repository at this point in the history
Add tms url generation to context layer tool
  • Loading branch information
yellowcap authored Jan 10, 2025
2 parents a8bf96e + 6766e43 commit fccacf9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 18 deletions.
19 changes: 9 additions & 10 deletions tests/test_context_layer_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@


def test_context_layer_tool_cereal():
result = context_layer_tool.invoke(
input={"question": "Summarize disturbance alerts by type of cereal"}
msg = context_layer_tool.invoke(
{
"name": "context-layer-tool",
"args": {"question": "Summarize disturbance alerts by type of cereal"},
"id": "42",
"type": "tool_call",
}
)
assert result == "ESA/WorldCereal/2021/MODELS/v100"


def test_context_layer_tool_null():
result = context_layer_tool.invoke(
input={"question": "Provide disturbances for Aveiro Portugal"}
)
assert result == ""
assert msg.content == "ESA/WorldCereal/2021/MODELS/v100"
assert "{z}/{x}/{y}" in msg.artifact["tms_url"]
37 changes: 29 additions & 8 deletions zeno/tools/contextlayer/context_layer_retriever_tool.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import os

import ee
import lancedb
from langchain_community.vectorstores import LanceDB
from langchain_core.tools import tool
from langchain_core.tools.retriever import create_retriever_tool
from langchain_ollama import OllamaEmbeddings
from langchain_ollama.embeddings import OllamaEmbeddings
from pandas import Series
from pydantic import BaseModel, Field

from zeno.agents.maingraph.models import ModelFactory
from zeno.tools.contextlayer.layers import DatasetNames
from zeno.tools.distalert.gee import init_gee

init_gee()

embedder = OllamaEmbeddings(
model="nomic-embed-text", base_url=os.environ["OLLAMA_BASE_URL"]
Expand Down Expand Up @@ -38,7 +41,7 @@
class grade(BaseModel):
"""Choice of landcover."""

choice: DatasetNames = Field(description="Choice of context layer to use")
choice: str = Field(description="Choice of context layer to use")


class ContextLayerInput(BaseModel):
Expand All @@ -50,8 +53,8 @@ class ContextLayerInput(BaseModel):
model = ModelFactory().get("claude-3-5-sonnet-latest").with_structured_output(grade)


@tool("context-layer-tool", args_schema=ContextLayerInput, return_direct=False)
def context_layer_tool(question: str) -> DatasetNames:
@tool("context-layer-tool", args_schema=ContextLayerInput, response_format="content_and_artifact")
def context_layer_tool(question: str) -> dict:
"""
Determines whether the question asks for summarizing by land cover.
"""
Expand All @@ -66,11 +69,29 @@ def context_layer_tool(question: str) -> DatasetNames:
# if the results set contains multiple datasets with the same name
# as the top result, then we collect them all, and sort them by
# year to return the most recent, by default
results = (
result = (
results[results["name"] == results.iloc[0]["name"]]
.sort_values(by="year", ascending=False)
.iloc[0]
)

# return matches.dataset.value
return results.dataset
tms_url = get_tms_url(result)

result = result.to_dict()
result["tms_url"] = tms_url

dataset = result.pop("dataset")

return dataset, result


def get_tms_url(result: Series):
if result.type == "ImageCollection":
image = ee.ImageCollection(result.dataset).mosaic()
else:
image = ee.Image(result.dataset)

# TODO: add dynamic viz parameters
map_id = image.select(result.band).getMapId()

return map_id["tile_fetcher"].url_format

0 comments on commit fccacf9

Please sign in to comment.