Skip to content

Commit

Permalink
Structures, Tools, Tasks, are now serializable. (#1261)
Browse files Browse the repository at this point in the history
Co-authored-by: Collin Dutter <[email protected]>
  • Loading branch information
william-price01 and collindutter authored Oct 17, 2024
1 parent 829456b commit db492c9
Show file tree
Hide file tree
Showing 30 changed files with 380 additions and 70 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Chat` input now uses a slightly customized version of `Rich.prompt.Prompt` by default.
- `Chat` output now uses `Rich.print` by default.
- `Chat.output_fn`'s now takes an optional kwarg parameter, `stream`.
- Implemented `SerializableMixin` in `Structure`, `BaseTask`, `BaseTool`, and `TaskMemory`

### Fixed

Expand Down
15 changes: 11 additions & 4 deletions griptape/memory/task/task_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,20 @@
from griptape.memory.meta import ActionSubtaskMetaEntry
from griptape.memory.task.storage import BlobArtifactStorage, TextArtifactStorage
from griptape.mixins.activity_mixin import ActivityMixin
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from griptape.memory.task.storage import BaseArtifactStorage
from griptape.tasks import ActionsSubtask


@define
class TaskMemory(ActivityMixin):
name: str = field(default=Factory(lambda self: self.__class__.__name__, takes_self=True), kw_only=True)
class TaskMemory(ActivityMixin, SerializableMixin):
name: str = field(
default=Factory(lambda self: self.__class__.__name__, takes_self=True),
kw_only=True,
metadata={"serializable": True},
)
artifact_storages: dict[type, BaseArtifactStorage] = field(
default=Factory(
lambda: {
Expand All @@ -26,8 +31,10 @@ class TaskMemory(ActivityMixin):
),
kw_only=True,
)
namespace_storage: dict[str, BaseArtifactStorage] = field(factory=dict, kw_only=True)
namespace_metadata: dict[str, Any] = field(factory=dict, kw_only=True)
namespace_storage: dict[str, BaseArtifactStorage] = field(
factory=dict, kw_only=True, metadata={"serializable": True}
)
namespace_metadata: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})

@artifact_storages.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_artifact_storages(self, _: Attribute, artifact_storage: dict[type, BaseArtifactStorage]) -> None:
Expand Down
13 changes: 10 additions & 3 deletions griptape/mixins/serializable_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,26 @@ class SerializableMixin(Generic[T]):
kw_only=True,
metadata={"serializable": True},
)
module_name: str = field(
default=Factory(lambda self: self.__class__.__module__, takes_self=True),
kw_only=True,
metadata={"serializable": False},
)

@classmethod
def get_schema(cls: type[T], subclass_name: Optional[str] = None) -> Schema:
def get_schema(cls: type[T], subclass_name: Optional[str] = None, *, module_name: Optional[str] = None) -> Schema:
"""Generates a Marshmallow schema for the class.
Args:
subclass_name: An optional subclass name. Required if the class is abstract.
module_name: An optional module name. Defaults to the class's module.
"""
if ABC in cls.__bases__:
if subclass_name is None:
raise ValueError(f"Type field is required for abstract class: {cls.__name__}")

subclass_cls = cls._import_cls_rec(cls.__module__, subclass_name)
module_name = module_name or cls.__module__
subclass_cls = cls._import_cls_rec(module_name, subclass_name)

schema_class = BaseSchema.from_attrs_cls(subclass_cls)
else:
Expand All @@ -44,7 +51,7 @@ def get_schema(cls: type[T], subclass_name: Optional[str] = None) -> Schema:

@classmethod
def from_dict(cls: type[T], data: dict) -> T:
return cast(T, cls.get_schema(subclass_name=data.get("type")).load(data))
return cast(T, cls.get_schema(subclass_name=data.get("type"), module_name=data.get("module_name")).load(data))

@classmethod
def from_json(cls: type[T], data: str) -> T:
Expand Down
10 changes: 9 additions & 1 deletion griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,11 @@ def _resolve_types(cls, attrs_cls: type) -> None:
BaseVectorStoreDriver,
)
from griptape.events import EventListener
from griptape.memory.structure import Run
from griptape.memory import TaskMemory
from griptape.memory.structure import BaseConversationMemory, Run
from griptape.memory.task.storage import BaseArtifactStorage
from griptape.structures import Structure
from griptape.tasks import BaseTask
from griptape.tokenizers import BaseTokenizer
from griptape.tools import BaseTool
from griptape.utils import import_optional_dependency, is_dependency_installed
Expand All @@ -198,13 +201,18 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"BaseMessageContent": BaseMessageContent,
"BaseDeltaMessageContent": BaseDeltaMessageContent,
"BaseTool": BaseTool,
"BaseTask": BaseTask,
"Usage": Message.Usage,
"Structure": Structure,
"BaseTokenizer": BaseTokenizer,
"ToolAction": ToolAction,
"Reference": Reference,
"Run": Run,
"Sequence": Sequence,
"TaskMemory": TaskMemory,
"State": BaseTask.State,
"BaseConversationMemory": BaseConversationMemory,
"BaseArtifactStorage": BaseArtifactStorage,
# Third party modules
"Client": import_optional_dependency("cohere").Client if is_dependency_installed("cohere") else Any,
"GenerativeModel": import_optional_dependency("google.generativeai").GenerativeModel
Expand Down
2 changes: 1 addition & 1 deletion griptape/schemas/polymorphic_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _load(self, data: Any, *, partial: Any = None, unknown: Any = None, **kwargs
if data_type is None:
raise ValidationError({self.type_field: ["Missing data for required field."]})

type_schema = self.inner_class.get_schema(data_type)
type_schema = self.inner_class.get_schema(data_type, module_name=data.get("module_name"))
if not type_schema:
raise ValidationError({self.type_field: [f"Unsupported value: {data_type}"]})

Expand Down
4 changes: 2 additions & 2 deletions griptape/structures/agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, Union

from attrs import Attribute, Factory, define, field

Expand All @@ -19,7 +19,7 @@

@define
class Agent(Structure):
input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field(
input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field(
default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""),
)
stream: bool = field(default=Factory(lambda: Defaults.drivers_config.prompt_driver.stream), kw_only=True)
Expand Down
14 changes: 9 additions & 5 deletions griptape/structures/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import uuid
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Optional
from typing import TYPE_CHECKING, Any, Optional, Union

from attrs import Factory, define, field

Expand All @@ -12,6 +12,7 @@
from griptape.memory.meta import MetaMemory
from griptape.memory.structure import ConversationMemory, Run
from griptape.mixins.rule_mixin import RuleMixin
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
Expand All @@ -20,19 +21,22 @@


@define
class Structure(ABC, RuleMixin):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
_tasks: list[BaseTask | list[BaseTask]] = field(factory=list, kw_only=True, alias="tasks")
class Structure(ABC, RuleMixin, SerializableMixin):
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
_tasks: list[Union[BaseTask, list[BaseTask]]] = field(
factory=list, kw_only=True, alias="tasks", metadata={"serializable": True}
)
conversation_memory: Optional[BaseConversationMemory] = field(
default=Factory(lambda: ConversationMemory()),
kw_only=True,
metadata={"serializable": True},
)
task_memory: TaskMemory = field(
default=Factory(lambda self: TaskMemory(), takes_self=True),
kw_only=True,
)
meta_memory: MetaMemory = field(default=Factory(lambda: MetaMemory()), kw_only=True)
fail_fast: bool = field(default=True, kw_only=True)
fail_fast: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_execution_args: tuple = ()

def __attrs_post_init__(self) -> None:
Expand Down
8 changes: 4 additions & 4 deletions griptape/tasks/actions_subtask.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import logging
import re
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Callable, Optional, Union

import schema
from attrs import define, field
Expand Down Expand Up @@ -33,7 +33,7 @@ class ActionsSubtask(BaseTask):
thought: Optional[str] = field(default=None, kw_only=True)
actions: list[ToolAction] = field(factory=list, kw_only=True)
output: Optional[BaseArtifact] = field(default=None, init=False)
_input: str | list | tuple | BaseArtifact | Callable[[BaseTask], BaseArtifact] = field(
_input: Union[str, list, tuple, BaseArtifact, Callable[[BaseTask], BaseArtifact]] = field(
default=lambda task: task.full_context["args"][0] if task.full_context["args"] else TextArtifact(value=""),
alias="input",
)
Expand Down Expand Up @@ -197,8 +197,8 @@ def actions_to_json(self) -> str:

def _process_task_input(
self,
task_input: str | tuple | list | BaseArtifact | Callable[[BaseTask], BaseArtifact],
) -> TextArtifact | ListArtifact:
task_input: Union[str, tuple, list, BaseArtifact, Callable[[BaseTask], BaseArtifact]],
) -> Union[TextArtifact, ListArtifact]:
if isinstance(task_input, (TextArtifact, ListArtifact)):
return task_input
elif isinstance(task_input, ActionArtifact):
Expand Down
4 changes: 2 additions & 2 deletions griptape/tasks/base_audio_input_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from abc import ABC
from typing import Callable
from typing import Callable, Union

from attrs import define, field

Expand All @@ -16,7 +16,7 @@

@define
class BaseAudioInputTask(RuleMixin, BaseTask, ABC):
_input: AudioArtifact | Callable[[BaseTask], AudioArtifact] = field(alias="input")
_input: Union[AudioArtifact, Callable[[BaseTask], AudioArtifact]] = field(alias="input")

@property
def input(self) -> AudioArtifact:
Expand Down
15 changes: 8 additions & 7 deletions griptape/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from griptape.configs import Defaults
from griptape.events import EventBus, FinishTaskEvent, StartTaskEvent
from griptape.mixins.futures_executor_mixin import FuturesExecutorMixin
from griptape.mixins.serializable_mixin import SerializableMixin

if TYPE_CHECKING:
from griptape.artifacts import BaseArtifact
Expand All @@ -22,21 +23,21 @@


@define
class BaseTask(FuturesExecutorMixin, ABC):
class BaseTask(FuturesExecutorMixin, SerializableMixin, ABC):
class State(Enum):
PENDING = 1
EXECUTING = 2
FINISHED = 3

id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True)
state: State = field(default=State.PENDING, kw_only=True)
parent_ids: list[str] = field(factory=list, kw_only=True)
child_ids: list[str] = field(factory=list, kw_only=True)
max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True)
id: str = field(default=Factory(lambda: uuid.uuid4().hex), kw_only=True, metadata={"serializable": True})
state: State = field(default=State.PENDING, kw_only=True, metadata={"serializable": True})
parent_ids: list[str] = field(factory=list, kw_only=True, metadata={"serializable": True})
child_ids: list[str] = field(factory=list, kw_only=True, metadata={"serializable": True})
max_meta_memory_entries: Optional[int] = field(default=20, kw_only=True, metadata={"serializable": True})
structure: Optional[Structure] = field(default=None, kw_only=True)

output: Optional[BaseArtifact] = field(default=None, init=False)
context: dict[str, Any] = field(factory=dict, kw_only=True)
context: dict[str, Any] = field(factory=dict, kw_only=True, metadata={"serializable": True})

def __rshift__(self, other: BaseTask) -> BaseTask:
self.add_child(other)
Expand Down
4 changes: 2 additions & 2 deletions griptape/tasks/base_text_input_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from abc import ABC
from typing import Callable
from typing import Callable, Union

from attrs import define, field

Expand All @@ -19,7 +19,7 @@
class BaseTextInputTask(RuleMixin, BaseTask, ABC):
DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}"

_input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field(
_input: Union[str, TextArtifact, Callable[[BaseTask], TextArtifact]] = field(
default=DEFAULT_INPUT_TEMPLATE,
alias="input",
)
Expand Down
22 changes: 12 additions & 10 deletions griptape/tasks/image_query_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable
from typing import Callable, Union

from attrs import Factory, define, field

Expand All @@ -25,12 +25,12 @@ class ImageQueryTask(BaseTask):
"""

image_query_engine: ImageQueryEngine = field(default=Factory(lambda: ImageQueryEngine()), kw_only=True)
_input: (
tuple[str, list[ImageArtifact]]
| tuple[TextArtifact, list[ImageArtifact]]
| Callable[[BaseTask], ListArtifact]
| ListArtifact
) = field(default=None, alias="input")
_input: Union[
tuple[str, list[ImageArtifact]],
tuple[TextArtifact, list[ImageArtifact]],
Callable[[BaseTask], ListArtifact],
ListArtifact,
] = field(default=None, alias="input")

@property
def input(self) -> ListArtifact:
Expand All @@ -55,9 +55,11 @@ def input(self) -> ListArtifact:
def input(
self,
value: (
tuple[str, list[ImageArtifact]]
| tuple[TextArtifact, list[ImageArtifact]]
| Callable[[BaseTask], ListArtifact]
Union[
tuple[str, list[ImageArtifact]],
tuple[TextArtifact, list[ImageArtifact]],
Callable[[BaseTask], ListArtifact],
]
),
) -> None:
self._input = value
Expand Down
8 changes: 4 additions & 4 deletions griptape/tasks/inpainting_image_generation_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable
from typing import Callable, Union

from attrs import Factory, define, field

Expand Down Expand Up @@ -32,9 +32,9 @@ class InpaintingImageGenerationTask(BaseImageGenerationTask):
default=Factory(lambda: InpaintingImageGenerationEngine()),
kw_only=True,
)
_input: (
tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact
) = field(default=None, alias="input")
_input: Union[
tuple[Union[str, TextArtifact], ImageArtifact, ImageArtifact], Callable[[BaseTask], ListArtifact], ListArtifact
] = field(default=None, alias="input")

@property
def input(self) -> ListArtifact:
Expand Down
8 changes: 4 additions & 4 deletions griptape/tasks/outpainting_image_generation_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable
from typing import Callable, Union

from attrs import Factory, define, field

Expand Down Expand Up @@ -32,9 +32,9 @@ class OutpaintingImageGenerationTask(BaseImageGenerationTask):
default=Factory(lambda: OutpaintingImageGenerationEngine()),
kw_only=True,
)
_input: (
tuple[str | TextArtifact, ImageArtifact, ImageArtifact] | Callable[[BaseTask], ListArtifact] | ListArtifact
) = field(default=None, alias="input")
_input: Union[
tuple[Union[str, TextArtifact], ImageArtifact, ImageArtifact], Callable[[BaseTask], ListArtifact], ListArtifact
] = field(default=None, alias="input")

@property
def input(self) -> ListArtifact:
Expand Down
4 changes: 2 additions & 2 deletions griptape/tasks/prompt_image_generation_task.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Callable
from typing import Callable, Union

from attrs import Factory, define, field

Expand Down Expand Up @@ -29,7 +29,7 @@ class PromptImageGenerationTask(BaseImageGenerationTask):

DEFAULT_INPUT_TEMPLATE = "{{ args[0] }}"

_input: str | TextArtifact | Callable[[BaseTask], TextArtifact] = field(
_input: Union[str, TextArtifact, Callable[[BaseTask], TextArtifact]] = field(
default=DEFAULT_INPUT_TEMPLATE, alias="input"
)
image_generation_engine: PromptImageGenerationEngine = field(
Expand Down
Loading

0 comments on commit db492c9

Please sign in to comment.