Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve JSON extraction, add extraction to Task Memory #1044

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `AstraDbVectorStoreDriver` to support DataStax Astra DB as a vector store.
- Ability to set custom schema properties on Tool Activities via `extra_schema_properties`.
- `extract_json` and `extract_csv` Activities to `TaskMemoryClient`.
- `extract_json_namespace` and `extract_csv_namespace` methods to `TaskMemory`.

### Changed
- **BREAKING**: Split parameter `JsonExtractionEngine.template_generator` into `system_template_generator` and `user_template_generator`.
- **BREAKING**: Split parameter `CsvExtractionEngine.template_generator` into `system_template_generator` and `user_template_generator`.
- **BREAKING**: Split `JsonExtractionEngine.extract` into `extract_text` and `extract_artifacts`.
- **BREAKING**: Split `CsvExtractionEngine.extract` into `extract_text` and `extract_artifacts`.
- Parse json from LLM output before loading in `JsonExtractionEngine`.

### Fixed
- Missing implementations of `csv_extraction_engine` and `json_extraction_engine` in `TextArtifactStorage`.

## [0.29.0] - 2024-07-30

Expand Down
4 changes: 2 additions & 2 deletions docs/griptape-framework/engines/extraction-engines.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Charlie is 40 and lives in Texas.
"""

# Extract CSV rows using the engine
result = csv_engine.extract(sample_text, column_names=["name", "age", "location"])
result = csv_engine.extract_text(sample_text, column_names=["name", "age", "location"])

for row in result.value:
print(row.to_text())
Expand Down Expand Up @@ -73,7 +73,7 @@ user_schema = Schema(
).json_schema("UserSchema")

# Extract data using the engine
result = json_engine.extract(sample_json_text, template_schema=user_schema)
result = json_engine.extract_text(sample_json_text, template_schema=user_schema)

for artifact in result.value:
print(artifact.value)
Expand Down
12 changes: 9 additions & 3 deletions griptape/engines/extraction/base_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@

from attrs import Attribute, Factory, define, field

from griptape.artifacts import ListArtifact, TextArtifact
from griptape.chunkers import BaseChunker, TextChunker

if TYPE_CHECKING:
from griptape.artifacts import ErrorArtifact, ListArtifact
from griptape.artifacts import ErrorArtifact
from griptape.drivers import BasePromptDriver
from griptape.rules import Ruleset

Expand Down Expand Up @@ -45,10 +46,15 @@ def min_response_tokens(self) -> int:
)

@abstractmethod
def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact,
*,
rulesets: Optional[list[Ruleset]] = None,
**kwargs,
) -> ListArtifact | ErrorArtifact: ...

def extract_text(
self, text: str, *, rulesets: Optional[list[Ruleset]] = None, **kwargs
) -> ListArtifact | ErrorArtifact:
return self.extract_artifacts(ListArtifact([TextArtifact(text)]), rulesets=rulesets, **kwargs)
42 changes: 30 additions & 12 deletions griptape/engines/extraction/csv_extraction_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

@define
class CsvExtractionEngine(BaseExtractionEngine):
template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/csv_extraction.j2")), kw_only=True)
system_template_generator: J2 = field(default=Factory(lambda: J2("engines/csv_extraction/system.j2")), kw_only=True)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/csv_extraction/user.j2")), kw_only=True)

def extract(
def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact,
*,
rulesets: Optional[list[Ruleset]] = None,
column_names: Optional[list[str]] = None,
Expand All @@ -33,7 +34,7 @@
try:
return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)],
cast(list[TextArtifact], artifacts.value),
column_names,
[],
rulesets=rulesets,
Expand All @@ -60,32 +61,49 @@
rulesets: Optional[list[Ruleset]] = None,
) -> list[CsvRowArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
full_text = self.template_generator.render(
system_prompt = self.system_template_generator.render(
column_names=column_names,
text=artifacts_text,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)
user_prompt = self.user_template_generator.render(
text=artifacts_text,
)

if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens:
if (
self.prompt_driver.tokenizer.count_input_tokens_left(system_prompt + user_prompt)
>= self.min_response_tokens
):
rows.extend(
self.text_to_csv_rows(
self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value,
self.prompt_driver.run(
PromptStack(
messages=[
Message(system_prompt, role=Message.SYSTEM_ROLE),
Message(user_prompt, role=Message.USER_ROLE),
]
)
).value,
column_names,
),
)

return rows
else:
chunks = self.chunker.chunk(artifacts_text)
partial_text = self.template_generator.render(
column_names=column_names,
partial_text = self.user_template_generator.render(

Check warning on line 93 in griptape/engines/extraction/csv_extraction_engine.py

View check run for this annotation

Codecov / codecov/patch

griptape/engines/extraction/csv_extraction_engine.py#L93

Added line #L93 was not covered by tests
text=chunks[0].value,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)

rows.extend(
self.text_to_csv_rows(
self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value,
self.prompt_driver.run(
PromptStack(
messages=[
Message(system_prompt, role=Message.SYSTEM_ROLE),
Message(partial_text, role=Message.USER_ROLE),
]
)
).value,
column_names,
),
)
Expand Down
54 changes: 41 additions & 13 deletions griptape/engines/extraction/json_extraction_engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import re
from typing import TYPE_CHECKING, Optional, cast

from attrs import Factory, define, field
Expand All @@ -17,11 +18,16 @@

@define
class JsonExtractionEngine(BaseExtractionEngine):
template_generator: J2 = field(default=Factory(lambda: J2("engines/extraction/json_extraction.j2")), kw_only=True)
JSON_PATTERN = r"(?s)[^\[]*(\[.*\])"

def extract(
system_template_generator: J2 = field(
default=Factory(lambda: J2("engines/json_extraction/system.j2")), kw_only=True
)
user_template_generator: J2 = field(default=Factory(lambda: J2("engines/json_extraction/user.j2")), kw_only=True)

def extract_artifacts(
self,
text: str | ListArtifact,
artifacts: ListArtifact,
*,
rulesets: Optional[list[Ruleset]] = None,
template_schema: Optional[list[dict]] = None,
Expand All @@ -34,7 +40,7 @@

return ListArtifact(
self._extract_rec(
cast(list[TextArtifact], text.value) if isinstance(text, ListArtifact) else [TextArtifact(text)],
cast(list[TextArtifact], artifacts.value),
json_schema,
[],
rulesets=rulesets,
Expand All @@ -45,7 +51,12 @@
return ErrorArtifact(f"error extracting JSON: {e}")

def json_to_text_artifacts(self, json_input: str) -> list[TextArtifact]:
return [TextArtifact(json.dumps(e)) for e in json.loads(json_input)]
json_matches = re.findall(self.JSON_PATTERN, json_input, re.DOTALL)

if json_matches:
return [TextArtifact(json.dumps(e)) for e in json.loads(json_matches[-1])]
else:
return []

def _extract_rec(
self,
Expand All @@ -55,31 +66,48 @@
rulesets: Optional[list[Ruleset]] = None,
) -> list[TextArtifact]:
artifacts_text = self.chunk_joiner.join([a.value for a in artifacts])
full_text = self.template_generator.render(
system_prompt = self.system_template_generator.render(
json_template_schema=json_template_schema,
text=artifacts_text,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)
user_prompt = self.user_template_generator.render(
text=artifacts_text,
)

if self.prompt_driver.tokenizer.count_input_tokens_left(full_text) >= self.min_response_tokens:
if (
self.prompt_driver.tokenizer.count_input_tokens_left(user_prompt + system_prompt)
>= self.min_response_tokens
):
extractions.extend(
self.json_to_text_artifacts(
self.prompt_driver.run(PromptStack(messages=[Message(full_text, role=Message.USER_ROLE)])).value,
self.prompt_driver.run(
PromptStack(
messages=[
Message(system_prompt, role=Message.SYSTEM_ROLE),
Message(user_prompt, role=Message.USER_ROLE),
]
)
).value
),
)

return extractions
else:
chunks = self.chunker.chunk(artifacts_text)
partial_text = self.template_generator.render(
template_schema=json_template_schema,
partial_text = self.user_template_generator.render(

Check warning on line 97 in griptape/engines/extraction/json_extraction_engine.py

View check run for this annotation

Codecov / codecov/patch

griptape/engines/extraction/json_extraction_engine.py#L97

Added line #L97 was not covered by tests
text=chunks[0].value,
rulesets=J2("rulesets/rulesets.j2").render(rulesets=rulesets),
)

extractions.extend(
self.json_to_text_artifacts(
self.prompt_driver.run(PromptStack(messages=[Message(partial_text, role=Message.USER_ROLE)])).value,
self.prompt_driver.run(
PromptStack(
messages=[
Message(system_prompt, role=Message.SYSTEM_ROLE),
Message(partial_text, role=Message.USER_ROLE),
]
)
).value,
),
)

Expand Down
8 changes: 7 additions & 1 deletion griptape/memory/task/storage/base_artifact_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from attrs import define

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact, InfoArtifact, ListArtifact, TextArtifact
from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, ListArtifact, TextArtifact


@define
Expand All @@ -25,3 +25,9 @@ def summarize(self, namespace: str) -> TextArtifact | InfoArtifact: ...

@abstractmethod
def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact: ...

@abstractmethod
def extract_csv(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact: ...

@abstractmethod
def extract_json(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact: ...
6 changes: 6 additions & 0 deletions griptape/memory/task/storage/blob_artifact_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@

def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact:
return InfoArtifact("can't query artifacts")

def extract_csv(self, namespace: str) -> InfoArtifact:
return InfoArtifact("can't extract csv")

Check warning on line 37 in griptape/memory/task/storage/blob_artifact_storage.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/storage/blob_artifact_storage.py#L37

Added line #L37 was not covered by tests

def extract_json(self, namespace: str) -> InfoArtifact:
return InfoArtifact("can't extract json")

Check warning on line 40 in griptape/memory/task/storage/blob_artifact_storage.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/storage/blob_artifact_storage.py#L40

Added line #L40 was not covered by tests
13 changes: 13 additions & 0 deletions griptape/memory/task/storage/text_artifact_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from griptape.memory.task.storage import BaseArtifactStorage

if TYPE_CHECKING:
from griptape.artifacts import ErrorArtifact
from griptape.drivers import BaseVectorStoreDriver
from griptape.engines import BaseSummaryEngine, CsvExtractionEngine, JsonExtractionEngine

Expand Down Expand Up @@ -45,6 +46,18 @@

return self.summary_engine.summarize_artifacts(self.load_artifacts(namespace))

def extract_csv(self, namespace: str) -> ListArtifact | ErrorArtifact:
if self.csv_extraction_engine is None:
raise ValueError("Csv extraction engine is not set.")

Check warning on line 51 in griptape/memory/task/storage/text_artifact_storage.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/storage/text_artifact_storage.py#L51

Added line #L51 was not covered by tests

return self.csv_extraction_engine.extract_artifacts(self.load_artifacts(namespace))

def extract_json(self, namespace: str) -> ListArtifact | ErrorArtifact:
if self.json_extraction_engine is None:
raise ValueError("Json extraction engine is not set.")

Check warning on line 57 in griptape/memory/task/storage/text_artifact_storage.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/storage/text_artifact_storage.py#L57

Added line #L57 was not covered by tests

return self.json_extraction_engine.extract_artifacts(self.load_artifacts(namespace))

def query(self, namespace: str, query: str, metadata: Any = None) -> BaseArtifact:
if self.rag_engine is None:
raise ValueError("rag_engine is not set")
Expand Down
16 changes: 16 additions & 0 deletions griptape/memory/task/task_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,19 @@
return storage.query(namespace=namespace, query=query, metadata=self.namespace_metadata.get(namespace))
else:
return InfoArtifact("Can't find memory content")

def extract_json_namespace(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact:
storage = self.namespace_storage.get(namespace)

if storage:
return storage.extract_json(namespace)
else:
return ErrorArtifact("Can't find memory content")

Check warning on line 149 in griptape/memory/task/task_memory.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/task_memory.py#L149

Added line #L149 was not covered by tests

def extract_csv_namespace(self, namespace: str) -> ListArtifact | InfoArtifact | ErrorArtifact:
storage = self.namespace_storage.get(namespace)

if storage:
return storage.extract_csv(namespace)
else:
return ErrorArtifact("Can't find memory content")

Check warning on line 157 in griptape/memory/task/task_memory.py

View check run for this annotation

Codecov / codecov/patch

griptape/memory/task/task_memory.py#L157

Added line #L157 was not covered by tests
2 changes: 1 addition & 1 deletion griptape/tasks/extraction_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ def extraction_engine(self) -> BaseExtractionEngine:
return self._extraction_engine

def run(self) -> ListArtifact | ErrorArtifact:
return self.extraction_engine.extract(self.input.to_text(), rulesets=self.all_rulesets, **self.args)
return self.extraction_engine.extract_text(self.input.to_text(), rulesets=self.all_rulesets, **self.args)
7 changes: 7 additions & 0 deletions griptape/templates/engines/csv_extraction/system.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Don't add the header row. Don't use markdown formatting for output. Fields containing line breaks (CRLF), double quotes, and commas should be enclosed in double-quotes.
Column Names: """{{ column_names }}"""

{% if rulesets %}

{{ rulesets }}
{% endif %}
4 changes: 4 additions & 0 deletions griptape/templates/engines/csv_extraction/user.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Extract information from the Text based on the Column Names and output it as a CSV file.
Text: """{{ text }}"""

Answer:
11 changes: 0 additions & 11 deletions griptape/templates/engines/extraction/csv_extraction.j2

This file was deleted.

6 changes: 6 additions & 0 deletions griptape/templates/engines/json_extraction/system.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Extraction Template JSON Schema: """{{ json_template_schema }}"""

{% if rulesets %}

{{ rulesets }}
{% endif %}
Original file line number Diff line number Diff line change
@@ -1,11 +1,4 @@
Text: """{{ text }}"""

Extraction Template JSON Schema: """{{ json_template_schema }}"""

Extract information from the Text based on the Extraction Template JSON Schema into an array of JSON objects.
{% if rulesets %}

{{ rulesets }}
{% endif %}
Text: """{{ text }}"""

JSON array:
Loading
Loading