Skip to content

Commit

Permalink
Kwarg injection with @activity decorated functions (#1265)
Browse files Browse the repository at this point in the history
  • Loading branch information
vachillo authored Oct 17, 2024
1 parent db492c9 commit bf96755
Show file tree
Hide file tree
Showing 13 changed files with 131 additions and 25 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Chat` output now uses `Rich.print` by default.
- `Chat.output_fn`'s now takes an optional kwarg parameter, `stream`.
- Implemented `SerializableMixin` in `Structure`, `BaseTask`, `BaseTool`, and `TaskMemory`
- `@activity` decorated functions can now accept kwargs that are defined in the activity schema.

### Fixed

Expand Down
11 changes: 9 additions & 2 deletions docs/griptape-tools/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@ Tools give the LLM abilities to invoke outside APIs, reference data sets, and ge

Griptape tools are special Python classes that LLMs can use to accomplish specific goals. Here is an example custom tool for generating a random number:

A tool can have many "activities" as denoted by the `@activity` decorator. Each activity has a description (used to provide context to the LLM), and the input schema that the LLM must follow in order to use the tool.

When a function is decorated with `@activity`, the decorator injects keyword arguments into the function according to the schema. There are also two Griptape-provided keyword arguments: `params: dict` and `values: dict`.

!!! info
If your schema defines any parameters named `params` or `values`, they will be overwritten by the Griptape-provided arguments.

In the following example, all `@activity` decorated functions will result in the same value, but the method signature is defined in different ways.

```python
--8<-- "docs/griptape-tools/src/index_1.py"
```

A tool can have many "activities" as denoted by the `@activity` decorator. Each activity has a description (used to provide context to the LLM), and the input schema that the LLM must follow in order to use the tool.

Output artifacts from all tool activities (except for `InfoArtifact` and `ErrorArtifact`) go to short-term `TaskMemory`. To disable that behavior set the `off_prompt` tool parameter to `False`:

We provide a set of official Griptape Tools for accessing and processing data. You can also [build your own tools](./custom-tools/index.md).
36 changes: 36 additions & 0 deletions docs/griptape-tools/src/index_1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import random
import typing

from schema import Literal, Optional, Schema

Expand All @@ -19,5 +22,38 @@ class RandomNumberGenerator(BaseTool):
def generate(self, params: dict) -> TextArtifact:
return TextArtifact(str(round(random.random(), params["values"].get("decimals"))))

@activity(
config={
"description": "Can be used to generate random numbers",
"schema": Schema(
{Optional(Literal("decimals", description="Number of decimals to round the random number to")): int}
),
}
)
def generate_with_decimals(self, decimals: typing.Optional[int]) -> TextArtifact:
return TextArtifact(str(round(random.random(), decimals)))

@activity(
config={
"description": "Can be used to generate random numbers",
"schema": Schema(
{Optional(Literal("decimals", description="Number of decimals to round the random number to")): int}
),
}
)
def generate_with_values(self, values: dict) -> TextArtifact:
return TextArtifact(str(round(random.random(), values.get("decimals"))))

@activity(
config={
"description": "Can be used to generate random numbers",
"schema": Schema(
{Optional(Literal("decimals", description="Number of decimals to round the random number to")): int}
),
}
)
def generate_with_kwargs(self, **kwargs) -> TextArtifact:
return TextArtifact(str(round(random.random(), kwargs.get("decimals"))))


RandomNumberGenerator()
4 changes: 2 additions & 2 deletions griptape/tools/aws_iam/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_user_policy(self, params: dict) -> TextArtifact | ErrorArtifact:
return ErrorArtifact(f"error returning policy document: {e}")

@activity(config={"description": "Can be used to list AWS MFA Devices"})
def list_mfa_devices(self, _: dict) -> ListArtifact | ErrorArtifact:
def list_mfa_devices(self) -> ListArtifact | ErrorArtifact:
try:
devices = self.client.list_mfa_devices()
return ListArtifact([TextArtifact(str(d)) for d in devices["MFADevices"]])
Expand Down Expand Up @@ -76,7 +76,7 @@ def list_user_policies(self, params: dict) -> ListArtifact | ErrorArtifact:
return ErrorArtifact(f"error listing iam user policies: {e}")

@activity(config={"description": "Can be used to list AWS IAM users."})
def list_users(self, _: dict) -> ListArtifact | ErrorArtifact:
def list_users(self) -> ListArtifact | ErrorArtifact:
try:
users = self.client.list_users()
return ListArtifact([TextArtifact(str(u)) for u in users["Users"]])
Expand Down
2 changes: 1 addition & 1 deletion griptape/tools/aws_s3/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_object_acl(self, params: dict) -> TextArtifact | ErrorArtifact:
return ErrorArtifact(f"error getting object acl: {e}")

@activity(config={"description": "Can be used to list all AWS S3 buckets."})
def list_s3_buckets(self, _: dict) -> ListArtifact | ErrorArtifact:
def list_s3_buckets(self) -> ListArtifact | ErrorArtifact:
try:
buckets = self.client.list_buckets()

Expand Down
2 changes: 1 addition & 1 deletion griptape/tools/date_time/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class DateTimeTool(BaseTool):
@activity(config={"description": "Can be used to return current date and time."})
def get_current_datetime(self, _: dict) -> BaseArtifact:
def get_current_datetime(self) -> BaseArtifact:
try:
current_datetime = datetime.now()

Expand Down
8 changes: 3 additions & 5 deletions griptape/tools/web_search/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@ class WebSearchTool(BaseTool):
),
},
)
def search(self, props: dict) -> ListArtifact | ErrorArtifact:
values = props["values"]
query = values["query"]
extra_keys = {k: values[k] for k in values.keys() - {"query"}}
def search(self, values: dict) -> ListArtifact | ErrorArtifact:
query = values.pop("query")

try:
return self.web_search_driver.search(query, **extra_keys)
return self.web_search_driver.search(query, **values)
except Exception as e:
return ErrorArtifact(f"Error searching '{query}' with {self.web_search_driver.__class__.__name__}: {e}")
36 changes: 34 additions & 2 deletions griptape/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
import inspect
from typing import Any, Callable, Optional

import schema
Expand All @@ -24,8 +25,8 @@ def activity(config: dict) -> Any:

def decorator(func: Callable) -> Any:
@functools.wraps(func)
def wrapper(self: Any, *args, **kwargs) -> Any:
return func(self, *args, **kwargs)
def wrapper(self: Any, params: dict) -> Any:
return func(self, **_build_kwargs(func, params))

setattr(wrapper, "name", func.__name__)
setattr(wrapper, "config", validated_config)
Expand Down Expand Up @@ -54,3 +55,34 @@ def lazy_attr(self: Any, value: Any) -> None:
return lazy_attr

return decorator


def _build_kwargs(func: Callable, params: dict) -> dict:
func_params = inspect.signature(func).parameters.copy()
func_params.pop("self")

kwarg_var = None
for param in func_params.values():
# if there is a **kwargs parameter, we can safely
# pass all the params to the function
if param.kind == inspect.Parameter.VAR_KEYWORD:
kwarg_var = func_params.pop(param.name).name
break

# only pass the values that are in the function signature
# or if there is a **kwargs parameter, pass all the values
kwargs = {k: v for k, v in params.get("values", {}).items() if k in func_params or kwarg_var is not None}

# add 'params' and 'values' if they are in the signature
# or if there is a **kwargs parameter
if "params" in func_params or kwarg_var is not None:
kwargs["params"] = params
if "values" in func_params or kwarg_var is not None:
kwargs["values"] = params.get("values")

# set any missing parameters to None
for param_name in func_params:
if param_name not in kwargs:
kwargs[param_name] = None

return kwargs
24 changes: 12 additions & 12 deletions tests/mocks/mock_tool/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,38 @@ class MockTool(BaseTool):
"schema": Schema({Literal("test"): str}, description="Test input"),
}
)
def test(self, value: dict) -> BaseArtifact:
return TextArtifact(f"ack {value['values']['test']}")
def test(self, test: str) -> BaseArtifact:
return TextArtifact(f"ack {test}")

@activity(
config={
"description": "test description: {{ _self.foo() }}",
"schema": Schema({Literal("test"): str}, description="Test input"),
}
)
def test_error(self, value: dict) -> BaseArtifact:
return ErrorArtifact(f"error {value['values']['test']}")
def test_error(self, params: dict) -> BaseArtifact:
return ErrorArtifact(f"error {params['values']['test']}")

@activity(
config={
"description": "test description: {{ _self.foo() }}",
"schema": Schema({Literal("test"): str}, description="Test input"),
}
)
def test_exception(self, value: dict) -> BaseArtifact:
raise Exception(f"error {value['values']['test']}")
def test_exception(self, params: dict) -> BaseArtifact:
raise Exception(f"error {params['values']['test']}")

@activity(
config={
"description": "test description: {{ _self.foo() }}",
"schema": Schema({Literal("test"): str}, description="Test input"),
}
)
def test_str_output(self, value: dict) -> str:
return f"ack {value['values']['test']}"
def test_str_output(self, params: dict) -> str:
return f"ack {params['values']['test']}"

@activity(config={"description": "test description"})
def test_no_schema(self, value: dict) -> str:
def test_no_schema(self) -> str:
return "no schema"

@activity(
Expand All @@ -68,14 +68,14 @@ def test_callable_schema(self) -> TextArtifact:
return TextArtifact("ack")

@activity(config={"description": "test description"})
def test_list_output(self, value: dict) -> ListArtifact:
def test_list_output(self) -> ListArtifact:
return ListArtifact([TextArtifact("foo"), TextArtifact("bar")])

@activity(
config={"description": "test description", "schema": Schema({Literal("test"): str}, description="Test input")}
)
def test_without_default_memory(self, value: dict) -> str:
return f"ack {value['values']['test']}"
def test_without_default_memory(self, params: dict) -> str:
return f"ack {params['values']['test']}"

def foo(self) -> str:
return "foo"
Expand Down
Empty file.
Empty file.
25 changes: 25 additions & 0 deletions tests/mocks/mock_tool_kwargs/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from attrs import define
from schema import Literal, Schema

from griptape.tools import BaseTool
from griptape.utils.decorators import activity


@define
class MockToolKwargs(BaseTool):
@activity(
config={
"description": "test description",
"schema": Schema({Literal("test_kwarg"): str}, description="Test input"),
}
)
def test_with_kwargs(self, params: dict, test_kwarg: str, test_kwarg_none: None, **kwargs) -> str:
if test_kwarg_none is not None:
raise ValueError("test_kwarg_none should be None")
if "test_kwarg_kwargs" not in kwargs:
raise ValueError("test_kwarg_kwargs not in kwargs")
if "values" not in kwargs:
raise ValueError("values not in params")
if "test_kwarg" not in params["values"]:
raise ValueError("test_kwarg not in params")
return f"ack {test_kwarg}"
7 changes: 7 additions & 0 deletions tests/unit/tools/test_base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from griptape.tasks import ActionsSubtask, ToolkitTask
from griptape.tools import BaseTool
from tests.mocks.mock_tool.tool import MockTool
from tests.mocks.mock_tool_kwargs.tool import MockToolKwargs
from tests.utils import defaults


Expand Down Expand Up @@ -308,3 +309,9 @@ def test_from_dict(self, tool):
assert isinstance(deserialized_tool, BaseTool)

assert deserialized_tool.execute(tool.test_list_output, ActionsSubtask("foo"), action).to_text() == "foo\n\nbar"

def test_method_kwargs_var_injection(self, tool):
tool = MockToolKwargs()

params = {"values": {"test_kwarg": "foo", "test_kwarg_kwargs": "bar"}}
assert tool.test_with_kwargs(params) == "ack foo"

0 comments on commit bf96755

Please sign in to comment.