diff --git a/CHANGELOG.md b/CHANGELOG.md index 94fcf0abd..9857e60df 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `AstraDbVectorStoreDriver` to support DataStax Astra DB as a vector store. - Ability to set custom schema properties on Tool Activities via `extra_schema_properties`. +- `extract_json` and `extract_csv` Activities to `TaskMemoryClient`. +- `extract_json_namespace` and `extract_csv_namespace` methods to `TaskMemory`. + +### Changed +- **BREAKING**: Split parameter `JsonExtractionEngine.template_generator` into `system_template_generator` and `user_template_generator`. +- **BREAKING**: Split parameter `CsvExtractionEngine.template_generator` into `system_template_generator` and `user_template_generator`. +- **BREAKING**: Split `JsonExtractionEngine.extract` into `extract_text` and `extract_artifacts`. +- **BREAKING**: Split `CsvExtractionEngine.extract` into `extract_text` and `extract_artifacts`. +- Parse json from LLM output before loading in `JsonExtractionEngine`. + +### Fixed +- Missing implementations of `csv_extraction_engine` and `json_extraction_engine` in `TextArtifactStorage`. ## [0.29.0] - 2024-07-30 diff --git a/docs/griptape-framework/engines/extraction-engines.md b/docs/griptape-framework/engines/extraction-engines.md index 496560968..e93c96a23 100644 --- a/docs/griptape-framework/engines/extraction-engines.md +++ b/docs/griptape-framework/engines/extraction-engines.md @@ -32,7 +32,7 @@ Charlie is 40 and lives in Texas. """ # Extract CSV rows using the engine -result = csv_engine.extract(sample_text, column_names=["name", "age", "location"]) +result = csv_engine.extract_text(sample_text, column_names=["name", "age", "location"]) for row in result.value: print(row.to_text()) @@ -73,7 +73,7 @@ user_schema = Schema( ).json_schema("UserSchema") # Extract data using the engine -result = json_engine.extract(sample_json_text, template_schema=user_schema) +result = json_engine.extract_text(sample_json_text, template_schema=user_schema) for artifact in result.value: print(artifact.value) diff --git a/griptape/engines/extraction/base_extraction_engine.py b/griptape/engines/extraction/base_extraction_engine.py index f263ee0aa..1bd831cf5 100644 --- a/griptape/engines/extraction/base_extraction_engine.py +++ b/griptape/engines/extraction/base_extraction_engine.py @@ -5,10 +5,11 @@ from attrs import Attribute, Factory, define, field +from griptape.artifacts import ListArtifact, TextArtifact from griptape.chunkers import BaseChunker, TextChunker if TYPE_CHECKING: - from griptape.artifacts import ErrorArtifact, ListArtifact + from griptape.artifacts import ErrorArtifact from griptape.drivers import BasePromptDriver from griptape.rules import Ruleset @@ -45,10 +46,15 @@ def min_response_tokens(self) -> int: ) @abstractmethod - def extract( + def extract_artifacts( self, - text: str | ListArtifact, + artifacts: ListArtifact, *, rulesets: Optional[list[Ruleset]] = None, **kwargs, ) -> ListArtifact | ErrorArtifact: ... + + def extract_text( + self, text: str, *, rulesets: Optional[list[Ruleset]] = None, **kwargs + ) -> ListArtifact | ErrorArtifact: + return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs) diff --git a/griptape/engines/extraction/csv_extraction_engine.py b/griptape/engines/extraction/csv_extraction_engine.py index 3184654b1..50d37d813 100644 --- a/griptape/engines/extraction/csv_extraction_engine.py +++ b/griptape/engines/extraction/csv_extraction_engine.py @@ -18,11 +18,12 @@ @define class CsvExtractionEngine(BaseExtractionEngine): - template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv_extraction.j2")), kw_only=True) + system_template_generator: J2 = field(default=Factory(lambda: J2("engines/csv_extraction/system.j2")), kw_only=True) + user_template_generator: J2 = field(default=Factory(lambda: J2("engines/csv_extraction/user.j2")), kw_only=True) - def extract( + def extract_artifacts( self, - text: str | ListArtifact, + artifacts: ListArtifact, *, rulesets: Optional[list[Ruleset]] = None, column_names: Optional[list[str]] = None, @@ -33,7 +34,7 @@ def extract( try: return ListArtifact( self._extract_rec( - cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], + cast(list[TextArtifact], artifacts.value), column_names, [], rulesets=rulesets, @@ -60,16 +61,28 @@ def _extract_rec( rulesets: Optional[list[Ruleset]] = None, ) -> list[CsvRowArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) - full_text = self.template_generator.render( + system_prompt = self.system_template_generator.render( column_names=column_names, - text=artifacts_text, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) + user_prompt = self.user_template_generator.render( + text=artifacts_text, + ) - if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: + if ( + self.prompt_driver.tokenizer.count_input_tokens_left(system_prompt + user_prompt) + >= self.min_response_tokens + ): rows.extend( self.text_to_csv_rows( - self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value, + self.prompt_driver.run( + PromptStack( + messages=[ + Message(system_prompt, role=Message.SYSTEM_ROLE), + Message(user_prompt, role=Message.USER_ROLE), + ] + ) + ).value, column_names, ), ) @@ -77,15 +90,20 @@ def _extract_rec( return rows else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.template_generator.render( - column_names=column_names, + partial_text = self.user_template_generator.render( text=chunks[0].value, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) rows.extend( self.text_to_csv_rows( - self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value, + self.prompt_driver.run( + PromptStack( + messages=[ + Message(system_prompt, role=Message.SYSTEM_ROLE), + Message(partial_text, role=Message.USER_ROLE), + ] + ) + ).value, column_names, ), ) diff --git a/griptape/engines/extraction/json_extraction_engine.py b/griptape/engines/extraction/json_extraction_engine.py index 436fc093f..165dbc4f1 100644 --- a/griptape/engines/extraction/json_extraction_engine.py +++ b/griptape/engines/extraction/json_extraction_engine.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import re from typing import TYPE_CHECKING, Optional, cast from attrs import Factory, define, field @@ -17,11 +18,16 @@ @define class JsonExtractionEngine(BaseExtractionEngine): - template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json_extraction.j2")), kw_only=True) + JSON_PATTERN = r"(?s)[^\[]*(\[.*\])" - def extract( + system_template_generator: J2 = field( + default=Factory(lambda: J2("engines/json_extraction/system.j2")), kw_only=True + ) + user_template_generator: J2 = field(default=Factory(lambda: J2("engines/json_extraction/user.j2")), kw_only=True) + + def extract_artifacts( self, - text: str | ListArtifact, + artifacts: ListArtifact, *, rulesets: Optional[list[Ruleset]] = None, template_schema: Optional[list[dict]] = None, @@ -34,7 +40,7 @@ def extract( return ListArtifact( self._extract_rec( - cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)], + cast(list[TextArtifact], artifacts.value), json_schema, [], rulesets=rulesets, @@ -45,7 +51,12 @@ def extract( return ErrorArtifact(f"error extracting JSON: {e}") def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]: - return [TextArtifact(json.dumps(e)) for e in json.loads(json_input)] + json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL) + + if json_matches: + return [TextArtifact(json.dumps(e)) for e in json.loads(json_matches[-1])] + else: + return [] def _extract_rec( self, @@ -55,31 +66,48 @@ def _extract_rec( rulesets: Optional[list[Ruleset]] = None, ) -> list[TextArtifact]: artifacts_text = self.chunk_joiner.join([a.value for a in artifacts]) - full_text = self.template_generator.render( + system_prompt = self.system_template_generator.render( json_template_schema=json_template_schema, - text=artifacts_text, rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) + user_prompt = self.user_template_generator.render( + text=artifacts_text, + ) - if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens: + if ( + self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt) + >= self.min_response_tokens + ): extractions.extend( self.json_to_text_artifacts( - self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value, + self.prompt_driver.run( + PromptStack( + messages=[ + Message(system_prompt, role=Message.SYSTEM_ROLE), + Message(user_prompt, role=Message.USER_ROLE), + ] + ) + ).value ), ) return extractions else: chunks = self.chunker.chunk(artifacts_text) - partial_text = self.template_generator.render( - template_schema=json_template_schema, + partial_text = self.user_template_generator.render( text=chunks[0].value, - rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets), ) extractions.extend( self.json_to_text_artifacts( - self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value, + self.prompt_driver.run( + PromptStack( + messages=[ + Message(system_prompt, role=Message.SYSTEM_ROLE), + Message(partial_text, role=Message.USER_ROLE), + ] + ) + ).value, ), ) diff --git a/griptape/memory/task/storage/base_artifact_storage.py b/griptape/memory/task/storage/base_artifact_storage.py index 866df19da..2d7eeb47a 100644 --- a/griptape/memory/task/storage/base_artifact_storage.py +++ b/griptape/memory/task/storage/base_artifact_storage.py @@ -6,7 +6,7 @@ from attrs import define if TYPE_CHECKING: - from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, TextArtifact + from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact @define @@ -25,3 +25,9 @@ def summarize(self, namespace: str) -> TextArtifact | InfoArtifact: ... @abstractmethod def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: ... + + @abstractmethod + def extract_csv(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact: ... + + @abstractmethod + def extract_json(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact: ... diff --git a/griptape/memory/task/storage/blob_artifact_storage.py b/griptape/memory/task/storage/blob_artifact_storage.py index 6199dc3a3..16dfbd97a 100644 --- a/griptape/memory/task/storage/blob_artifact_storage.py +++ b/griptape/memory/task/storage/blob_artifact_storage.py @@ -32,3 +32,9 @@ def summarize(self, namespace: str) -> InfoArtifact: def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: return InfoArtifact("can't query artifacts") + + def extract_csv(self, namespace: str) -> InfoArtifact: + return InfoArtifact("can't extract csv") + + def extract_json(self, namespace: str) -> InfoArtifact: + return InfoArtifact("can't extract json") diff --git a/griptape/memory/task/storage/text_artifact_storage.py b/griptape/memory/task/storage/text_artifact_storage.py index 8e66c5aba..aa696ef83 100644 --- a/griptape/memory/task/storage/text_artifact_storage.py +++ b/griptape/memory/task/storage/text_artifact_storage.py @@ -9,6 +9,7 @@ from griptape.memory.task.storage import BaseArtifactStorage if TYPE_CHECKING: + from griptape.artifacts import ErrorArtifact from griptape.drivers import BaseVectorStoreDriver from griptape.engines import BaseSummaryEngine, CsvExtractionEngine, JsonExtractionEngine @@ -45,6 +46,18 @@ def summarize(self, namespace: str) -> TextArtifact: return self.summary_engine.summarize_artifacts(self.load_artifacts(namespace)) + def extract_csv(self, namespace: str) -> ListArtifact | ErrorArtifact: + if self.csv_extraction_engine is None: + raise ValueError("Csv extraction engine is not set.") + + return self.csv_extraction_engine.extract_artifacts(self.load_artifacts(namespace)) + + def extract_json(self, namespace: str) -> ListArtifact | ErrorArtifact: + if self.json_extraction_engine is None: + raise ValueError("Json extraction engine is not set.") + + return self.json_extraction_engine.extract_artifacts(self.load_artifacts(namespace)) + def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: if self.rag_engine is None: raise ValueError("rag_engine is not set") diff --git a/griptape/memory/task/task_memory.py b/griptape/memory/task/task_memory.py index e2131d1f0..e77a0ed01 100644 --- a/griptape/memory/task/task_memory.py +++ b/griptape/memory/task/task_memory.py @@ -139,3 +139,19 @@ def query_namespace(self, namespace: str, query: str) -> BaseArtifact: return storage.query(namespace=namespace, query=query, metadata=self.namespace_metadata.get(namespace)) else: return InfoArtifact("Can't find memory content") + + def extract_json_namespace(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact: + storage = self.namespace_storage.get(namespace) + + if storage: + return storage.extract_json(namespace) + else: + return ErrorArtifact("Can't find memory content") + + def extract_csv_namespace(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact: + storage = self.namespace_storage.get(namespace) + + if storage: + return storage.extract_csv(namespace) + else: + return ErrorArtifact("Can't find memory content") diff --git a/griptape/tasks/extraction_task.py b/griptape/tasks/extraction_task.py index d8f492693..408380d0b 100644 --- a/griptape/tasks/extraction_task.py +++ b/griptape/tasks/extraction_task.py @@ -21,4 +21,4 @@ def extraction_engine(self) -> BaseExtractionEngine: return self._extraction_engine def run(self) -> ListArtifact | ErrorArtifact: - return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args) + return self.extraction_engine.extract_text(self.input.to_text(), rulesets=self.all_rulesets, **self.args) diff --git a/griptape/templates/engines/csv_extraction/system.j2 b/griptape/templates/engines/csv_extraction/system.j2 new file mode 100644 index 000000000..7c5776257 --- /dev/null +++ b/griptape/templates/engines/csv_extraction/system.j2 @@ -0,0 +1,7 @@ +Don't add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes. +Column Names: """{{ column_names }}""" + +{% if rulesets %} + +{{ rulesets }} +{% endif %} diff --git a/griptape/templates/engines/csv_extraction/user.j2 b/griptape/templates/engines/csv_extraction/user.j2 new file mode 100644 index 000000000..0f33dadc3 --- /dev/null +++ b/griptape/templates/engines/csv_extraction/user.j2 @@ -0,0 +1,4 @@ +Extract information from the Text based on the Column Names and output it as a CSV file. +Text: """{{ text }}""" + +Answer: diff --git a/griptape/templates/engines/extraction/csv_extraction.j2 b/griptape/templates/engines/extraction/csv_extraction.j2 deleted file mode 100644 index 6f9da346b..000000000 --- a/griptape/templates/engines/extraction/csv_extraction.j2 +++ /dev/null @@ -1,11 +0,0 @@ -Text: """{{ text }}""" - -Column Names: """{{ column_names }}""" - -Extract information from the Text based on the Column Names and output it as a CSV file. Don't add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes. -{% if rulesets %} - -{{ rulesets }} -{% endif %} - -Answer: diff --git a/griptape/templates/engines/json_extraction/system.j2 b/griptape/templates/engines/json_extraction/system.j2 new file mode 100644 index 000000000..987ff19a9 --- /dev/null +++ b/griptape/templates/engines/json_extraction/system.j2 @@ -0,0 +1,6 @@ +Extraction Template JSON Schema: """{{ json_template_schema }}""" + +{% if rulesets %} + +{{ rulesets }} +{% endif %} diff --git a/griptape/templates/engines/extraction/json_extraction.j2 b/griptape/templates/engines/json_extraction/user.j2 similarity index 56% rename from griptape/templates/engines/extraction/json_extraction.j2 rename to griptape/templates/engines/json_extraction/user.j2 index 85d95bef9..984977d9a 100644 --- a/griptape/templates/engines/extraction/json_extraction.j2 +++ b/griptape/templates/engines/json_extraction/user.j2 @@ -1,11 +1,4 @@ -Text: """{{ text }}""" - -Extraction Template JSON Schema: """{{ json_template_schema }}""" - Extract information from the Text based on the Extraction Template JSON Schema into an array of JSON objects. -{% if rulesets %} - -{{ rulesets }} -{% endif %} +Text: """{{ text }}""" JSON array: diff --git a/griptape/tools/task_memory_client/tool.py b/griptape/tools/task_memory_client/tool.py index 160a54d85..59f281b6d 100644 --- a/griptape/tools/task_memory_client/tool.py +++ b/griptape/tools/task_memory_client/tool.py @@ -3,17 +3,22 @@ from attrs import define from schema import Literal, Schema -from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, TextArtifact +from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.tools import BaseTool from griptape.utils.decorators import activity @define class TaskMemoryClient(BaseTool): + ARTIFACT_REFERENCE_SCHEMA = { + "memory_name": str, + "artifact_namespace": str, + } + @activity( config={ "description": "Can be used to summarize memory content", - "schema": Schema({"memory_name": str, "artifact_namespace": str}), + "schema": Schema(ARTIFACT_REFERENCE_SCHEMA), }, ) def summarize(self, params: dict) -> TextArtifact | InfoArtifact | ErrorArtifact: @@ -30,8 +35,7 @@ def summarize(self, params: dict) -> TextArtifact | InfoArtifact | ErrorArtifact "description": "Can be used to search and query memory content", "schema": Schema( { - "memory_name": str, - "artifact_namespace": str, + **ARTIFACT_REFERENCE_SCHEMA, Literal( "query", description="A natural language search query in the form of a question with enough " @@ -50,3 +54,33 @@ def query(self, params: dict) -> BaseArtifact: return memory.query_namespace(namespace=artifact_namespace, query=query) else: return ErrorArtifact("memory not found") + + @activity( + config={ + "description": "Can be used extract memory content in JSON format", + "schema": Schema(ARTIFACT_REFERENCE_SCHEMA), + }, + ) + def extract_json(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: + memory = self.find_input_memory(params["values"]["memory_name"]) + artifact_namespace = params["values"]["artifact_namespace"] + + if memory: + return memory.extract_json_namespace(artifact_namespace) + else: + return ErrorArtifact("memory not found") + + @activity( + config={ + "description": "Can be used extract memory content in CSV format", + "schema": Schema(ARTIFACT_REFERENCE_SCHEMA), + }, + ) + def extract_csv(self, params: dict) -> ListArtifact | InfoArtifact | ErrorArtifact: + memory = self.find_input_memory(params["values"]["memory_name"]) + artifact_namespace = params["values"]["artifact_namespace"] + + if memory: + return memory.extract_csv_namespace(artifact_namespace) + else: + return ErrorArtifact("memory not found") diff --git a/tests/unit/engines/extraction/test_csv_extraction_engine.py b/tests/unit/engines/extraction/test_csv_extraction_engine.py index f69d8a0ba..ff6998581 100644 --- a/tests/unit/engines/extraction/test_csv_extraction_engine.py +++ b/tests/unit/engines/extraction/test_csv_extraction_engine.py @@ -10,7 +10,7 @@ def engine(self): return CsvExtractionEngine(prompt_driver=MockPromptDriver()) def test_extract(self, engine): - result = engine.extract("foo", column_names=["test1"]) + result = engine.extract_text("foo", column_names=["test1"]) assert len(result.value) == 1 assert result.value[0].value == {"test1": "mock output"} diff --git a/tests/unit/engines/extraction/test_json_extraction_engine.py b/tests/unit/engines/extraction/test_json_extraction_engine.py index d95adbb43..34ac0dae5 100644 --- a/tests/unit/engines/extraction/test_json_extraction_engine.py +++ b/tests/unit/engines/extraction/test_json_extraction_engine.py @@ -17,14 +17,14 @@ def engine(self): def test_extract(self, engine): json_schema = Schema({"foo": "bar"}).json_schema("TemplateSchema") - result = engine.extract("foo", template_schema=json_schema) + result = engine.extract_text("foo", template_schema=json_schema) assert len(result.value) == 2 assert result.value[0].value == '{"test_key_1": "test_value_1"}' assert result.value[1].value == '{"test_key_2": "test_value_2"}' def test_extract_error(self, engine): - assert isinstance(engine.extract("foo", template_schema=lambda: "non serializable"), ErrorArtifact) + assert isinstance(engine.extract_text("foo", template_schema=lambda: "non serializable"), ErrorArtifact) def test_json_to_text_artifacts(self, engine): assert [ diff --git a/tests/unit/memory/tool/storage/test_blob_artifact_storage.py b/tests/unit/memory/tool/storage/test_blob_artifact_storage.py index c7f2cfcbd..a9df191ce 100644 --- a/tests/unit/memory/tool/storage/test_blob_artifact_storage.py +++ b/tests/unit/memory/tool/storage/test_blob_artifact_storage.py @@ -35,3 +35,13 @@ def test_query(self, storage): storage.store_artifact("foo", BlobArtifact(b"test")) assert storage.query("foo", "query").value == "can't query artifacts" + + def test_json_extraction_namespace(self, storage): + storage.store_artifact("foo", BlobArtifact(b"test")) + + assert storage.query("foo", "query").value == "can't query artifacts" + + def test_csv_extraction_namespace(self, storage): + storage.store_artifact("foo", BlobArtifact(b"test")) + + assert storage.query("foo", "query").value == "can't query artifacts" diff --git a/tests/unit/memory/tool/storage/test_text_artifact_storage.py b/tests/unit/memory/tool/storage/test_text_artifact_storage.py index 64f44c581..01c954467 100644 --- a/tests/unit/memory/tool/storage/test_text_artifact_storage.py +++ b/tests/unit/memory/tool/storage/test_text_artifact_storage.py @@ -35,3 +35,13 @@ def test_query(self, storage): storage.store_artifact("foo", TextArtifact("test")) assert storage.query("foo", "query").value == "mock output" + + def test_json_extraction_namespace(self, storage): + storage.store_artifact("foo", TextArtifact("test")) + + assert storage.extract_json("foo").value == [] + + def test_csv_extraction_namespace(self, storage): + storage.store_artifact("foo", TextArtifact("test")) + + assert storage.extract_csv("foo").value[0].value == {} diff --git a/tests/unit/memory/tool/test_task_memory.py b/tests/unit/memory/tool/test_task_memory.py index 53e4703a6..826dcc86d 100644 --- a/tests/unit/memory/tool/test_task_memory.py +++ b/tests/unit/memory/tool/test_task_memory.py @@ -1,6 +1,6 @@ import pytest -from griptape.artifacts import BlobArtifact, CsvRowArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact +from griptape.artifacts import BlobArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact from griptape.memory import TaskMemory from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage from griptape.structures import Agent @@ -10,10 +10,6 @@ class TestTaskMemory: - @pytest.fixture(autouse=True) - def _mock_griptape(self, mocker): - mocker.patch("griptape.engines.CsvExtractionEngine.extract", return_value=[CsvRowArtifact({"foo": "bar"})]) - @pytest.fixture() def memory(self): return defaults.text_task_memory("MyMemory") diff --git a/tests/unit/tools/test_task_memory_client.py b/tests/unit/tools/test_task_memory_client.py index 4276b89ec..60495425c 100644 --- a/tests/unit/tools/test_task_memory_client.py +++ b/tests/unit/tools/test_task_memory_client.py @@ -27,3 +27,21 @@ def test_query(self, tool): ).value == "mock output" ) + + def test_extract_json(self, tool): + tool.input_memory[0].store_artifact("foo", TextArtifact("test")) + + assert ( + tool.extract_json({"values": {"memory_name": tool.input_memory[0].name, "artifact_namespace": "foo"}}).value + == [] + ) + + def test_extract_csv(self, tool): + tool.input_memory[0].store_artifact("foo", TextArtifact("test")) + + assert ( + tool.extract_csv({"values": {"memory_name": tool.input_memory[0].name, "artifact_namespace": "foo"}}) + .value[0] + .value + == {} + )