From db492c95434288ce53708416a8b7876f6f8bd9cc Mon Sep 17 00:00:00 2001 From: William Price <82848178+william-price01@users.noreply.github.com> Date: Thu, 17 Oct 2024 09:32:41 -0700 Subject: [PATCH] Structures, Tools, Tasks, are now serializable. (#1261) Co-authored-by: Collin Dutter --- CHANGELOG.md | 1 + griptape/memory/task/task_memory.py | 15 +++++-- griptape/mixins/serializable_mixin.py | 13 ++++-- griptape/schemas/base_schema.py | 10 ++++- griptape/schemas/polymorphic_schema.py | 2 +- griptape/structures/agent.py | 4 +- griptape/structures/structure.py | 14 +++--- griptape/tasks/actions_subtask.py | 8 ++-- griptape/tasks/base_audio_input_task.py | 4 +- griptape/tasks/base_task.py | 15 ++++--- griptape/tasks/base_text_input_task.py | 4 +- griptape/tasks/image_query_task.py | 22 +++++----- .../tasks/inpainting_image_generation_task.py | 8 ++-- .../outpainting_image_generation_task.py | 8 ++-- .../tasks/prompt_image_generation_task.py | 4 +- griptape/tasks/prompt_task.py | 4 +- griptape/tasks/text_to_speech_task.py | 4 +- griptape/tasks/tool_task.py | 2 +- .../tasks/variation_image_generation_task.py | 6 +-- griptape/tools/base_tool.py | 23 ++++++---- tests/mocks/mock_prompt_driver.py | 6 +-- tests/mocks/mock_tool/tool.py | 5 +++ tests/unit/memory/tool/test_task_memory.py | 24 ++++++++++ tests/unit/mixins/test_seriliazable_mixin.py | 15 +++++++ tests/unit/structures/test_agent.py | 44 +++++++++++++++++++ tests/unit/structures/test_pipeline.py | 44 +++++++++++++++++++ tests/unit/structures/test_workflow.py | 44 +++++++++++++++++++ tests/unit/tasks/test_base_task.py | 32 ++++++++++++++ tests/unit/tasks/test_tool_task.py | 36 +++++++++++++++ tests/unit/tools/test_base_tool.py | 29 ++++++++++++ 30 files changed, 380 insertions(+), 70 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a88fb709..37afa4b41 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Chat` input now uses a slightly customized version of `Rich.prompt.Prompt` by default. - `Chat` output now uses `Rich.print` by default. - `Chat.output_fn`'s now takes an optional kwarg parameter, `stream`. +- Implemented `SerializableMixin` in `Structure`, `BaseTask`, `BaseTool`, and `TaskMemory` ### Fixed diff --git a/griptape/memory/task/task_memory.py b/griptape/memory/task/task_memory.py index 1aa60dba3..b5bc35378 100644 --- a/griptape/memory/task/task_memory.py +++ b/griptape/memory/task/task_memory.py @@ -8,6 +8,7 @@ from griptape.memory.meta import ActionSubtaskMetaEntry from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage from griptape.mixins.activity_mixin import ActivityMixin +from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: from griptape.memory.task.storage import BaseArtifactStorage @@ -15,8 +16,12 @@ @define -class TaskMemory(ActivityMixin): - name: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) +class TaskMemory(ActivityMixin, SerializableMixin): + name: str = field( + default=Factory(lambda self: self.__class__.__name__, takes_self=True), + kw_only=True, + metadata={"serializable": True}, + ) artifact_storages: dict[type, BaseArtifactStorage] = field( default=Factory( lambda: { @@ -26,8 +31,10 @@ class TaskMemory(ActivityMixin): ), kw_only=True, ) - namespace_storage: dict[str, BaseArtifactStorage] = field(factory=dict, kw_only=True) - namespace_metadata: dict[str, Any] = field(factory=dict, kw_only=True) + namespace_storage: dict[str, BaseArtifactStorage] = field( + factory=dict, kw_only=True, metadata={"serializable": True} + ) + namespace_metadata: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) @artifact_storages.validator # pyright: ignore[reportAttributeAccessIssue] def validate_artifact_storages(self, _: Attribute, artifact_storage: dict[type, BaseArtifactStorage]) -> None: diff --git a/griptape/mixins/serializable_mixin.py b/griptape/mixins/serializable_mixin.py index e8f772cab..35269b36e 100644 --- a/griptape/mixins/serializable_mixin.py +++ b/griptape/mixins/serializable_mixin.py @@ -22,19 +22,26 @@ class SerializableMixin(Generic[T]): kw_only=True, metadata={"serializable": True}, ) + module_name: str = field( + default=Factory(lambda self: self.__class__.__module__, takes_self=True), + kw_only=True, + metadata={"serializable": False}, + ) @classmethod - def get_schema(cls: type[T], subclass_name: Optional[str] = None) -> Schema: + def get_schema(cls: type[T], subclass_name: Optional[str] = None, *, module_name: Optional[str] = None) -> Schema: """Generates a Marshmallow schema for the class. Args: subclass_name: An optional subclass name. Required if the class is abstract. + module_name: An optional module name. Defaults to the class's module. """ if ABC in cls.__bases__: if subclass_name is None: raise ValueError(f"Type field is required for abstract class: {cls.__name__}") - subclass_cls = cls._import_cls_rec(cls.__module__, subclass_name) + module_name = module_name or cls.__module__ + subclass_cls = cls._import_cls_rec(module_name, subclass_name) schema_class = BaseSchema.from_attrs_cls(subclass_cls) else: @@ -44,7 +51,7 @@ def get_schema(cls: type[T], subclass_name: Optional[str] = None) -> Schema: @classmethod def from_dict(cls: type[T], data: dict) -> T: - return cast(T, cls.get_schema(subclass_name=data.get("type")).load(data)) + return cast(T, cls.get_schema(subclass_name=data.get("type"), module_name=data.get("module_name")).load(data)) @classmethod def from_json(cls: type[T], data: str) -> T: diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index e6ef47a37..9762bf83d 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -173,8 +173,11 @@ def _resolve_types(cls, attrs_cls: type) -> None: BaseVectorStoreDriver, ) from griptape.events import EventListener - from griptape.memory.structure import Run + from griptape.memory import TaskMemory + from griptape.memory.structure import BaseConversationMemory, Run + from griptape.memory.task.storage import BaseArtifactStorage from griptape.structures import Structure + from griptape.tasks import BaseTask from griptape.tokenizers import BaseTokenizer from griptape.tools import BaseTool from griptape.utils import import_optional_dependency, is_dependency_installed @@ -198,6 +201,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseMessageContent": BaseMessageContent, "BaseDeltaMessageContent": BaseDeltaMessageContent, "BaseTool": BaseTool, + "BaseTask": BaseTask, "Usage": Message.Usage, "Structure": Structure, "BaseTokenizer": BaseTokenizer, @@ -205,6 +209,10 @@ def _resolve_types(cls, attrs_cls: type) -> None: "Reference": Reference, "Run": Run, "Sequence": Sequence, + "TaskMemory": TaskMemory, + "State": BaseTask.State, + "BaseConversationMemory": BaseConversationMemory, + "BaseArtifactStorage": BaseArtifactStorage, # Third party modules "Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any, "GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel diff --git a/griptape/schemas/polymorphic_schema.py b/griptape/schemas/polymorphic_schema.py index 2e556b2c7..39749a431 100644 --- a/griptape/schemas/polymorphic_schema.py +++ b/griptape/schemas/polymorphic_schema.py @@ -116,7 +116,7 @@ def _load(self, data: Any, *, partial: Any = None, unknown: Any = None, **kwargs if data_type is None: raise ValidationError({self.type_field: ["Missing data for required field."]}) - type_schema = self.inner_class.get_schema(data_type) + type_schema = self.inner_class.get_schema(data_type, module_name=data.get("module_name")) if not type_schema: raise ValidationError({self.type_field: [f"Unsupported value: {data_type}"]}) diff --git a/griptape/structures/agent.py b/griptape/structures/agent.py index 77f3e0618..121220bc2 100644 --- a/griptape/structures/agent.py +++ b/griptape/structures/agent.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import Attribute, Factory, define, field @@ -19,7 +19,7 @@ @define class Agent(Structure): - input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( + input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), ) stream: bool = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver.stream), kw_only=True) diff --git a/griptape/structures/structure.py b/griptape/structures/structure.py index 29ee6281c..c2702fdbf 100644 --- a/griptape/structures/structure.py +++ b/griptape/structures/structure.py @@ -2,7 +2,7 @@ import uuid from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Optional, Union from attrs import Factory, define, field @@ -12,6 +12,7 @@ from griptape.memory.meta import MetaMemory from griptape.memory.structure import ConversationMemory, Run from griptape.mixins.rule_mixin import RuleMixin +from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -20,19 +21,22 @@ @define -class Structure(ABC, RuleMixin): - id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) - _tasks: list[BaseTask | list[BaseTask]] = field(factory=list, kw_only=True, alias="tasks") +class Structure(ABC, RuleMixin, SerializableMixin): + id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) + _tasks: list[Union[BaseTask, list[BaseTask]]] = field( + factory=list, kw_only=True, alias="tasks", metadata={"serializable": True} + ) conversation_memory: Optional[BaseConversationMemory] = field( default=Factory(lambda: ConversationMemory()), kw_only=True, + metadata={"serializable": True}, ) task_memory: TaskMemory = field( default=Factory(lambda self: TaskMemory(), takes_self=True), kw_only=True, ) meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True) - fail_fast: bool = field(default=True, kw_only=True) + fail_fast: bool = field(default=True, kw_only=True, metadata={"serializable": True}) _execution_args: tuple = () def __attrs_post_init__(self) -> None: diff --git a/griptape/tasks/actions_subtask.py b/griptape/tasks/actions_subtask.py index 9057fc127..8d68ae996 100644 --- a/griptape/tasks/actions_subtask.py +++ b/griptape/tasks/actions_subtask.py @@ -3,7 +3,7 @@ import json import logging import re -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, Union import schema from attrs import define, field @@ -33,7 +33,7 @@ class ActionsSubtask(BaseTask): thought: Optional[str] = field(default=None, kw_only=True) actions: list[ToolAction] = field(factory=list, kw_only=True) output: Optional[BaseArtifact] = field(default=None, init=False) - _input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( + _input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), alias="input", ) @@ -197,8 +197,8 @@ def actions_to_json(self) -> str: def _process_task_input( self, - task_input: str | tuple | list | BaseArtifact | Callable[[BaseTask], BaseArtifact], - ) -> TextArtifact | ListArtifact: + task_input: Union[str, tuple, list, BaseArtifact, Callable[[BaseTask], BaseArtifact]], + ) -> Union[TextArtifact, ListArtifact]: if isinstance(task_input, (TextArtifact, ListArtifact)): return task_input elif isinstance(task_input, ActionArtifact): diff --git a/griptape/tasks/base_audio_input_task.py b/griptape/tasks/base_audio_input_task.py index 8a834db56..0459fed03 100644 --- a/griptape/tasks/base_audio_input_task.py +++ b/griptape/tasks/base_audio_input_task.py @@ -2,7 +2,7 @@ import logging from abc import ABC -from typing import Callable +from typing import Callable, Union from attrs import define, field @@ -16,7 +16,7 @@ @define class BaseAudioInputTask(RuleMixin, BaseTask, ABC): - _input: AudioArtifact | Callable[[BaseTask], AudioArtifact] = field(alias="input") + _input: Union[AudioArtifact, Callable[[BaseTask], AudioArtifact]] = field(alias="input") @property def input(self) -> AudioArtifact: diff --git a/griptape/tasks/base_task.py b/griptape/tasks/base_task.py index ff1f6a11d..b6012c7e1 100644 --- a/griptape/tasks/base_task.py +++ b/griptape/tasks/base_task.py @@ -12,6 +12,7 @@ from griptape.configs import Defaults from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin +from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: from griptape.artifacts import BaseArtifact @@ -22,21 +23,21 @@ @define -class BaseTask(FuturesExecutorMixin, ABC): +class BaseTask(FuturesExecutorMixin, SerializableMixin, ABC): class State(Enum): PENDING = 1 EXECUTING = 2 FINISHED = 3 - id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True) - state: State = field(default=State.PENDING, kw_only=True) - parent_ids: list[str] = field(factory=list, kw_only=True) - child_ids: list[str] = field(factory=list, kw_only=True) - max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True) + id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True}) + state: State = field(default=State.PENDING, kw_only=True, metadata={"serializable": True}) + parent_ids: list[str] = field(factory=list, kw_only=True, metadata={"serializable": True}) + child_ids: list[str] = field(factory=list, kw_only=True, metadata={"serializable": True}) + max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True, metadata={"serializable": True}) structure: Optional[Structure] = field(default=None, kw_only=True) output: Optional[BaseArtifact] = field(default=None, init=False) - context: dict[str, Any] = field(factory=dict, kw_only=True) + context: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True}) def __rshift__(self, other: BaseTask) -> BaseTask: self.add_child(other) diff --git a/griptape/tasks/base_text_input_task.py b/griptape/tasks/base_text_input_task.py index dfed85bcf..b8321b4f4 100644 --- a/griptape/tasks/base_text_input_task.py +++ b/griptape/tasks/base_text_input_task.py @@ -2,7 +2,7 @@ import logging from abc import ABC -from typing import Callable +from typing import Callable, Union from attrs import define, field @@ -19,7 +19,7 @@ class BaseTextInputTask(RuleMixin, BaseTask, ABC): DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" - _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field( + _input: Union[str, TextArtifact, Callable[[BaseTask], TextArtifact]] = field( default=DEFAULT_INPUT_TEMPLATE, alias="input", ) diff --git a/griptape/tasks/image_query_task.py b/griptape/tasks/image_query_task.py index 1c77bbc0a..5d1fcc79a 100644 --- a/griptape/tasks/image_query_task.py +++ b/griptape/tasks/image_query_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Union from attrs import Factory, define, field @@ -25,12 +25,12 @@ class ImageQueryTask(BaseTask): """ image_query_engine: ImageQueryEngine = field(default=Factory(lambda: ImageQueryEngine()), kw_only=True) - _input: ( - tuple[str, list[ImageArtifact]] - | tuple[TextArtifact, list[ImageArtifact]] - | Callable[[BaseTask], ListArtifact] - | ListArtifact - ) = field(default=None, alias="input") + _input: Union[ + tuple[str, list[ImageArtifact]], + tuple[TextArtifact, list[ImageArtifact]], + Callable[[BaseTask], ListArtifact], + ListArtifact, + ] = field(default=None, alias="input") @property def input(self) -> ListArtifact: @@ -55,9 +55,11 @@ def input(self) -> ListArtifact: def input( self, value: ( - tuple[str, list[ImageArtifact]] - | tuple[TextArtifact, list[ImageArtifact]] - | Callable[[BaseTask], ListArtifact] + Union[ + tuple[str, list[ImageArtifact]], + tuple[TextArtifact, list[ImageArtifact]], + Callable[[BaseTask], ListArtifact], + ] ), ) -> None: self._input = value diff --git a/griptape/tasks/inpainting_image_generation_task.py b/griptape/tasks/inpainting_image_generation_task.py index 649f9e3fb..a00e345fb 100644 --- a/griptape/tasks/inpainting_image_generation_task.py +++ b/griptape/tasks/inpainting_image_generation_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Union from attrs import Factory, define, field @@ -32,9 +32,9 @@ class InpaintingImageGenerationTask(BaseImageGenerationTask): default=Factory(lambda: InpaintingImageGenerationEngine()), kw_only=True, ) - _input: ( - tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact - ) = field(default=None, alias="input") + _input: Union[ + tuple[Union[str, TextArtifact], ImageArtifact, ImageArtifact], Callable[[BaseTask], ListArtifact], ListArtifact + ] = field(default=None, alias="input") @property def input(self) -> ListArtifact: diff --git a/griptape/tasks/outpainting_image_generation_task.py b/griptape/tasks/outpainting_image_generation_task.py index 019f74fa1..ee928c800 100644 --- a/griptape/tasks/outpainting_image_generation_task.py +++ b/griptape/tasks/outpainting_image_generation_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Union from attrs import Factory, define, field @@ -32,9 +32,9 @@ class OutpaintingImageGenerationTask(BaseImageGenerationTask): default=Factory(lambda: OutpaintingImageGenerationEngine()), kw_only=True, ) - _input: ( - tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact - ) = field(default=None, alias="input") + _input: Union[ + tuple[Union[str, TextArtifact], ImageArtifact, ImageArtifact], Callable[[BaseTask], ListArtifact], ListArtifact + ] = field(default=None, alias="input") @property def input(self) -> ListArtifact: diff --git a/griptape/tasks/prompt_image_generation_task.py b/griptape/tasks/prompt_image_generation_task.py index d2ebf79c2..5676f4d65 100644 --- a/griptape/tasks/prompt_image_generation_task.py +++ b/griptape/tasks/prompt_image_generation_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Union from attrs import Factory, define, field @@ -29,7 +29,7 @@ class PromptImageGenerationTask(BaseImageGenerationTask): DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" - _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field( + _input: Union[str, TextArtifact, Callable[[BaseTask], TextArtifact]] = field( default=DEFAULT_INPUT_TEMPLATE, alias="input" ) image_generation_engine: PromptImageGenerationEngine = field( diff --git a/griptape/tasks/prompt_task.py b/griptape/tasks/prompt_task.py index 127fabf48..ed3ffa452 100644 --- a/griptape/tasks/prompt_task.py +++ b/griptape/tasks/prompt_task.py @@ -1,7 +1,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Callable, Optional +from typing import TYPE_CHECKING, Callable, Optional, Union from attrs import Factory, define, field @@ -28,7 +28,7 @@ class PromptTask(RuleMixin, BaseTask): default=Factory(lambda self: self.default_system_template_generator, takes_self=True), kw_only=True, ) - _input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field( + _input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field( default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""), alias="input", ) diff --git a/griptape/tasks/text_to_speech_task.py b/griptape/tasks/text_to_speech_task.py index c131d69bc..5f897164c 100644 --- a/griptape/tasks/text_to_speech_task.py +++ b/griptape/tasks/text_to_speech_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Union from attrs import Factory, define, field @@ -18,7 +18,7 @@ class TextToSpeechTask(BaseAudioGenerationTask): DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}" - _input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field(default=DEFAULT_INPUT_TEMPLATE) + _input: Union[str, TextArtifact, Callable[[BaseTask], TextArtifact]] = field(default=DEFAULT_INPUT_TEMPLATE) text_to_speech_engine: TextToSpeechEngine = field(default=Factory(lambda: TextToSpeechEngine()), kw_only=True) @property diff --git a/griptape/tasks/tool_task.py b/griptape/tasks/tool_task.py index 38d6e1512..a9a36ddb6 100644 --- a/griptape/tasks/tool_task.py +++ b/griptape/tasks/tool_task.py @@ -25,7 +25,7 @@ class ToolTask(PromptTask, ActionsSubtaskOriginMixin): ACTION_PATTERN = r"(?s)[^{]*({.*})" - tool: BaseTool = field(kw_only=True) + tool: BaseTool = field(kw_only=True, metadata={"serializable": True}) subtask: Optional[ActionsSubtask] = field(default=None, kw_only=True) task_memory: Optional[TaskMemory] = field(default=None, kw_only=True) diff --git a/griptape/tasks/variation_image_generation_task.py b/griptape/tasks/variation_image_generation_task.py index ddc16178b..c443cd08b 100644 --- a/griptape/tasks/variation_image_generation_task.py +++ b/griptape/tasks/variation_image_generation_task.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable +from typing import Callable, Union from attrs import Factory, define, field @@ -32,8 +32,8 @@ class VariationImageGenerationTask(BaseImageGenerationTask): default=Factory(lambda: VariationImageGenerationEngine()), kw_only=True, ) - _input: tuple[str | TextArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact = field( - default=None, alias="input" + _input: Union[tuple[Union[str, TextArtifact], ImageArtifact], Callable[[BaseTask], ListArtifact], ListArtifact] = ( + field(default=None, alias="input") ) @property diff --git a/griptape/tools/base_tool.py b/griptape/tools/base_tool.py index 81f127791..7efa9f77f 100644 --- a/griptape/tools/base_tool.py +++ b/griptape/tools/base_tool.py @@ -16,6 +16,7 @@ from griptape.artifacts import BaseArtifact, ErrorArtifact, InfoArtifact, TextArtifact from griptape.common import observable from griptape.mixins.activity_mixin import ActivityMixin +from griptape.mixins.serializable_mixin import SerializableMixin if TYPE_CHECKING: from griptape.common import ToolAction @@ -24,7 +25,7 @@ @define -class BaseTool(ActivityMixin, ABC): +class BaseTool(ActivityMixin, SerializableMixin, ABC): """Abstract class for all tools to inherit from for. Attributes: @@ -39,13 +40,19 @@ class BaseTool(ActivityMixin, ABC): REQUIREMENTS_FILE = "requirements.txt" - name: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True) - input_memory: Optional[list[TaskMemory]] = field(default=None, kw_only=True) - output_memory: Optional[dict[str, list[TaskMemory]]] = field(default=None, kw_only=True) - install_dependencies_on_init: bool = field(default=True, kw_only=True) - dependencies_install_directory: Optional[str] = field(default=None, kw_only=True) - verbose: bool = field(default=False, kw_only=True) - off_prompt: bool = field(default=False, kw_only=True) + name: str = field( + default=Factory(lambda self: self.__class__.__name__, takes_self=True), + kw_only=True, + metadata={"serializable": True}, + ) + input_memory: Optional[list[TaskMemory]] = field(default=None, kw_only=True, metadata={"serializable": True}) + output_memory: Optional[dict[str, list[TaskMemory]]] = field( + default=None, kw_only=True, metadata={"serializable": True} + ) + install_dependencies_on_init: bool = field(default=True, kw_only=True, metadata={"serializable": True}) + dependencies_install_directory: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + verbose: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + off_prompt: bool = field(default=False, kw_only=True, metadata={"serializable": True}) def __attrs_post_init__(self) -> None: if self.install_dependencies_on_init: diff --git a/tests/mocks/mock_prompt_driver.py b/tests/mocks/mock_prompt_driver.py index 70089430d..f308c9804 100644 --- a/tests/mocks/mock_prompt_driver.py +++ b/tests/mocks/mock_prompt_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable +from typing import TYPE_CHECKING, Callable, Union from attrs import define, field @@ -29,8 +29,8 @@ class MockPromptDriver(BasePromptDriver): model: str = "test-model" tokenizer: BaseTokenizer = MockTokenizer(model="test-model", max_input_tokens=4096, max_output_tokens=4096) - mock_input: str | Callable[[], str] = field(default="mock input", kw_only=True) - mock_output: str | Callable[[PromptStack], str] = field(default="mock output", kw_only=True) + mock_input: Union[str, Callable[[], str]] = field(default="mock input", kw_only=True) + mock_output: Union[str, Callable[[PromptStack], str]] = field(default="mock output", kw_only=True) def try_run(self, prompt_stack: PromptStack) -> Message: output = self.mock_output(prompt_stack) if isinstance(self.mock_output, Callable) else self.mock_output diff --git a/tests/mocks/mock_tool/tool.py b/tests/mocks/mock_tool/tool.py index 9c2241636..a9ee86c3c 100644 --- a/tests/mocks/mock_tool/tool.py +++ b/tests/mocks/mock_tool/tool.py @@ -12,6 +12,11 @@ class MockTool(BaseTool): test_int: int = field(default=5, kw_only=True) test_dict: dict = field(factory=dict, kw_only=True) custom_schema: dict = field(default=Factory(lambda: {"test": str}), kw_only=True) + module_name: str = field( + default=Factory(lambda self: self.__class__.__module__, takes_self=True), + kw_only=True, + metadata={"serializable": False}, + ) @activity( config={ diff --git a/tests/unit/memory/tool/test_task_memory.py b/tests/unit/memory/tool/test_task_memory.py index d2575959a..f4ea3579a 100644 --- a/tests/unit/memory/tool/test_task_memory.py +++ b/tests/unit/memory/tool/test_task_memory.py @@ -92,3 +92,27 @@ def test_load_artifacts_for_blob_list_artifact(self, memory): ) assert len(memory.load_artifacts("test")) == 2 + + def test_to_dict(self, memory): + expected_task_memory_dict = { + "type": memory.type, + "name": memory.name, + "namespace_storage": memory.namespace_storage, + "namespace_metadata": memory.namespace_metadata, + } + assert expected_task_memory_dict == memory.to_dict() + + def test_from_dict(self, memory): + serialized_memory = memory.to_dict() + assert isinstance(serialized_memory, dict) + + deserialized_memory = memory.from_dict(serialized_memory) + assert isinstance(deserialized_memory, TaskMemory) + + deserialized_memory.process_output( + MockTool().test, + ActionsSubtask(), + ListArtifact([BlobArtifact(b"foo", name="test1"), BlobArtifact(b"foo", name="test2")], name="test"), + ) + + assert len(deserialized_memory.load_artifacts("test")) == 2 diff --git a/tests/unit/mixins/test_seriliazable_mixin.py b/tests/unit/mixins/test_seriliazable_mixin.py index afb3d1eb4..dc30848f2 100644 --- a/tests/unit/mixins/test_seriliazable_mixin.py +++ b/tests/unit/mixins/test_seriliazable_mixin.py @@ -7,7 +7,11 @@ from griptape.memory import TaskMemory from griptape.memory.structure import ConversationMemory from griptape.schemas import BaseSchema +from griptape.tasks.base_task import BaseTask +from griptape.tasks.tool_task import ToolTask +from griptape.tools.base_tool import BaseTool from tests.mocks.mock_serializable import MockSerializable +from tests.mocks.mock_tool.tool import MockTool class TestSerializableMixin: @@ -15,10 +19,19 @@ def test_get_schema(self): assert isinstance(BaseArtifact.get_schema("TextArtifact"), BaseSchema) assert isinstance(TextArtifact.get_schema(), BaseSchema) + assert isinstance(BaseTool.get_schema("MockTool", module_name="tests.mocks.mock_tool.tool"), BaseSchema) + def test_from_dict(self): assert isinstance(BaseArtifact.from_dict({"type": "TextArtifact", "value": "foobar"}), TextArtifact) assert isinstance(TextArtifact.from_dict({"value": "foobar"}), TextArtifact) + assert isinstance( + BaseTask.from_dict( + {"type": "ToolTask", "tool": {"type": "MockTool", "module_name": "tests.mocks.mock_tool.tool"}}, + ), + ToolTask, + ) + def test_from_json(self): assert isinstance(BaseArtifact.from_json('{"type": "TextArtifact", "value": "foobar"}'), TextArtifact) assert isinstance(TextArtifact.from_json('{"value": "foobar"}'), TextArtifact) @@ -56,6 +69,8 @@ def test_import_class_rec(self): with pytest.raises(ValueError): MockSerializable._import_cls_rec("griptape.memory.task", "ConversationMemory") + assert MockSerializable._import_cls_rec("tests.mocks.mock_tool.tool", "MockTool") == MockTool + def test_nested_optional_serializable(self): assert MockSerializable(nested=None).to_dict().get("nested") is None diff --git a/tests/unit/structures/test_agent.py b/tests/unit/structures/test_agent.py index 97697749e..b8b6bb1b4 100644 --- a/tests/unit/structures/test_agent.py +++ b/tests/unit/structures/test_agent.py @@ -252,3 +252,47 @@ def test_task_outputs(self): assert len(agent.task_outputs) == 1 assert agent.task_outputs[task.id] == task.output + + def test_to_dict(self): + task = PromptTask("test prompt") + agent = Agent(prompt_driver=MockPromptDriver()) + agent.add_task(task) + expected_agent_dict = { + "type": "Agent", + "id": agent.id, + "tasks": [ + { + "type": agent.tasks[0].type, + "id": agent.tasks[0].id, + "state": str(agent.tasks[0].state), + "parent_ids": agent.tasks[0].parent_ids, + "child_ids": agent.tasks[0].child_ids, + "max_meta_memory_entries": agent.tasks[0].max_meta_memory_entries, + "context": agent.tasks[0].context, + } + ], + "conversation_memory": { + "type": agent.conversation_memory.type, + "runs": agent.conversation_memory.runs, + "meta": agent.conversation_memory.meta, + "max_runs": agent.conversation_memory.max_runs, + }, + } + assert agent.to_dict() == expected_agent_dict + + def test_from_dict(self): + task = PromptTask("test prompt") + agent = Agent(prompt_driver=MockPromptDriver()) + agent.add_task(task) + + serialized_agent = agent.to_dict() + assert isinstance(serialized_agent, dict) + + deserialized_agent = Agent.from_dict(serialized_agent) + assert isinstance(deserialized_agent, Agent) + + assert deserialized_agent.task_outputs[task.id] is None + deserialized_agent.run() + + assert len(deserialized_agent.task_outputs) == 1 + assert deserialized_agent.task_outputs[task.id].value == "mock output" diff --git a/tests/unit/structures/test_pipeline.py b/tests/unit/structures/test_pipeline.py index b2580de4c..f86c6330a 100644 --- a/tests/unit/structures/test_pipeline.py +++ b/tests/unit/structures/test_pipeline.py @@ -411,3 +411,47 @@ def test_task_outputs(self): pipeline.run() assert len(pipeline.task_outputs) == 1 assert pipeline.task_outputs[task.id] == task.output + + def test_to_dict(self): + task = PromptTask("test") + pipeline = Pipeline() + pipeline + [task] + expected_pipeline_dict = { + "type": pipeline.type, + "id": pipeline.id, + "tasks": [ + { + "type": pipeline.tasks[0].type, + "id": pipeline.tasks[0].id, + "state": str(pipeline.tasks[0].state), + "parent_ids": pipeline.tasks[0].parent_ids, + "child_ids": pipeline.tasks[0].child_ids, + "max_meta_memory_entries": pipeline.tasks[0].max_meta_memory_entries, + "context": pipeline.tasks[0].context, + } + ], + "conversation_memory": { + "type": pipeline.conversation_memory.type, + "runs": pipeline.conversation_memory.runs, + "meta": pipeline.conversation_memory.meta, + "max_runs": pipeline.conversation_memory.max_runs, + }, + "fail_fast": pipeline.fail_fast, + } + assert pipeline.to_dict() == expected_pipeline_dict + + def test_from_dict(self): + task = PromptTask("test") + pipeline = Pipeline(tasks=[task]) + + serialized_pipeline = pipeline.to_dict() + assert isinstance(serialized_pipeline, dict) + + deserialized_pipeline = Pipeline.from_dict(serialized_pipeline) + assert isinstance(deserialized_pipeline, Pipeline) + + assert deserialized_pipeline.task_outputs[task.id] is None + deserialized_pipeline.run() + + assert len(deserialized_pipeline.task_outputs) == 1 + assert deserialized_pipeline.task_outputs[task.id].value == "mock output" diff --git a/tests/unit/structures/test_workflow.py b/tests/unit/structures/test_workflow.py index 1a9b4e2d1..d3fb17906 100644 --- a/tests/unit/structures/test_workflow.py +++ b/tests/unit/structures/test_workflow.py @@ -979,3 +979,47 @@ def test_task_outputs(self): assert len(workflow.task_outputs) == 1 assert workflow.task_outputs[task.id].value == "mock output" + + def test_to_dict(self): + task = PromptTask("test") + workflow = Workflow(tasks=[task]) + + expected_workflow_dict = { + "type": workflow.type, + "id": workflow.id, + "tasks": [ + { + "type": workflow.tasks[0].type, + "id": workflow.tasks[0].id, + "state": str(workflow.tasks[0].state), + "parent_ids": workflow.tasks[0].parent_ids, + "child_ids": workflow.tasks[0].child_ids, + "max_meta_memory_entries": workflow.tasks[0].max_meta_memory_entries, + "context": workflow.tasks[0].context, + } + ], + "conversation_memory": { + "type": workflow.conversation_memory.type, + "runs": workflow.conversation_memory.runs, + "meta": workflow.conversation_memory.meta, + "max_runs": workflow.conversation_memory.max_runs, + }, + "fail_fast": workflow.fail_fast, + } + assert workflow.to_dict() == expected_workflow_dict + + def test_from_dict(self): + task = PromptTask("test") + workflow = Workflow(tasks=[task]) + + serialized_workflow = workflow.to_dict() + assert isinstance(serialized_workflow, dict) + + deserialized_workflow = Workflow.from_dict(serialized_workflow) + assert isinstance(deserialized_workflow, Workflow) + + assert deserialized_workflow.task_outputs[task.id] is None + deserialized_workflow.run() + + assert len(deserialized_workflow.task_outputs) == 1 + assert deserialized_workflow.task_outputs[task.id].value == "mock output" diff --git a/tests/unit/tasks/test_base_task.py b/tests/unit/tasks/test_base_task.py index cd94aeef6..3437eb117 100644 --- a/tests/unit/tasks/test_base_task.py +++ b/tests/unit/tasks/test_base_task.py @@ -159,3 +159,35 @@ def test_add_child_bitshift(self, task): assert child.id in task.child_ids assert task.id in child.parent_ids assert added_task == child + + def test_to_dict(self, task): + expected_task_dict = { + "type": task.type, + "id": task.id, + "state": str(task.state), + "parent_ids": task.parent_ids, + "child_ids": task.child_ids, + "max_meta_memory_entries": task.max_meta_memory_entries, + "context": task.context, + } + assert expected_task_dict == task.to_dict() + + def test_from_dict(self): + task = MockTask("Foobar2", id="Foobar2") + + serialized_task = task.to_dict() + assert isinstance(serialized_task, dict) + + deserialized_task = MockTask.from_dict(serialized_task) + assert isinstance(deserialized_task, MockTask) + + workflow = Workflow() + workflow.add_task(deserialized_task) + + assert workflow.tasks == [deserialized_task] + + workflow.run() + + assert str(workflow.tasks[0].state) == "State.FINISHED" + assert workflow.tasks[0].id == deserialized_task.id + assert workflow.tasks[0].output.value == "foobar" diff --git a/tests/unit/tasks/test_tool_task.py b/tests/unit/tasks/test_tool_task.py index f92f6a887..cb2a6b341 100644 --- a/tests/unit/tasks/test_tool_task.py +++ b/tests/unit/tasks/test_tool_task.py @@ -237,3 +237,39 @@ def test_actions_schema(self): Agent().add_task(task) assert task.actions_schema().json_schema("Actions Schema") == self.TARGET_TOOLS_SCHEMA + + def test_to_dict(self): + tool = MockTool() + task = ToolTask("test", tool=tool) + + expected_tool_task_dict = { + "type": task.type, + "id": task.id, + "state": str(task.state), + "parent_ids": task.parent_ids, + "child_ids": task.child_ids, + "max_meta_memory_entries": task.max_meta_memory_entries, + "context": task.context, + "tool": { + "type": task.tool.type, + "name": task.tool.name, + "input_memory": task.tool.input_memory, + "output_memory": task.tool.output_memory, + "install_dependencies_on_init": task.tool.install_dependencies_on_init, + "dependencies_install_directory": task.tool.dependencies_install_directory, + "verbose": task.tool.verbose, + "off_prompt": task.tool.off_prompt, + }, + } + assert expected_tool_task_dict == task.to_dict() + + def test_from_dict(self): + tool = MockTool() + task = ToolTask("test", tool=tool) + + serialized_tool_task = task.to_dict() + serialized_tool_task["tool"]["module_name"] = "tests.mocks.mock_tool.tool" + assert isinstance(serialized_tool_task, dict) + + deserialized_tool_task = ToolTask.from_dict(serialized_tool_task) + assert isinstance(deserialized_tool_task, ToolTask) diff --git a/tests/unit/tools/test_base_tool.py b/tests/unit/tools/test_base_tool.py index 60c9f6825..9f28acf02 100644 --- a/tests/unit/tools/test_base_tool.py +++ b/tests/unit/tools/test_base_tool.py @@ -6,6 +6,7 @@ from griptape.common import ToolAction from griptape.tasks import ActionsSubtask, ToolkitTask +from griptape.tools import BaseTool from tests.mocks.mock_tool.tool import MockTool from tests.utils import defaults @@ -279,3 +280,31 @@ def test_to_native_tool_name(self, tool, mocker): tool.name = "MockTool" with pytest.raises(ValueError, match="Activity name"): tool.to_native_tool_name(tool.test) + + def test_to_dict(self, tool): + tool = MockTool() + + expected_tool_dict = { + "type": tool.type, + "name": tool.name, + "input_memory": tool.input_memory, + "output_memory": tool.output_memory, + "install_dependencies_on_init": tool.install_dependencies_on_init, + "dependencies_install_directory": tool.dependencies_install_directory, + "verbose": tool.verbose, + "off_prompt": tool.off_prompt, + } + + assert expected_tool_dict == tool.to_dict() + + def test_from_dict(self, tool): + tool = MockTool() + action = ToolAction(input={}, name="", tag="") + + serialized_tool = tool.to_dict() + assert isinstance(serialized_tool, dict) + + deserialized_tool = MockTool.from_dict(serialized_tool) + assert isinstance(deserialized_tool, BaseTool) + + assert deserialized_tool.execute(tool.test_list_output, ActionsSubtask("foo"), action).to_text() == "foo\n\nbar"