diff --git a/.github/workflows/docs-integration-tests.yml b/.github/workflows/docs-integration-tests.yml index d8e2162ed..81807be59 100644 --- a/.github/workflows/docs-integration-tests.yml +++ b/.github/workflows/docs-integration-tests.yml @@ -125,6 +125,11 @@ jobs: ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.INTEG_ASTRA_DB_APPLICATION_TOKEN }} TAVILY_API_KEY: ${{ secrets.INTEG_TAVILY_API_KEY }} EXA_API_KEY: ${{ secrets.INTEG_EXA_API_KEY }} + AMAZON_S3_BUCKET: ${{ secrets.INTEG_AMAZON_S3_BUCKET }} + AMAZON_S3_KEY: ${{ secrets.INTEG_AMAZON_S3_KEY }} + GT_CLOUD_BUCKET_ID: ${{ secrets.INTEG_GT_CLOUD_BUCKET_ID }} + GT_CLOUD_ASSET_NAME: ${{ secrets.INTEG_GT_CLOUD_ASSET_NAME }} + services: postgres: image: ankane/pgvector:v0.5.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index a6bab4e61..e8930532f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing. - `BaseTask.task_outputs` to get a dictionary of all task outputs. This has been added to `Workflow.context` and `Pipeline.context`. - `Chat.input_fn` for customizing the input to the Chat utility. +- `GriptapeCloudFileManagerDriver` for managing files on Griptape Cloud. +- `BaseFileManagerDriver.load_artifact()` & `BaseFileManagerDriver.save_artifact()` for loading & saving artifacts as files. +- Events `BaseChunkEvent`, `TextChunkEvent`, `ActionChunkEvent`. ### Changed @@ -23,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - **BREAKING**: Renamed parameter `driver` on `EventListener` to `event_listener_driver`. - **BREAKING**: Changed default value of parameter `handler` on `EventListener` to `None`. - **BREAKING**: Updated `EventListener.handler` return value behavior. +- **BREAKING**: Removed `CompletionChunkEvent`. - If `EventListener.handler` returns `None`, the event will not be published to the `event_listener_driver`. - If `EventListener.handler` is None, the event will be published to the `event_listener_driver` as-is. - Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`. @@ -42,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `@activity` decorated functions can now accept kwargs that are defined in the activity schema. - Updated `ToolkitTask` system prompt to no longer mention `memory_name` and `artifact_namespace`. - Models in `ToolkitTask` with native tool calling no longer need to provide their final answer as `Answer:`. +- `EventListener.event_types` will now listen on child types of any provided type. ### Fixed diff --git a/MIGRATION.md b/MIGRATION.md index 956611de8..41b501870 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -4,6 +4,46 @@ This document provides instructions for migrating your codebase to accommodate b ## 0.33.X to 0.34.X +### Removed `CompletionChunkEvent` + +`CompletionChunkEvent` has been removed. There is now `BaseChunkEvent` with children `TextChunkEvent` and `ActionChunkEvent`. `BaseChunkEvent` can replace `completion_chunk_event.token` by doing `str(base_chunk_event)`. + +#### Before + +```python +def handler_fn_stream(event: CompletionChunkEvent) -> None: + print(f"CompletionChunkEvent: {event.to_json()}") + +def handler_fn_stream_text(event: CompletionChunkEvent) -> None: + # This prints out Tool actions with no easy way + # to filter them out + print(event.token, end="", flush=True) + +EventListener(handler=handler_fn_stream, event_types=[CompletionChunkEvent]) +EventListener(handler=handler_fn_stream_text, event_types=[CompletionChunkEvent]) +``` + +#### After + +```python +def handler_fn_stream(event: BaseChunkEvent) -> None: + print(str(e), end="", flush=True) + # print out each child event type + if isinstance(event, TextChunkEvent): + print(f"TextChunkEvent: {event.to_json()}") + if isinstance(event, ActionChunkEvent): + print(f"ActionChunkEvent: {event.to_json()}") + + +def handler_fn_stream_text(event: TextChunkEvent) -> None: + # This will only be text coming from the + # prompt driver, not Tool actions + print(event.token, end="", flush=True) + +EventListener(handler=handler_fn_stream, event_types=[BaseChunkEvent]) +EventListener(handler=handler_fn_stream_text, event_types=[TextChunkEvent]) +``` + ### `EventListener.handler` behavior, `driver` parameter rename Returning `None` from the `handler` function now causes the event to not be published to the `EventListenerDriver`. diff --git a/Makefile b/Makefile index 0344c8c19..65a674291 100644 --- a/Makefile +++ b/Makefile @@ -40,6 +40,10 @@ test: test/unit test/integration test/unit: ## Run unit tests. @poetry run pytest -n auto tests/unit +.PHONY: test/unit/% +test/unit/%: ## Run specific unit tests. + @poetry run pytest -n auto tests/unit -k $* + .PHONY: test/unit/coverage test/unit/coverage: @poetry run pytest -n auto --cov=griptape tests/unit diff --git a/docs/griptape-framework/drivers/file-manager-drivers.md b/docs/griptape-framework/drivers/file-manager-drivers.md new file mode 100644 index 000000000..37012c29f --- /dev/null +++ b/docs/griptape-framework/drivers/file-manager-drivers.md @@ -0,0 +1,48 @@ +--- +search: + boost: 2 +--- + +## Overview + +File Manager Drivers can be used to load and save files with local or external file systems. + +You can use File Manager Drivers with Loaders: + +```python +--8<-- "docs/griptape-framework/drivers/src/file_manager_driver.py" +``` + +Or use them independently as shown below for each driver: + +## File Manager Drivers + +### Griptape Cloud + +!!! info + This driver requires the `drivers-file-manager-griptape-cloud` [extra](../index.md#extras). + +The [GriptapeCloudFileManagerDriver](../../reference/griptape/drivers/file_manager/griptape_cloud_file_manager_driver.md) allows you to load and save files sourced from Griptape Cloud Asset and Bucket resources. + +```python +--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_file_manager_driver.py" +``` + +### Local + +The [LocalFileManagerDriver](../../reference/griptape/drivers/file_manager/local_file_manager_driver.md) allows you to load and save files sourced from a local directory. + +```python +--8<-- "docs/griptape-framework/drivers/src/local_file_manager_driver.py" +``` + +### Amazon S3 + +!!! info + This driver requires the `drivers-file-manager-amazon-s3` [extra](../index.md#extras). + +The [LocalFile ManagerDriver](../../reference/griptape/drivers/file_manager/amazon_s3_file_manager_driver.md) allows you to load and save files sourced from an Amazon S3 bucket. + +```python +--8<-- "docs/griptape-framework/drivers/src/amazon_s3_file_manager_driver.py" +``` diff --git a/docs/griptape-framework/drivers/src/amazon_s3_file_manager_driver.py b/docs/griptape-framework/drivers/src/amazon_s3_file_manager_driver.py new file mode 100644 index 000000000..5fa9324cb --- /dev/null +++ b/docs/griptape-framework/drivers/src/amazon_s3_file_manager_driver.py @@ -0,0 +1,24 @@ +import os + +import boto3 + +from griptape.drivers import AmazonS3FileManagerDriver + +amazon_s3_file_manager_driver = AmazonS3FileManagerDriver( + bucket=os.environ["AMAZON_S3_BUCKET"], + session=boto3.Session( + region_name=os.environ["AWS_DEFAULT_REGION"], + aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], + aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], + ), +) + +# Download File +file_contents = amazon_s3_file_manager_driver.load_file(os.environ["AMAZON_S3_KEY"]) + +print(file_contents) + +# Upload File +response = amazon_s3_file_manager_driver.save_file(os.environ["AMAZON_S3_KEY"], file_contents.value) + +print(response) diff --git a/docs/griptape-framework/drivers/src/file_manager_driver.py b/docs/griptape-framework/drivers/src/file_manager_driver.py new file mode 100644 index 000000000..0ba2e26c7 --- /dev/null +++ b/docs/griptape-framework/drivers/src/file_manager_driver.py @@ -0,0 +1,9 @@ +from griptape.drivers import LocalFileManagerDriver +from griptape.loaders import TextLoader + +local_file_manager_driver = LocalFileManagerDriver() + +loader = TextLoader(file_manager_driver=local_file_manager_driver) +text_artifact = loader.load("tests/resources/test.txt") + +print(text_artifact.value) diff --git a/docs/griptape-framework/drivers/src/griptape_cloud_file_manager_driver.py b/docs/griptape-framework/drivers/src/griptape_cloud_file_manager_driver.py new file mode 100644 index 000000000..b222b5d4a --- /dev/null +++ b/docs/griptape-framework/drivers/src/griptape_cloud_file_manager_driver.py @@ -0,0 +1,18 @@ +import os + +from griptape.drivers import GriptapeCloudFileManagerDriver + +gtc_file_manager_driver = GriptapeCloudFileManagerDriver( + api_key=os.environ["GT_CLOUD_API_KEY"], + bucket_id=os.environ["GT_CLOUD_BUCKET_ID"], +) + +# Download File +file_contents = gtc_file_manager_driver.load_file(os.environ["GT_CLOUD_ASSET_NAME"]) + +print(file_contents) + +# Upload File +response = gtc_file_manager_driver.save_file(os.environ["GT_CLOUD_ASSET_NAME"], file_contents.value) + +print(response) diff --git a/docs/griptape-framework/drivers/src/local_file_manager_driver.py b/docs/griptape-framework/drivers/src/local_file_manager_driver.py new file mode 100644 index 000000000..a53378060 --- /dev/null +++ b/docs/griptape-framework/drivers/src/local_file_manager_driver.py @@ -0,0 +1,13 @@ +from griptape.drivers import LocalFileManagerDriver + +local_file_manager_driver = LocalFileManagerDriver() + +# Download File +file_contents = local_file_manager_driver.load_file("tests/resources/test.txt") + +print(file_contents) + +# Upload File +response = local_file_manager_driver.save_file("tests/resources/test.txt", file_contents.value) + +print(response) diff --git a/docs/griptape-framework/misc/events.md b/docs/griptape-framework/misc/events.md index da62caefc..b97b9de98 100644 --- a/docs/griptape-framework/misc/events.md +++ b/docs/griptape-framework/misc/events.md @@ -85,14 +85,20 @@ The `EventListener` will automatically be added and removed from the [EventBus]( ## Streaming -You can use the [CompletionChunkEvent](../../reference/griptape/events/completion_chunk_event.md) to stream the completion results from Prompt Drivers. +You can use the [BaseChunkEvent](../../reference/griptape/events/base_chunk_event.md) to stream the completion results from Prompt Drivers. ```python --8<-- "docs/griptape-framework/misc/src/events_3.py" ``` -You can also use the [Stream](../../reference/griptape/utils/stream.md) utility to automatically wrap -[CompletionChunkEvent](../../reference/griptape/events/completion_chunk_event.md)s in a Python iterator. +You can also use the [TextChunkEvent](../../reference/griptape/events/text_chunk_event.md) and [ActionChunkEvent](../../reference/griptape/events/action_chunk_event.md) to further differentiate the different types of chunks for more customized output. + +```python +--8<-- "docs/griptape-framework/misc/src/events_chunk_stream.py" +``` + +If you want Griptape to handle the chunk events for you, use the [Stream](../../reference/griptape/utils/stream.md) utility to automatically wrap +[BaseChunkEvent](../../reference/griptape/events/base_chunk_event.md)s in a Python iterator. ```python --8<-- "docs/griptape-framework/misc/src/events_4.py" diff --git a/docs/griptape-framework/misc/src/events_3.py b/docs/griptape-framework/misc/src/events_3.py index 7adac812f..beacf814a 100644 --- a/docs/griptape-framework/misc/src/events_3.py +++ b/docs/griptape-framework/misc/src/events_3.py @@ -1,7 +1,5 @@ -from typing import cast - from griptape.drivers import OpenAiChatPromptDriver -from griptape.events import CompletionChunkEvent, EventBus, EventListener +from griptape.events import BaseChunkEvent, EventBus, EventListener from griptape.structures import Pipeline from griptape.tasks import ToolkitTask from griptape.tools import PromptSummaryTool, WebScraperTool @@ -9,9 +7,9 @@ EventBus.add_event_listeners( [ EventListener( - lambda e: print(cast(CompletionChunkEvent, e).token, end="", flush=True), - event_types=[CompletionChunkEvent], - ) + lambda e: print(str(e), end="", flush=True), + event_types=[BaseChunkEvent], + ), ] ) diff --git a/docs/griptape-framework/misc/src/events_chunk_stream.py b/docs/griptape-framework/misc/src/events_chunk_stream.py new file mode 100644 index 000000000..3ab5517f4 --- /dev/null +++ b/docs/griptape-framework/misc/src/events_chunk_stream.py @@ -0,0 +1,29 @@ +from griptape.drivers import OpenAiChatPromptDriver +from griptape.events import ActionChunkEvent, EventBus, EventListener, TextChunkEvent +from griptape.structures import Pipeline +from griptape.tasks import ToolkitTask +from griptape.tools import PromptSummaryTool, WebScraperTool + +EventBus.add_event_listeners( + [ + EventListener( + lambda e: print(str(e), end="", flush=True), + event_types=[TextChunkEvent], + ), + EventListener( + lambda e: print(str(e), end="", flush=True), + event_types=[ActionChunkEvent], + ), + ] +) + +pipeline = Pipeline() +pipeline.add_tasks( + ToolkitTask( + "Based on https://griptape.ai, tell me what griptape is.", + prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True), + tools=[WebScraperTool(off_prompt=True), PromptSummaryTool(off_prompt=False)], + ) +) + +pipeline.run() diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index a4806fc72..4acbc9a19 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -112,6 +112,7 @@ from .file_manager.base_file_manager_driver import BaseFileManagerDriver from .file_manager.local_file_manager_driver import LocalFileManagerDriver from .file_manager.amazon_s3_file_manager_driver import AmazonS3FileManagerDriver +from .file_manager.griptape_cloud_file_manager_driver import GriptapeCloudFileManagerDriver from .rerank.base_rerank_driver import BaseRerankDriver from .rerank.cohere_rerank_driver import CohereRerankDriver @@ -230,6 +231,7 @@ "BaseFileManagerDriver", "LocalFileManagerDriver", "AmazonS3FileManagerDriver", + "GriptapeCloudFileManagerDriver", "BaseRerankDriver", "CohereRerankDriver", "BaseRulesetDriver", diff --git a/griptape/drivers/event_listener/pusher_event_listener_driver.py b/griptape/drivers/event_listener/pusher_event_listener_driver.py index 33d160b46..263876777 100644 --- a/griptape/drivers/event_listener/pusher_event_listener_driver.py +++ b/griptape/drivers/event_listener/pusher_event_listener_driver.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional from attrs import define, field @@ -21,7 +21,7 @@ class PusherEventListenerDriver(BaseEventListenerDriver): channel: str = field(kw_only=True, metadata={"serializable": True}) event_name: str = field(kw_only=True, metadata={"serializable": True}) ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True}) - _client: Pusher = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) + _client: Optional[Pusher] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() def client(self) -> Pusher: diff --git a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py index 1e841866a..ec9037fd8 100644 --- a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py +++ b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py @@ -64,11 +64,12 @@ def try_load_file(self, path: str) -> bytes: raise FileNotFoundError from e raise e - def try_save_file(self, path: str, value: bytes) -> None: + def try_save_file(self, path: str, value: bytes) -> str: full_key = self._to_full_key(path) if self._is_a_directory(full_key): raise IsADirectoryError self.client.put_object(Bucket=self.bucket, Key=full_key, Body=value) + return f"s3://{self.bucket}/{full_key}" def _to_full_key(self, path: str) -> str: path = path.lstrip("/") diff --git a/griptape/drivers/file_manager/base_file_manager_driver.py b/griptape/drivers/file_manager/base_file_manager_driver.py index c904f1532..3c8a680da 100644 --- a/griptape/drivers/file_manager/base_file_manager_driver.py +++ b/griptape/drivers/file_manager/base_file_manager_driver.py @@ -5,7 +5,7 @@ from attrs import define, field -from griptape.artifacts import BlobArtifact, InfoArtifact, TextArtifact +from griptape.artifacts import BaseArtifact, BlobArtifact, InfoArtifact, TextArtifact @define @@ -42,9 +42,23 @@ def save_file(self, path: str, value: bytes | str) -> InfoArtifact: elif isinstance(value, (bytearray, memoryview)): raise ValueError(f"Unsupported type: {type(value)}") - self.try_save_file(path, value) + location = self.try_save_file(path, value) - return InfoArtifact("Successfully saved file") + return InfoArtifact(f"Successfully saved file at: {location}") @abstractmethod - def try_save_file(self, path: str, value: bytes) -> None: ... + def try_save_file(self, path: str, value: bytes) -> str: ... + + def load_artifact(self, path: str) -> BaseArtifact: + response = self.try_load_file(path) + return BaseArtifact.from_json( + response.decode() if self.encoding is None else response.decode(encoding=self.encoding) + ) + + def save_artifact(self, path: str, artifact: BaseArtifact) -> InfoArtifact: + artifact_json = artifact.to_json() + value = artifact_json.encode() if self.encoding is None else artifact_json.encode(encoding=self.encoding) + + location = self.try_save_file(path, value) + + return InfoArtifact(f"Successfully saved artifact at: {location}") diff --git a/griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py b/griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py new file mode 100644 index 000000000..5138a1fe4 --- /dev/null +++ b/griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING, Optional +from urllib.parse import urljoin + +import requests +from attrs import Attribute, Factory, define, field + +from griptape.drivers import BaseFileManagerDriver +from griptape.utils import import_optional_dependency + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + from azure.storage.blob import BlobClient + + +@define +class GriptapeCloudFileManagerDriver(BaseFileManagerDriver): + """GriptapeCloudFileManagerDriver can be used to list, load, and save files as Assets in Griptape Cloud Buckets. + + Attributes: + bucket_id: The ID of the Bucket to list, load, and save Assets in. If not provided, the driver will attempt to + retrieve the ID from the environment variable `GT_CLOUD_BUCKET_ID`. + workdir: The working directory. List, load, and save operations will be performed relative to this directory. + base_url: The base URL of the Griptape Cloud API. Defaults to the value of the environment variable + `GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`. + api_key: The API key to use for authenticating with the Griptape Cloud API. If not provided, the driver will + attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`. + + Raises: + ValueError: If `api_key` is not provided, if `workdir` does not start with "/"", or invalid `bucket_id` and/or `bucket_name` value(s) are provided. + """ + + bucket_id: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_BUCKET_ID")), kw_only=True) + workdir: str = field(default="/", kw_only=True) + base_url: str = field( + default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), + ) + api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY"))) + headers: dict = field( + default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), + init=False, + ) + + @workdir.validator # pyright: ignore[reportAttributeAccessIssue] + def validate_workdir(self, _: Attribute, workdir: str) -> None: + if not workdir.startswith("/"): + raise ValueError(f"{self.__class__.__name__} requires 'workdir' to be an absolute path, starting with `/`") + + @api_key.validator # pyright: ignore[reportAttributeAccessIssue] + def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: + if value is None: + raise ValueError(f"{self.__class__.__name__} requires an API key") + return value + + @bucket_id.validator # pyright: ignore[reportAttributeAccessIssue] + def validate_bucket_id(self, _: Attribute, value: Optional[str]) -> str: + if value is None: + raise ValueError(f"{self.__class__.__name__} requires an Bucket ID") + return value + + def __attrs_post_init__(self) -> None: + try: + self._call_api(method="get", path=f"/buckets/{self.bucket_id}").json() + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + raise ValueError(f"No Bucket found with ID: {self.bucket_id}") from e + raise ValueError(f"Unexpected error when retrieving Bucket with ID: {self.bucket_id}") from e + + def try_list_files(self, path: str, postfix: str = "") -> list[str]: + full_key = self._to_full_key(path) + + if not self._is_a_directory(full_key): + raise NotADirectoryError + + data = {"prefix": full_key} + if postfix: + data["postfix"] = postfix + # TODO: GTC SDK: Pagination + list_assets_response = self._call_api( + method="list", path=f"/buckets/{self.bucket_id}/assets", json=data, raise_for_status=False + ).json() + + return [asset["name"] for asset in list_assets_response.get("assets", [])] + + def try_load_file(self, path: str) -> bytes: + full_key = self._to_full_key(path) + + if self._is_a_directory(full_key): + raise IsADirectoryError + + try: + blob_client = self._get_blob_client(full_key=full_key) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + raise FileNotFoundError from e + raise e + + try: + return blob_client.download_blob().readall() + except import_optional_dependency("azure.core.exceptions").ResourceNotFoundError as e: + raise FileNotFoundError from e + + def try_save_file(self, path: str, value: bytes) -> str: + full_key = self._to_full_key(path) + + if self._is_a_directory(full_key): + raise IsADirectoryError + + try: + self._call_api(method="get", path=f"/buckets/{self.bucket_id}/assets/{full_key}", raise_for_status=True) + except requests.exceptions.HTTPError as e: + if e.response.status_code == 404: + logger.info("Asset '%s' not found, attempting to create", full_key) + data = {"name": full_key} + self._call_api(method="put", path=f"/buckets/{self.bucket_id}/assets", json=data, raise_for_status=True) + else: + raise e + + blob_client = self._get_blob_client(full_key=full_key) + + blob_client.upload_blob(data=value, overwrite=True) + return f"buckets/{self.bucket_id}/assets/{full_key}" + + def _get_blob_client(self, full_key: str) -> BlobClient: + url_response = self._call_api( + method="post", path=f"/buckets/{self.bucket_id}/asset-urls/{full_key}", raise_for_status=True + ).json() + sas_url = url_response["url"] + return import_optional_dependency("azure.storage.blob").BlobClient.from_blob_url(blob_url=sas_url) + + def _get_url(self, path: str) -> str: + path = path.lstrip("/") + return urljoin(self.base_url, f"/api/{path}") + + def _call_api( + self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True + ) -> requests.Response: + res = requests.request(method, self._get_url(path), json=json, headers=self.headers) + if raise_for_status: + res.raise_for_status() + return res + + def _is_a_directory(self, path: str) -> bool: + return path == "" or path.endswith("/") + + def _to_full_key(self, path: str) -> str: + path = path.lstrip("/") + full_key = f"{self.workdir}/{path}" + return full_key.lstrip("/") diff --git a/griptape/drivers/file_manager/local_file_manager_driver.py b/griptape/drivers/file_manager/local_file_manager_driver.py index b383ff7d7..69ef3ae1f 100644 --- a/griptape/drivers/file_manager/local_file_manager_driver.py +++ b/griptape/drivers/file_manager/local_file_manager_driver.py @@ -34,12 +34,13 @@ def try_load_file(self, path: str) -> bytes: raise IsADirectoryError return Path(full_path).read_bytes() - def try_save_file(self, path: str, value: bytes) -> None: + def try_save_file(self, path: str, value: bytes) -> str: full_path = self._full_path(path) if self._is_dir(full_path): raise IsADirectoryError os.makedirs(os.path.dirname(full_path), exist_ok=True) Path(full_path).write_bytes(value) + return full_path def _full_path(self, path: str) -> str: full_path = path if self.workdir is None else os.path.join(self.workdir, path.lstrip("/")) diff --git a/griptape/drivers/prompt/base_prompt_driver.py b/griptape/drivers/prompt/base_prompt_driver.py index 778b6f474..9af43f082 100644 --- a/griptape/drivers/prompt/base_prompt_driver.py +++ b/griptape/drivers/prompt/base_prompt_driver.py @@ -16,7 +16,13 @@ TextMessageContent, observable, ) -from griptape.events import CompletionChunkEvent, EventBus, FinishPromptEvent, StartPromptEvent +from griptape.events import ( + ActionChunkEvent, + EventBus, + FinishPromptEvent, + StartPromptEvent, + TextChunkEvent, +) from griptape.mixins.exponential_backoff_mixin import ExponentialBackoffMixin from griptape.mixins.serializable_mixin import SerializableMixin @@ -127,12 +133,17 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message: else: delta_contents[content.index] = [content] if isinstance(content, TextDeltaMessageContent): - EventBus.publish_event(CompletionChunkEvent(token=content.text)) + EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index)) elif isinstance(content, ActionCallDeltaMessageContent): - if content.tag is not None and content.name is not None and content.path is not None: - EventBus.publish_event(CompletionChunkEvent(token=str(content))) - elif content.partial_input is not None: - EventBus.publish_event(CompletionChunkEvent(token=content.partial_input)) + EventBus.publish_event( + ActionChunkEvent( + partial_input=content.partial_input, + tag=content.tag, + name=content.name, + path=content.path, + index=content.index, + ), + ) # Build a complete content from the content deltas return self.__build_message(list(delta_contents.values()), usage) diff --git a/griptape/events/__init__.py b/griptape/events/__init__.py index b3e2f3a79..e8a14d750 100644 --- a/griptape/events/__init__.py +++ b/griptape/events/__init__.py @@ -10,7 +10,9 @@ from .finish_prompt_event import FinishPromptEvent from .start_structure_run_event import StartStructureRunEvent from .finish_structure_run_event import FinishStructureRunEvent -from .completion_chunk_event import CompletionChunkEvent +from .base_chunk_event import BaseChunkEvent +from .text_chunk_event import TextChunkEvent +from .action_chunk_event import ActionChunkEvent from .event_listener import EventListener from .start_image_generation_event import StartImageGenerationEvent from .finish_image_generation_event import FinishImageGenerationEvent @@ -37,7 +39,9 @@ "FinishPromptEvent", "StartStructureRunEvent", "FinishStructureRunEvent", - "CompletionChunkEvent", + "BaseChunkEvent", + "TextChunkEvent", + "ActionChunkEvent", "EventListener", "StartImageGenerationEvent", "FinishImageGenerationEvent", diff --git a/griptape/events/action_chunk_event.py b/griptape/events/action_chunk_event.py new file mode 100644 index 000000000..d51bc017f --- /dev/null +++ b/griptape/events/action_chunk_event.py @@ -0,0 +1,33 @@ +from __future__ import annotations + +from typing import Optional + +from attrs import define, field + +from griptape.events.base_chunk_event import BaseChunkEvent + + +@define +class ActionChunkEvent(BaseChunkEvent): + partial_input: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + tag: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + path: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) + + def __str__(self) -> str: + parts = [] + + if self.name: + parts.append(self.name) + if self.path: + parts.append(f".{self.path}") + if self.tag: + parts.append(f" ({self.tag})") + + if self.partial_input: + if parts: + parts.append(f"\n{self.partial_input}") + else: + parts.append(self.partial_input) + + return "".join(parts) diff --git a/griptape/events/base_chunk_event.py b/griptape/events/base_chunk_event.py new file mode 100644 index 000000000..c94fc9e2d --- /dev/null +++ b/griptape/events/base_chunk_event.py @@ -0,0 +1,13 @@ +from abc import abstractmethod + +from attrs import define, field + +from griptape.events.base_event import BaseEvent + + +@define +class BaseChunkEvent(BaseEvent): + index: int = field(default=0, metadata={"serializable": True}) + + @abstractmethod + def __str__(self) -> str: ... diff --git a/griptape/events/completion_chunk_event.py b/griptape/events/completion_chunk_event.py deleted file mode 100644 index 48b479625..000000000 --- a/griptape/events/completion_chunk_event.py +++ /dev/null @@ -1,8 +0,0 @@ -from attrs import define, field - -from griptape.events.base_event import BaseEvent - - -@define -class CompletionChunkEvent(BaseEvent): - token: str = field(kw_only=True, metadata={"serializable": True}) diff --git a/griptape/events/event_listener.py b/griptape/events/event_listener.py index 704e20d32..df4a2668a 100644 --- a/griptape/events/event_listener.py +++ b/griptape/events/event_listener.py @@ -49,7 +49,7 @@ def __exit__(self, type, value, traceback) -> None: # noqa: ANN001, A002 def publish_event(self, event: T, *, flush: bool = False) -> None: event_types = self.event_types - if event_types is None or type(event) in event_types: + if event_types is None or any(isinstance(event, event_type) for event_type in event_types): handled_event = event if self.handler is not None: handled_event = self.handler(event) diff --git a/griptape/events/text_chunk_event.py b/griptape/events/text_chunk_event.py new file mode 100644 index 000000000..7d3880bf2 --- /dev/null +++ b/griptape/events/text_chunk_event.py @@ -0,0 +1,11 @@ +from attrs import define, field + +from griptape.events.base_chunk_event import BaseChunkEvent + + +@define +class TextChunkEvent(BaseChunkEvent): + token: str = field(kw_only=True, metadata={"serializable": True}) + + def __str__(self) -> str: + return self.token diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index 8a764e85a..f722db33d 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -1,5 +1,6 @@ from __future__ import annotations +import json from queue import Queue from threading import Thread from typing import TYPE_CHECKING @@ -7,7 +8,15 @@ from attrs import Attribute, Factory, define, field from griptape.artifacts.text_artifact import TextArtifact -from griptape.events import CompletionChunkEvent, EventBus, EventListener, FinishPromptEvent, FinishStructureRunEvent +from griptape.events import ( + ActionChunkEvent, + BaseChunkEvent, + EventBus, + EventListener, + FinishPromptEvent, + FinishStructureRunEvent, + TextChunkEvent, +) if TYPE_CHECKING: from collections.abc import Iterator @@ -18,7 +27,7 @@ @define class Stream: - """A wrapper for Structures that converts `CompletionChunkEvent`s into an iterator of TextArtifacts. + """A wrapper for Structures that converts `BaseChunkEvent`s into an iterator of TextArtifacts. It achieves this by running the Structure in a separate thread, listening for events from the Structure, and yielding those events. @@ -48,14 +57,25 @@ def run(self, *args) -> Iterator[TextArtifact]: t = Thread(target=self._run_structure, args=args) t.start() + action_str = "" while True: event = self._event_queue.get() if isinstance(event, FinishStructureRunEvent): break elif isinstance(event, FinishPromptEvent): yield TextArtifact(value="\n") - elif isinstance(event, CompletionChunkEvent): + elif isinstance(event, TextChunkEvent): yield TextArtifact(value=event.token) + elif isinstance(event, ActionChunkEvent): + if event.tag is not None and event.name is not None and event.path is not None: + yield TextArtifact(f"{event.name}.{event.tag} ({event.path})") + if event.partial_input is not None: + action_str += event.partial_input + try: + yield TextArtifact(json.dumps(json.loads(action_str), indent=2)) + action_str = "" + except Exception: + pass t.join() def _run_structure(self, *args) -> None: @@ -64,7 +84,7 @@ def event_handler(event: BaseEvent) -> None: stream_event_listener = EventListener( handler=event_handler, - event_types=[CompletionChunkEvent, FinishPromptEvent, FinishStructureRunEvent], + event_types=[BaseChunkEvent, FinishPromptEvent, FinishStructureRunEvent], ) EventBus.add_event_listener(stream_event_listener) diff --git a/mkdocs.yml b/mkdocs.yml index 6a8dd6fea..f43b9e1f7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -133,6 +133,7 @@ nav: - Web Search Drivers: "griptape-framework/drivers/web-search-drivers.md" - Observability Drivers: "griptape-framework/drivers/observability-drivers.md" - Ruleset Drivers: "griptape-framework/drivers/ruleset-drivers.md" + - File Manager Drivers: "griptape-framework/drivers/file-manager-drivers.md" - Data: - Overview: "griptape-framework/data/index.md" - Artifacts: "griptape-framework/data/artifacts.md" diff --git a/poetry.lock b/poetry.lock index e0a525eb6..287a21f21 100644 --- a/poetry.lock +++ b/poetry.lock @@ -265,6 +265,45 @@ docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphi tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +[[package]] +name = "azure-core" +version = "1.31.0" +description = "Microsoft Azure Core Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure_core-1.31.0-py3-none-any.whl", hash = "sha256:22954de3777e0250029360ef31d80448ef1be13b80a459bff80ba7073379e2cd"}, + {file = "azure_core-1.31.0.tar.gz", hash = "sha256:656a0dd61e1869b1506b7c6a3b31d62f15984b1a573d6326f6aa2f3e4123284b"}, +] + +[package.dependencies] +requests = ">=2.21.0" +six = ">=1.11.0" +typing-extensions = ">=4.6.0" + +[package.extras] +aio = ["aiohttp (>=3.0)"] + +[[package]] +name = "azure-storage-blob" +version = "12.23.1" +description = "Microsoft Azure Blob Storage Client Library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "azure_storage_blob-12.23.1-py3-none-any.whl", hash = "sha256:1c2238aa841d1545f42714a5017c010366137a44a0605da2d45f770174bfc6b4"}, + {file = "azure_storage_blob-12.23.1.tar.gz", hash = "sha256:a587e54d4e39d2a27bd75109db164ffa2058fe194061e5446c5a89bca918272f"}, +] + +[package.dependencies] +azure-core = ">=1.30.0" +cryptography = ">=2.1.4" +isodate = ">=0.6.1" +typing-extensions = ">=4.6.0" + +[package.extras] +aio = ["azure-core[aio] (>=1.30.0)"] + [[package]] name = "babel" version = "2.16.0" @@ -2280,6 +2319,17 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "isodate" +version = "0.7.2" +description = "An ISO 8601 date/time/duration parser and formatter" +optional = false +python-versions = ">=3.7" +files = [ + {file = "isodate-0.7.2-py3-none-any.whl", hash = "sha256:28009937d8031054830160fce6d409ed342816b543597cece116d966c6d99e15"}, + {file = "isodate-0.7.2.tar.gz", hash = "sha256:4cd1aa0f43ca76f4a6c6c0292a85f40b35ec2e43e315b59f06e6d32171a953e6"}, +] + [[package]] name = "jaraco-classes" version = "3.4.0" @@ -5508,29 +5558,29 @@ files = [ [[package]] name = "ruff" -version = "0.6.9" +version = "0.7.0" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.6.9-py3-none-linux_armv6l.whl", hash = "sha256:064df58d84ccc0ac0fcd63bc3090b251d90e2a372558c0f057c3f75ed73e1ccd"}, - {file = "ruff-0.6.9-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:140d4b5c9f5fc7a7b074908a78ab8d384dd7f6510402267bc76c37195c02a7ec"}, - {file = "ruff-0.6.9-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53fd8ca5e82bdee8da7f506d7b03a261f24cd43d090ea9db9a1dc59d9313914c"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645d7d8761f915e48a00d4ecc3686969761df69fb561dd914a773c1a8266e14e"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eae02b700763e3847595b9d2891488989cac00214da7f845f4bcf2989007d577"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d5ccc9e58112441de8ad4b29dcb7a86dc25c5f770e3c06a9d57e0e5eba48829"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:417b81aa1c9b60b2f8edc463c58363075412866ae4e2b9ab0f690dc1e87ac1b5"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3c866b631f5fbce896a74a6e4383407ba7507b815ccc52bcedabb6810fdb3ef7"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7b118afbb3202f5911486ad52da86d1d52305b59e7ef2031cea3425142b97d6f"}, - {file = "ruff-0.6.9-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a67267654edc23c97335586774790cde402fb6bbdb3c2314f1fc087dee320bfa"}, - {file = "ruff-0.6.9-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3ef0cc774b00fec123f635ce5c547dac263f6ee9fb9cc83437c5904183b55ceb"}, - {file = "ruff-0.6.9-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:12edd2af0c60fa61ff31cefb90aef4288ac4d372b4962c2864aeea3a1a2460c0"}, - {file = "ruff-0.6.9-py3-none-musllinux_1_2_i686.whl", hash = "sha256:55bb01caeaf3a60b2b2bba07308a02fca6ab56233302406ed5245180a05c5625"}, - {file = "ruff-0.6.9-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:925d26471fa24b0ce5a6cdfab1bb526fb4159952385f386bdcc643813d472039"}, - {file = "ruff-0.6.9-py3-none-win32.whl", hash = "sha256:eb61ec9bdb2506cffd492e05ac40e5bc6284873aceb605503d8494180d6fc84d"}, - {file = "ruff-0.6.9-py3-none-win_amd64.whl", hash = "sha256:785d31851c1ae91f45b3d8fe23b8ae4b5170089021fbb42402d811135f0b7117"}, - {file = "ruff-0.6.9-py3-none-win_arm64.whl", hash = "sha256:a9641e31476d601f83cd602608739a0840e348bda93fec9f1ee816f8b6798b93"}, - {file = "ruff-0.6.9.tar.gz", hash = "sha256:b076ef717a8e5bc819514ee1d602bbdca5b4420ae13a9cf61a0c0a4f53a2baa2"}, + {file = "ruff-0.7.0-py3-none-linux_armv6l.whl", hash = "sha256:0cdf20c2b6ff98e37df47b2b0bd3a34aaa155f59a11182c1303cce79be715628"}, + {file = "ruff-0.7.0-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:496494d350c7fdeb36ca4ef1c9f21d80d182423718782222c29b3e72b3512737"}, + {file = "ruff-0.7.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:214b88498684e20b6b2b8852c01d50f0651f3cc6118dfa113b4def9f14faaf06"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:630fce3fefe9844e91ea5bbf7ceadab4f9981f42b704fae011bb8efcaf5d84be"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:211d877674e9373d4bb0f1c80f97a0201c61bcd1e9d045b6e9726adc42c156aa"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:194d6c46c98c73949a106425ed40a576f52291c12bc21399eb8f13a0f7073495"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:82c2579b82b9973a110fab281860403b397c08c403de92de19568f32f7178598"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9af971fe85dcd5eaed8f585ddbc6bdbe8c217fb8fcf510ea6bca5bdfff56040e"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b641c7f16939b7d24b7bfc0be4102c56562a18281f84f635604e8a6989948914"}, + {file = "ruff-0.7.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d71672336e46b34e0c90a790afeac8a31954fd42872c1f6adaea1dff76fd44f9"}, + {file = "ruff-0.7.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:ab7d98c7eed355166f367597e513a6c82408df4181a937628dbec79abb2a1fe4"}, + {file = "ruff-0.7.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:1eb54986f770f49edb14f71d33312d79e00e629a57387382200b1ef12d6a4ef9"}, + {file = "ruff-0.7.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:dc452ba6f2bb9cf8726a84aa877061a2462afe9ae0ea1d411c53d226661c601d"}, + {file = "ruff-0.7.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:4b406c2dce5be9bad59f2de26139a86017a517e6bcd2688da515481c05a2cb11"}, + {file = "ruff-0.7.0-py3-none-win32.whl", hash = "sha256:f6c968509f767776f524a8430426539587d5ec5c662f6addb6aa25bc2e8195ec"}, + {file = "ruff-0.7.0-py3-none-win_amd64.whl", hash = "sha256:ff4aabfbaaba880e85d394603b9e75d32b0693152e16fa659a3064a85df7fce2"}, + {file = "ruff-0.7.0-py3-none-win_arm64.whl", hash = "sha256:10842f69c245e78d6adec7e1db0a7d9ddc2fff0621d730e61657b64fa36f207e"}, + {file = "ruff-0.7.0.tar.gz", hash = "sha256:47a86360cf62d9cd53ebfb0b5eb0e882193fc191c6d717e8bef4462bc3b9ea2b"}, ] [[package]] @@ -6507,6 +6557,11 @@ files = [ {file = "triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb"}, {file = "triton-3.0.0-1-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bcbf3b1c48af6a28011a5c40a5b3b9b5330530c3827716b5fbf6d7adcc1e53e9"}, {file = "triton-3.0.0-1-cp39-cp39-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6e5727202f7078c56f91ff13ad0c1abab14a0e7f2c87e91b12b6f64f3e8ae609"}, + {file = "triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230"}, + {file = "triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e"}, + {file = "triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253"}, + {file = "triton-3.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8903767951bf86ec960b4fe4e21bc970055afc65e9d57e916d79ae3c93665e3"}, + {file = "triton-3.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:41004fb1ae9a53fcb3e970745feb87f0e3c94c6ce1ba86e95fa3b8537894bef7"}, ] [package.dependencies] @@ -7113,7 +7168,7 @@ doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linke test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy", "pytest-ruff (>=0.2.1)"] [extras] -all = ["anthropic", "astrapy", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "exa-py", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "snowflake-sqlalchemy", "sqlalchemy", "tavily-python", "trafilatura", "transformers", "voyageai"] +all = ["anthropic", "astrapy", "azure-core", "azure-storage-blob", "beautifulsoup4", "boto3", "cohere", "diffusers", "duckduckgo-search", "elevenlabs", "exa-py", "google-generativeai", "mail-parser", "markdownify", "marqo", "ollama", "opensearch-py", "opentelemetry-api", "opentelemetry-exporter-otlp-proto-http", "opentelemetry-instrumentation", "opentelemetry-instrumentation-threading", "opentelemetry-sdk", "pandas", "pgvector", "pillow", "pinecone-client", "playwright", "psycopg2-binary", "pusher", "pymongo", "pypdf", "qdrant-client", "redis", "snowflake-sqlalchemy", "sqlalchemy", "tavily-python", "trafilatura", "transformers", "voyageai"] drivers-embedding-amazon-bedrock = ["boto3"] drivers-embedding-amazon-sagemaker = ["boto3"] drivers-embedding-cohere = ["cohere"] @@ -7124,6 +7179,8 @@ drivers-embedding-voyageai = ["voyageai"] drivers-event-listener-amazon-iot = ["boto3"] drivers-event-listener-amazon-sqs = ["boto3"] drivers-event-listener-pusher = ["pusher"] +drivers-file-manager-amazon-s3 = ["boto3"] +drivers-file-manager-griptape-cloud = ["azure-core", "azure-storage-blob"] drivers-image-generation-huggingface = ["diffusers", "pillow"] drivers-memory-conversation-amazon-dynamodb = ["boto3"] drivers-memory-conversation-redis = ["redis"] @@ -7165,4 +7222,4 @@ loaders-sql = ["sqlalchemy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "a3d496dda28ff7dee5c794cce2c38d849e00487031f9cdfd1dea684ca2c2587f" +content-hash = "05d5c24b38a0675a407077758ef162a35b96d79083d7202c9928b592ef23187e" diff --git a/pyproject.toml b/pyproject.toml index 79e060b24..69311c8ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,8 @@ opentelemetry-exporter-otlp-proto-http = {version = "^1.25.0", optional = true} diffusers = {version = "^0.30.3", optional = true} tavily-python = {version = "^0.5.0", optional = true} exa-py = {version = "^1.1.4", optional = true} +azure-core = "^1.31.0" +azure-storage-blob = "^12.23.1" # loaders pandas = {version = "^1.3", optional = true} @@ -145,6 +147,9 @@ drivers-observability-datadog = [ drivers-image-generation-huggingface = ["diffusers", "pillow"] +drivers-file-manager-amazon-s3 = ["boto3"] +drivers-file-manager-griptape-cloud = ["azure-core", "azure-storage-blob"] + loaders-pdf = ["pypdf"] loaders-image = ["pillow"] loaders-email = ["mail-parser"] @@ -188,6 +193,8 @@ all = [ "opentelemetry-exporter-otlp-proto-http", "diffusers", "pillow", + "azure-core", + "azure-storage-blob", # loaders "pandas", @@ -217,7 +224,7 @@ torch = "^2.4.1" optional = true [tool.poetry.group.dev.dependencies] -ruff = "^0.6.0" +ruff = "^0.7.0" pyright = "^1.1.376" pre-commit = "^4.0.0" boto3-stubs = {extras = ["bedrock", "iam", "opensearch", "s3", "sagemaker", "sqs", "iot-data", "dynamodb", "redshift-data"], version = "^1.34.105"} diff --git a/tests/mocks/mock_chunk_event.py b/tests/mocks/mock_chunk_event.py new file mode 100644 index 000000000..4017dcd0a --- /dev/null +++ b/tests/mocks/mock_chunk_event.py @@ -0,0 +1,11 @@ +from attrs import define, field + +from griptape.events.base_chunk_event import BaseChunkEvent + + +@define +class MockChunkEvent(BaseChunkEvent): + token: str = field(kw_only=True, metadata={"serializable": True}) + + def __str__(self) -> str: + return "mock " + self.token diff --git a/tests/unit/configs/drivers/test_drivers_config.py b/tests/unit/configs/drivers/test_drivers_config.py index 8eba0cb6a..74055f4e5 100644 --- a/tests/unit/configs/drivers/test_drivers_config.py +++ b/tests/unit/configs/drivers/test_drivers_config.py @@ -57,7 +57,7 @@ def test_context_manager(self): assert Defaults.drivers_config == old_drivers_config - @pytest.mark.skip_mock_config() + @pytest.mark.skip_mock_config def test_lazy_init(self): from griptape.configs import Defaults diff --git a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py index 2240dee58..efeb14dc6 100644 --- a/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_amazon_s3_file_manager_driver.py @@ -207,7 +207,7 @@ def test_save_file(self, workdir, path, content, driver, get_s3_value): result = driver.save_file(path, content) assert isinstance(result, InfoArtifact) - assert result.value == "Successfully saved file" + assert result.value.startswith("Successfully saved file at:") expected_s3_key = f"{workdir}/{path}".lstrip("/") content_str = content if isinstance(content, str) else content.decode() assert get_s3_value(expected_s3_key) == content_str @@ -245,7 +245,7 @@ def test_save_file_with_encoding(self, session, bucket, get_s3_value): expected_s3_key = f"{workdir}/{path}".lstrip("/") assert get_s3_value(expected_s3_key) == "foobar" - assert result.value == "Successfully saved file" + assert result.value.startswith("Successfully saved file at:") def test_save_and_load_file_with_encoding(self, session, bucket, get_s3_value): workdir = "/sub-folder" @@ -256,7 +256,7 @@ def test_save_and_load_file_with_encoding(self, session, bucket, get_s3_value): expected_s3_key = f"{workdir}/{path}".lstrip("/") assert get_s3_value(expected_s3_key) == "foobar" - assert result.value == "Successfully saved file" + assert result.value.startswith("Successfully saved file at:") driver = AmazonS3FileManagerDriver(session=session, bucket=bucket, encoding="ascii", workdir=workdir) path = "test/foobar.txt" diff --git a/tests/unit/drivers/file_manager/test_base_file_manager_driver.py b/tests/unit/drivers/file_manager/test_base_file_manager_driver.py new file mode 100644 index 000000000..41bda51df --- /dev/null +++ b/tests/unit/drivers/file_manager/test_base_file_manager_driver.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import pytest + +from griptape.artifacts import BaseArtifact, TextArtifact +from griptape.drivers import BaseFileManagerDriver + + +class MockFileManagerDriver(BaseFileManagerDriver): + def try_list_files(self, path: str) -> list[str]: + return ["foo", "bar"] + + def try_save_file(self, path: str, value: bytes) -> str: + assert path == "foo" + assert BaseArtifact.from_json(value.decode()).value == TextArtifact(value="value").value + + return "mock_save_location" + + def try_load_file(self, path: str) -> bytes: + assert path == "foo" + + return TextArtifact(value="value").to_json().encode() + + +class TestBaseFileManagerDriver: + @pytest.fixture() + def driver(self): + return MockFileManagerDriver(workdir="/") + + def test_load_artifact(self, driver): + response = driver.load_artifact("foo") + + assert response.value == "value" + + def test_save_artifact(self, driver): + response = driver.save_artifact("foo", TextArtifact(value="value")) + + assert response.value == "Successfully saved artifact at: mock_save_location" diff --git a/tests/unit/drivers/file_manager/test_griptape_cloud_file_manager_driver.py b/tests/unit/drivers/file_manager/test_griptape_cloud_file_manager_driver.py new file mode 100644 index 000000000..0ce837dc1 --- /dev/null +++ b/tests/unit/drivers/file_manager/test_griptape_cloud_file_manager_driver.py @@ -0,0 +1,192 @@ +from unittest import mock + +import pytest +import requests +from azure.core.exceptions import ResourceNotFoundError + + +class TestGriptapeCloudFileManagerDriver: + @pytest.fixture() + def driver(self, mocker): + from griptape.drivers import GriptapeCloudFileManagerDriver + + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + mocker.patch("requests.request", return_value=mock_response) + + return GriptapeCloudFileManagerDriver(base_url="https://api.griptape.ai", api_key="foo bar", bucket_id="1") + + def test_instantiate_bucket_id(self, mocker): + from griptape.drivers import GriptapeCloudFileManagerDriver + + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {} + mocker.patch("requests.request", return_value=mock_response) + + GriptapeCloudFileManagerDriver(base_url="https://api.griptape.ai", api_key="foo bar", bucket_id="1") + + def test_instantiate_no_bucket_id(self): + from griptape.drivers import GriptapeCloudFileManagerDriver + + with pytest.raises(ValueError, match="GriptapeCloudFileManagerDriver requires an Bucket ID"): + GriptapeCloudFileManagerDriver(api_key="foo bar") + + def test_instantiate_bucket_not_found(self, mocker): + from griptape.drivers import GriptapeCloudFileManagerDriver + + mocker.patch("requests.request", side_effect=requests.exceptions.HTTPError(response=mock.Mock(status_code=404))) + + with pytest.raises(ValueError, match="No Bucket found with ID: 1"): + return GriptapeCloudFileManagerDriver(api_key="foo bar", bucket_id="1") + + def test_instantiate_bucket_500(self, mocker): + from griptape.drivers import GriptapeCloudFileManagerDriver + + mocker.patch("requests.request", side_effect=requests.exceptions.HTTPError(response=mock.Mock(status_code=500))) + + with pytest.raises(ValueError, match="Unexpected error when retrieving Bucket with ID: 1"): + return GriptapeCloudFileManagerDriver(api_key="foo bar", bucket_id="1") + + def test_instantiate_no_api_key(self): + from griptape.drivers import GriptapeCloudFileManagerDriver + + with pytest.raises(ValueError, match="GriptapeCloudFileManagerDriver requires an API key"): + GriptapeCloudFileManagerDriver(bucket_id="1") + + def test_instantiate_invalid_work_dir(self): + from griptape.drivers import GriptapeCloudFileManagerDriver + + with pytest.raises( + ValueError, + match="GriptapeCloudFileManagerDriver requires 'workdir' to be an absolute path, starting with `/`", + ): + GriptapeCloudFileManagerDriver(api_key="foo bar", bucket_id="1", workdir="no_slash") + + def test_try_list_files(self, mocker, driver): + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"assets": [{"name": "foo/bar.pdf"}, {"name": "foo/baz.pdf"}]} + mocker.patch("requests.request", return_value=mock_response) + + files = driver.try_list_files("foo/") + + assert len(files) == 2 + assert files[0] == "foo/bar.pdf" + assert files[1] == "foo/baz.pdf" + + def test_try_list_files_postfix(self, mocker, driver): + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"assets": [{"name": "foo/bar.pdf"}, {"name": "foo/baz.pdf"}]} + mocker.patch("requests.request", return_value=mock_response) + + files = driver.try_list_files("foo/", ".pdf") + + assert len(files) == 2 + assert files[0] == "foo/bar.pdf" + assert files[1] == "foo/baz.pdf" + + def test_try_list_files_not_directory(self, mocker, driver): + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"assets": [{"name": "foo/bar"}, {"name": "foo/baz"}]} + mocker.patch("requests.request", return_value=mock_response) + + with pytest.raises(NotADirectoryError): + driver.try_list_files("foo") + + def test_try_load_file(self, mocker, driver): + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"url": "https://foo.bar"} + mocker.patch("requests.request", return_value=mock_response) + + mock_bytes = b"bytes" + mock_blob_client = mocker.Mock() + mock_blob_client.download_blob.return_value.readall.return_value = mock_bytes + mocker.patch("azure.storage.blob.BlobClient.from_blob_url", return_value=mock_blob_client) + + response = driver.try_load_file("foo") + + assert response == mock_bytes + + def test_try_load_file_directory(self, mocker, driver): + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"url": "https://foo.bar"} + mocker.patch("requests.request", return_value=mock_response) + + with pytest.raises(IsADirectoryError): + driver.try_load_file("foo/") + + def test_try_load_file_sas_404(self, mocker, driver): + mocker.patch("requests.request", side_effect=requests.exceptions.HTTPError(response=mock.Mock(status_code=404))) + + with pytest.raises(FileNotFoundError): + driver.try_load_file("foo") + + def test_try_load_file_sas_500(self, mocker, driver): + mocker.patch("requests.request", side_effect=requests.exceptions.HTTPError(response=mock.Mock(status_code=500))) + + with pytest.raises(requests.exceptions.HTTPError): + driver.try_load_file("foo") + + def test_try_load_file_blob_404(self, mocker, driver): + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"url": "https://foo.bar"} + mocker.patch("requests.request", return_value=mock_response) + + mock_blob_client = mocker.Mock() + mock_blob_client.download_blob.side_effect = ResourceNotFoundError() + mocker.patch("azure.storage.blob.BlobClient.from_blob_url", return_value=mock_blob_client) + + with pytest.raises(FileNotFoundError): + driver.try_load_file("foo") + + def test_try_save_files(self, mocker, driver): + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"url": "https://foo.bar"} + mocker.patch("requests.request", return_value=mock_response) + + mock_blob_client = mocker.Mock() + mocker.patch("azure.storage.blob.BlobClient.from_blob_url", return_value=mock_blob_client) + + response = driver.try_save_file("foo", b"value") + + assert response == "buckets/1/assets/foo" + + def test_try_save_file_directory(self, mocker, driver): + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = {"url": "https://foo.bar"} + mocker.patch("requests.request", return_value=mock_response) + + with pytest.raises(IsADirectoryError): + driver.try_save_file("foo/", b"value") + + def test_try_save_file_sas_404(self, mocker, driver): + mock_response = mocker.Mock() + mock_response.json.return_value = {"url": "https://foo.bar"} + mock_response.raise_for_status.side_effect = [ + requests.exceptions.HTTPError(response=mock.Mock(status_code=404)), + None, + None, + ] + mocker.patch("requests.request", return_value=mock_response) + + mock_blob_client = mocker.Mock() + mocker.patch("azure.storage.blob.BlobClient.from_blob_url", return_value=mock_blob_client) + + response = driver.try_save_file("foo", b"value") + + assert response == "buckets/1/assets/foo" + + def test_try_save_file_sas_500(self, mocker, driver): + mocker.patch("requests.request", side_effect=requests.exceptions.HTTPError(response=mock.Mock(status_code=500))) + + with pytest.raises(requests.exceptions.HTTPError): + driver.try_save_file("foo", b"value") diff --git a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py index 99f0285bc..b772941b8 100644 --- a/tests/unit/drivers/file_manager/test_local_file_manager_driver.py +++ b/tests/unit/drivers/file_manager/test_local_file_manager_driver.py @@ -176,7 +176,7 @@ def test_save_file(self, workdir, path, content, temp_dir, driver): result = driver.save_file(path, content) assert isinstance(result, InfoArtifact) - assert result.value == "Successfully saved file" + assert result.value.startswith("Successfully saved file at:") content_bytes = content if isinstance(content, str) else content.decode() assert Path(driver.workdir, path).read_text() == content_bytes @@ -210,14 +210,14 @@ def test_save_file_with_encoding(self, temp_dir): result = driver.save_file(os.path.join("test", "foobar.txt"), "foobar") assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" - assert result.value == "Successfully saved file" + assert result.value.startswith("Successfully saved file at:") def test_save_and_load_file_with_encoding(self, temp_dir): driver = LocalFileManagerDriver(encoding="ascii", workdir=temp_dir) result = driver.save_file(os.path.join("test", "foobar.txt"), "foobar") assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" - assert result.value == "Successfully saved file" + assert result.value.startswith("Successfully saved file at:") driver = LocalFileManagerDriver(encoding="ascii", workdir=temp_dir) result = driver.load_file(os.path.join("test", "foobar.txt")) diff --git a/tests/unit/events/test_action_chunk_event.py b/tests/unit/events/test_action_chunk_event.py new file mode 100644 index 000000000..6c242a475 --- /dev/null +++ b/tests/unit/events/test_action_chunk_event.py @@ -0,0 +1,38 @@ +import pytest + +from griptape.events import ActionChunkEvent + + +class TestCompletionChunkEvent: + TEST_PARAMS = [ + {"name": "foo", "tag": None, "path": None, "partial_input": None}, + {"name": "foo", "tag": "bar", "path": None, "partial_input": None}, + {"name": "foo", "tag": "bar", "path": "baz", "partial_input": None}, + {"name": "foo", "tag": None, "path": "baz", "partial_input": None}, + {"name": "foo", "tag": "bar", "path": "baz", "partial_input": "qux"}, + {"name": None, "tag": None, "path": None, "partial_input": "qux"}, + ] + + @pytest.fixture() + def action_chunk_event(self): + return ActionChunkEvent( + partial_input="foo bar", + tag="foo", + name="bar", + path="baz", + ) + + def test_token(self, action_chunk_event): + assert action_chunk_event.partial_input == "foo bar" + assert action_chunk_event.index == 0 + assert action_chunk_event.tag == "foo" + assert action_chunk_event.name == "bar" + assert action_chunk_event.path == "baz" + + def test_to_dict(self, action_chunk_event): + assert action_chunk_event.to_dict()["partial_input"] == "foo bar" + + @pytest.mark.parametrize("params", TEST_PARAMS) + def test_str(self, params): + event = ActionChunkEvent(**params) + assert str(event) == event.__str__() diff --git a/tests/unit/events/test_base_chunk_event.py b/tests/unit/events/test_base_chunk_event.py new file mode 100644 index 000000000..80cedf353 --- /dev/null +++ b/tests/unit/events/test_base_chunk_event.py @@ -0,0 +1,18 @@ +import pytest + +from tests.mocks.mock_chunk_event import MockChunkEvent + + +class TestBaseChunkEvent: + @pytest.fixture() + def base_chunk_event(self): + return MockChunkEvent(token="foo", index=1) + + def test_token(self, base_chunk_event): + assert base_chunk_event.index == 1 + assert base_chunk_event.token == "foo" + assert str(base_chunk_event) == "mock foo" + + def test_to_dict(self, base_chunk_event): + assert base_chunk_event.to_dict()["index"] == 1 + assert base_chunk_event.to_dict()["token"] == "foo" diff --git a/tests/unit/events/test_base_event.py b/tests/unit/events/test_base_event.py index 6ce010ee9..58535eaac 100644 --- a/tests/unit/events/test_base_event.py +++ b/tests/unit/events/test_base_event.py @@ -4,8 +4,8 @@ from griptape.artifacts.base_artifact import BaseArtifact from griptape.events import ( + ActionChunkEvent, BaseEvent, - CompletionChunkEvent, FinishActionsSubtaskEvent, FinishPromptEvent, FinishStructureRunEvent, @@ -14,6 +14,7 @@ StartPromptEvent, StartStructureRunEvent, StartTaskEvent, + TextChunkEvent, ) from tests.mocks.mock_event import MockEvent @@ -244,15 +245,38 @@ def test_finish_structure_run_event_from_dict(self): assert event.output_task_output.value == "bar" assert event.meta == {"foo": "bar"} - def test_completion_chunk_event_from_dict(self): - dict_value = {"type": "CompletionChunkEvent", "timestamp": 123.0, "token": "foo", "meta": {}} + def test_text_chunk_event_from_dict(self): + dict_value = {"type": "TextChunkEvent", "timestamp": 123.0, "token": "foo", "index": 0, "meta": {}} event = BaseEvent.from_dict(dict_value) - assert isinstance(event, CompletionChunkEvent) + assert isinstance(event, TextChunkEvent) + assert event.index == 0 assert event.token == "foo" assert event.meta == {} + def test_action_chunk_event_from_dict(self): + dict_value = { + "type": "ActionChunkEvent", + "timestamp": 123.0, + "partial_input": "foo", + "tag": None, + "index": 1, + "name": "bar", + "path": "foobar", + "meta": {}, + } + + event = BaseEvent.from_dict(dict_value) + + assert isinstance(event, ActionChunkEvent) + assert event.partial_input == "foo" + assert event.tag is None + assert event.index == 1 + assert event.name == "bar" + assert event.path == "foobar" + assert event.meta == {} + def test_unsupported_from_dict(self): dict_value = {"type": "foo", "value": "foobar"} with pytest.raises(ValueError): diff --git a/tests/unit/events/test_completion_chunk_event.py b/tests/unit/events/test_completion_chunk_event.py deleted file mode 100644 index 943ea483f..000000000 --- a/tests/unit/events/test_completion_chunk_event.py +++ /dev/null @@ -1,15 +0,0 @@ -import pytest - -from griptape.events import CompletionChunkEvent - - -class TestCompletionChunkEvent: - @pytest.fixture() - def completion_chunk_event(self): - return CompletionChunkEvent(token="foo bar") - - def test_token(self, completion_chunk_event): - assert completion_chunk_event.token == "foo bar" - - def test_to_dict(self, completion_chunk_event): - assert completion_chunk_event.to_dict()["token"] == "foo bar" diff --git a/tests/unit/events/test_event_listener.py b/tests/unit/events/test_event_listener.py index a6c7e2919..b3aee2891 100644 --- a/tests/unit/events/test_event_listener.py +++ b/tests/unit/events/test_event_listener.py @@ -3,7 +3,8 @@ import pytest from griptape.events import ( - CompletionChunkEvent, + ActionChunkEvent, + BaseChunkEvent, EventBus, EventListener, FinishActionsSubtaskEvent, @@ -14,6 +15,7 @@ StartPromptEvent, StartStructureRunEvent, StartTaskEvent, + TextChunkEvent, ) from griptape.events.base_event import BaseEvent from griptape.structures import Pipeline @@ -27,7 +29,7 @@ class TestEventListener: @pytest.fixture() def pipeline(self, mock_config): - mock_config.drivers_config.prompt_driver = MockPromptDriver(stream=True) + mock_config.drivers_config.prompt_driver = MockPromptDriver(stream=True, use_native_tools=True) task = ToolkitTask("test", tools=[MockTool(name="Tool1")]) pipeline = Pipeline() @@ -47,8 +49,8 @@ def test_untyped_listeners(self, pipeline, mock_config): pipeline.tasks[0].subtasks[0].after_run() pipeline.run() - assert event_handler_1.call_count == 9 - assert event_handler_2.call_count == 9 + assert event_handler_1.call_count == 10 + assert event_handler_2.call_count == 10 def test_typed_listeners(self, pipeline, mock_config): start_prompt_event_handler = Mock() @@ -59,7 +61,9 @@ def test_typed_listeners(self, pipeline, mock_config): finish_subtask_event_handler = Mock() start_structure_run_event_handler = Mock() finish_structure_run_event_handler = Mock() - completion_chunk_handler = Mock() + base_chunk_handler = Mock() + text_chunk_handler = Mock() + action_chunk_handler = Mock() EventBus.add_event_listeners( [ @@ -71,7 +75,9 @@ def test_typed_listeners(self, pipeline, mock_config): EventListener(finish_subtask_event_handler, event_types=[FinishActionsSubtaskEvent]), EventListener(start_structure_run_event_handler, event_types=[StartStructureRunEvent]), EventListener(finish_structure_run_event_handler, event_types=[FinishStructureRunEvent]), - EventListener(completion_chunk_handler, event_types=[CompletionChunkEvent]), + EventListener(base_chunk_handler, event_types=[BaseChunkEvent]), + EventListener(text_chunk_handler, event_types=[TextChunkEvent]), + EventListener(action_chunk_handler, event_types=[ActionChunkEvent]), ] ) @@ -88,7 +94,12 @@ def test_typed_listeners(self, pipeline, mock_config): finish_subtask_event_handler.assert_called_once() start_structure_run_event_handler.assert_called_once() finish_structure_run_event_handler.assert_called_once() - completion_chunk_handler.assert_called_once() + assert base_chunk_handler.call_count == 2 + assert action_chunk_handler.call_count == 2 + + pipeline.tasks[0].prompt_driver.use_native_tools = False + pipeline.run() + text_chunk_handler.assert_called_once() def test_add_remove_event_listener(self, pipeline): EventBus.clear_event_listeners() diff --git a/tests/unit/events/test_text_chunk_event.py b/tests/unit/events/test_text_chunk_event.py new file mode 100644 index 000000000..582de3ca1 --- /dev/null +++ b/tests/unit/events/test_text_chunk_event.py @@ -0,0 +1,16 @@ +import pytest + +from griptape.events import TextChunkEvent + + +class TestCompletionChunkEvent: + @pytest.fixture() + def text_chunk_event(self): + return TextChunkEvent(token="foo bar") + + def test_token(self, text_chunk_event): + assert text_chunk_event.token == "foo bar" + assert str(text_chunk_event) == "foo bar" + + def test_to_dict(self, text_chunk_event): + assert text_chunk_event.to_dict()["token"] == "foo bar" diff --git a/tests/unit/tools/test_file_manager.py b/tests/unit/tools/test_file_manager.py index 4e035bdee..9cbfe6859 100644 --- a/tests/unit/tools/test_file_manager.py +++ b/tests/unit/tools/test_file_manager.py @@ -111,7 +111,7 @@ def test_save_content_to_file(self, temp_dir): ) assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" - assert result.value == "Successfully saved file" + assert result.value.startswith("Successfully saved file at:") def test_save_content_to_file_with_encoding(self, temp_dir): file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="utf-8", workdir=temp_dir)) @@ -120,7 +120,7 @@ def test_save_content_to_file_with_encoding(self, temp_dir): ) assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" - assert result.value == "Successfully saved file" + assert result.value.startswith("Successfully saved file at:") def test_save_and_load_content_to_file_with_encoding(self, temp_dir): file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="ascii", workdir=temp_dir)) @@ -129,7 +129,7 @@ def test_save_and_load_content_to_file_with_encoding(self, temp_dir): ) assert Path(os.path.join(temp_dir, "test", "foobar.txt")).read_text() == "foobar" - assert result.value == "Successfully saved file" + assert result.value.startswith("Successfully saved file at:") file_manager = FileManagerTool(file_manager_driver=LocalFileManagerDriver(encoding="ascii", workdir=temp_dir)) result = file_manager.load_files_from_disk({"values": {"paths": [os.path.join("test", "foobar.txt")]}}) diff --git a/tests/unit/utils/test_stream.py b/tests/unit/utils/test_stream.py index caddbb1a3..e16403a06 100644 --- a/tests/unit/utils/test_stream.py +++ b/tests/unit/utils/test_stream.py @@ -1,15 +1,21 @@ +import json from collections.abc import Iterator import pytest from griptape.structures import Agent, Pipeline from griptape.utils import Stream +from tests.mocks.mock_prompt_driver import MockPromptDriver +from tests.mocks.mock_tool.tool import MockTool class TestStream: @pytest.fixture(params=[True, False]) def agent(self, request): - return Agent(stream=request.param) + driver = MockPromptDriver( + use_native_tools=request.param, + ) + return Agent(stream=request.param, tools=[MockTool()], prompt_driver=driver) def test_init(self, agent): if agent.stream: @@ -18,9 +24,10 @@ def test_init(self, agent): assert chat_stream.structure == agent chat_stream_run = chat_stream.run() assert isinstance(chat_stream_run, Iterator) - chat_stream_artifact = next(chat_stream_run) - assert chat_stream_artifact.value == "mock output" - + assert next(chat_stream_run).value == "MockTool.mock-tag (test)" + assert next(chat_stream_run).value == json.dumps({"values": {"test": "test-value"}}, indent=2) + next(chat_stream_run) + assert next(chat_stream_run).value == "Answer: mock output" next(chat_stream_run) with pytest.raises(StopIteration): next(chat_stream_run)