From 3542d06fa986059dcea11b5cbec5fe738c37ba72 Mon Sep 17 00:00:00 2001 From: Matt Vallillo Date: Fri, 25 Oct 2024 14:16:10 -0500 Subject: [PATCH] Add default for getting run_id from metadata (#1289) --- .../griptape_cloud_conversation_memory_driver.py | 11 ++++++----- .../test_griptape_cloud_conversation_memory_driver.py | 8 +++++++- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py b/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py index 6c1783519..fed70fc87 100644 --- a/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/griptape_cloud_conversation_memory_driver.py @@ -128,13 +128,14 @@ def load(self) -> tuple[list[Run], dict[str, Any]]: runs = [ Run( - id=m["metadata"].pop("run_id"), - meta=m["metadata"], - input=BaseArtifact.from_json(m["input"]), - output=BaseArtifact.from_json(m["output"]), + **({"id": message["metadata"].pop("run_id", None)} if "run_id" in message.get("metadata") else {}), + meta=message["metadata"], + input=BaseArtifact.from_json(message["input"]), + output=BaseArtifact.from_json(message["output"]), ) - for m in messages_response.get("messages", []) + for message in messages_response.get("messages", []) ] + return runs, thread_response.get("metadata", {}) def _get_url(self, path: str) -> str: diff --git a/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py b/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py index 0c76d6ecd..2c376ef4d 100644 --- a/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py +++ b/tests/unit/drivers/memory/conversation/test_griptape_cloud_conversation_memory_driver.py @@ -25,7 +25,7 @@ def request(*args, **kwargs): "message_id": f"{thread_id}_message", "input": '{"type": "TextArtifact", "id": "1234", "value": "Hi There, Hello"}', "output": '{"type": "TextArtifact", "id": "123", "value": "Hello! How can I assist you today?"}', - "metadata": {"run_id": "1234"}, + "metadata": {"run_id": "1234"} if thread_id != "no_meta" else {}, } ] } @@ -118,3 +118,9 @@ def test_load(self, driver): assert len(runs) == 1 assert runs[0].id == "1234" assert metadata == {"foo": "bar"} + + def test_load_no_message_meta(self, driver): + driver.thread_id = "no_meta" + runs, metadata = driver.load() + assert len(runs) == 1 + assert metadata == {"foo": "bar"}