Skip to content

Commit

Permalink
Toolkit Task Improvements (#1268)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Oct 22, 2024
1 parent dab865a commit 6c838b6
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 28 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Chat.output_fn`'s now takes an optional kwarg parameter, `stream`.
- Implemented `SerializableMixin` in `Structure`, `BaseTask`, `BaseTool`, and `TaskMemory`
- `@activity` decorated functions can now accept kwargs that are defined in the activity schema.
- Updated `ToolkitTask` system prompt to no longer mention `memory_name` and `artifact_namespace`.
- Models in `ToolkitTask` with native tool calling no longer need to provide their final answer as `Answer:`.
- `EventListener.event_types` will now listen on child types of any provided type.

### Fixed

- Structures not flushing events when not listening for `FinishStructureRunEvent`.
- `EventListener.event_types` and the argument to `BaseEventListenerDriver.handler` being out of sync.
- Models occasionally hallucinating `memory_name` and `artifact_namespace` into Tool schemas when using `ToolkitTask`.
- Models occasionally providing overly succinct final answers when using `ToolkitTask`.

## \[0.33.1\] - 2024-10-11

Expand Down
3 changes: 0 additions & 3 deletions docs/griptape-framework/structures/rulesets.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ A [Ruleset](../../reference/griptape/rules/ruleset.md) can be used to define [Ru
[JsonSchemaRule](../../reference/griptape/rules/json_schema_rule.md)s defines a structured format for the LLM's output by providing a JSON schema.
This is particularly useful when you need the LLM to return well-formed data, such as JSON objects, with specific fields and data types.

!!! warning
`JsonSchemaRule` may break [ToolkitTask](../structures/tasks.md#toolkit-task) which relies on a specific [output token](https://github.com/griptape-ai/griptape/blob/e6a04c7b88cf9fa5d6bcf4c833ffebfab89a3258/griptape/tasks/toolkit_task.py#L28).

```python
--8<-- "docs/griptape-framework/structures/src/json_schema_rule.py"
```
Expand Down
35 changes: 23 additions & 12 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,18 @@ def __init_from_prompt(self, value: str) -> None:
actions_matches = re.findall(self.ACTIONS_PATTERN, value, re.DOTALL)
answer_matches = re.findall(self.ANSWER_PATTERN, value, re.MULTILINE)

if self.thought is None and thought_matches:
self.thought = thought_matches[-1]
self.actions = self.__parse_actions(actions_matches)

self.__parse_actions(actions_matches)
if thought_matches:
self.thought = thought_matches[-1]

# If there are no actions to take but an answer is provided, set the answer as the output.
if len(self.actions) == 0 and self.output is None and answer_matches:
self.output = TextArtifact(answer_matches[-1])
if not self.actions and self.output is None:
if answer_matches:
# A direct answer is provided, set it as the output.
self.output = TextArtifact(answer_matches[-1])
else:
# The LLM failed to follow the ReAct prompt, set the LLM's raw response as the output.
self.output = TextArtifact(value)

def __init_from_artifacts(self, artifacts: ListArtifact) -> None:
"""Parses the input Artifacts to extract the thought and actions.
Expand All @@ -243,23 +247,30 @@ def __init_from_artifacts(self, artifacts: ListArtifact) -> None:
if isinstance(artifact, ActionArtifact)
]

thoughts = [artifact.value for artifact in artifacts.value if isinstance(artifact, TextArtifact)]
if thoughts:
self.thought = thoughts[0]
# When parsing from Artifacts we can't determine the thought unless there are also Actions
if self.actions:
thoughts = [artifact.value for artifact in artifacts.value if isinstance(artifact, TextArtifact)]
if thoughts:
self.thought = thoughts[0]
else:
if self.output is None:
self.output = TextArtifact(artifacts.to_text())

def __parse_actions(self, actions_matches: list[str]) -> None:
def __parse_actions(self, actions_matches: list[str]) -> list[ToolAction]:
if len(actions_matches) == 0:
return
return []
try:
data = actions_matches[-1]
actions_list: list[dict] = json.loads(data, strict=False)

self.actions = [self.__process_action_object(action_object) for action_object in actions_list]
return [self.__process_action_object(action_object) for action_object in actions_list]
except json.JSONDecodeError as e:
logger.exception("Subtask %s\nInvalid actions JSON: %s", self.origin_task.id, e)

self.output = ErrorArtifact(f"Actions JSON decoding error: {e}", exception=e)

return []

def __process_action_object(self, action_object: dict) -> ToolAction:
# Load action tag; throw exception if the key is not present
action_tag = action_object["tag"]
Expand Down
3 changes: 0 additions & 3 deletions griptape/tasks/toolkit_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,6 @@ def run(self) -> BaseArtifact:
if subtask.output is None:
if len(self.subtasks) >= self.max_subtasks:
subtask.output = ErrorArtifact(f"Exceeded tool limit of {self.max_subtasks} subtasks per task")
elif not subtask.actions:
# handle case when the LLM failed to follow the ReAct prompt and didn't return a proper action
subtask.output = subtask.input
else:
subtask.before_run()
subtask.run()
Expand Down
6 changes: 2 additions & 4 deletions griptape/templates/tasks/toolkit_task/system.j2
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@ Actions: <JSON array of actions that MUST follow this schema: {{ actions_schema
{{ stop_sequence }}: <action outputs>
"Thought", "Actions", "{{ stop_sequence }}" must always start on a new line.

You must use the following format when providing your final answer:
Answer: <final answer>
{% endif %}
Repeat executing actions as many times as you need.
If an action's output contains an error, you MUST ALWAYS try to fix the error by executing another action.
You must use the following format when providing your final answer:
Answer: <final answer>

Be truthful. ALWAYS be proactive and NEVER ask the user for more information input. Keep using actions until you have your final answer.
NEVER make up actions, action names, or action paths. NEVER make up facts. NEVER reference tags in other action input values.

Actions might store their output in memory as artifacts (with `memory_name` and `artifact_namespace`). If action output is stored in memory, ALWAYS try to pass it to another action. NEVER make up memory names or artifact namespaces.
{% if meta_memory %}

{{ meta_memory }}
Expand Down
59 changes: 53 additions & 6 deletions tests/unit/tasks/test_actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class TestActionsSubtask:
def test_basic_input(self):
def test_prompt_input(self):
valid_input = (
"Thought: need to test\n"
'Actions: [{"tag": "foo", "name": "MockTool", "path": "test", "input": {"values": {"test": "value"}}}]\n'
Expand All @@ -25,22 +25,31 @@ def test_basic_input(self):
assert json_dict[0]["name"] == "MockTool"
assert json_dict[0]["path"] == "test"
assert json_dict[0]["input"] == {"values": {"test": "value"}}
assert subtask.thought == "need to test"
assert subtask.output is None

def test_action_input(self):
valid_input = ActionArtifact(
ToolAction(tag="foo", name="MockTool", path="test", input={"values": {"test": "value"}})
def test_artifact_input(self):
valid_input = ListArtifact(
[
TextArtifact("need to test"),
ActionArtifact(
ToolAction(tag="foo", name="MockTool", path="test", input={"values": {"test": "value"}})
),
TextArtifact("answer"),
]
)
task = ToolkitTask(tools=[MockTool()])
Agent().add_task(task)
subtask = task.add_subtask(ActionsSubtask(valid_input))
json_dict = json.loads(subtask.actions_to_json())

assert subtask.thought is None
assert json_dict[0]["name"] == "MockTool"
assert json_dict[0]["path"] == "test"
assert json_dict[0]["input"] == {"values": {"test": "value"}}
assert subtask.thought == "need to test"
assert subtask.output is None

def test_action_and_thought_input(self):
def test_artifact_action_and_thought_input(self):
valid_input = ListArtifact(
[
TextArtifact("thought"),
Expand All @@ -59,6 +68,42 @@ def test_action_and_thought_input(self):
assert json_dict[0]["path"] == "test"
assert json_dict[0]["input"] == {"values": {"test": "value"}}

def test_prompt_answer(self):
valid_input = "Answer: test output"

task = ToolkitTask(tools=[MockTool()])
Agent().add_task(task)
subtask = task.add_subtask(ActionsSubtask(valid_input))

assert subtask.thought is None
assert subtask.actions == []
assert subtask.output.value == "test output"

def test_prompt_implicit_answer(self):
valid_input = "test output"

task = ToolkitTask(tools=[MockTool()])
Agent().add_task(task)
subtask = task.add_subtask(ActionsSubtask(valid_input))

assert subtask.thought is None
assert subtask.actions == []
assert subtask.output.value == "test output"

def test_artifact_answer(self):
valid_input = ListArtifact(
[
TextArtifact("answer"),
]
)
task = ToolkitTask(tools=[MockTool()])
Agent().add_task(task)
subtask = task.add_subtask(ActionsSubtask(valid_input))

assert subtask.thought is None
assert subtask.actions == []
assert subtask.output.value == "answer"

def test_callable_input(self):
valid_input = ListArtifact(
[
Expand Down Expand Up @@ -146,6 +191,8 @@ def test_invalid_actions(self):

assert isinstance(subtask.output, ErrorArtifact)
assert "Actions JSON decoding error" in subtask.output.value
assert subtask.thought == "need to test"
assert subtask.actions == []

def test_implicit_values(self):
valid_input = (
Expand Down

0 comments on commit 6c838b6

Please sign in to comment.