Skip to content

Commit

Permalink
Merge branch 'dev' into fix/toolkit
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 21, 2024
2 parents d99b7cc + dab865a commit b5025d8
Show file tree
Hide file tree
Showing 43 changed files with 959 additions and 100 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/docs-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ jobs:
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.INTEG_ASTRA_DB_APPLICATION_TOKEN }}
TAVILY_API_KEY: ${{ secrets.INTEG_TAVILY_API_KEY }}
EXA_API_KEY: ${{ secrets.INTEG_EXA_API_KEY }}
AMAZON_S3_BUCKET: ${{ secrets.INTEG_AMAZON_S3_BUCKET }}
AMAZON_S3_KEY: ${{ secrets.INTEG_AMAZON_S3_KEY }}
GT_CLOUD_BUCKET_ID: ${{ secrets.INTEG_GT_CLOUD_BUCKET_ID }}
GT_CLOUD_ASSET_NAME: ${{ secrets.INTEG_GT_CLOUD_ASSET_NAME }}

services:
postgres:
image: ankane/pgvector:v0.5.0
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing.
- `BaseTask.task_outputs` to get a dictionary of all task outputs. This has been added to `Workflow.context` and `Pipeline.context`.
- `Chat.input_fn` for customizing the input to the Chat utility.
- `GriptapeCloudFileManagerDriver` for managing files on Griptape Cloud.
- `BaseFileManagerDriver.load_artifact()` & `BaseFileManagerDriver.save_artifact()` for loading & saving artifacts as files.
- Events `BaseChunkEvent`, `TextChunkEvent`, `ActionChunkEvent`.

### Changed

- **BREAKING**: Removed `BaseEventListener.publish_event` `flush` argument. Use `BaseEventListener.flush_events()` instead.
- **BREAKING**: Renamed parameter `driver` on `EventListener` to `event_listener_driver`.
- **BREAKING**: Changed default value of parameter `handler` on `EventListener` to `None`.
- **BREAKING**: Updated `EventListener.handler` return value behavior.
- **BREAKING**: Removed `CompletionChunkEvent`.
- If `EventListener.handler` returns `None`, the event will not be published to the `event_listener_driver`.
- If `EventListener.handler` is None, the event will be published to the `event_listener_driver` as-is.
- Updated `EventListener.handler` return type to `Optional[BaseEvent | dict]`.
Expand All @@ -42,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `@activity` decorated functions can now accept kwargs that are defined in the activity schema.
- Updated `ToolkitTask` system prompt to no longer mention `memory_name` and `artifact_namespace`.
- Models in `ToolkitTask` with native tool calling no longer need to provide their final answer as `Answer:`.
- `EventListener.event_types` will now listen on child types of any provided type.

### Fixed

Expand Down
40 changes: 40 additions & 0 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,46 @@ This document provides instructions for migrating your codebase to accommodate b

## 0.33.X to 0.34.X

### Removed `CompletionChunkEvent`

`CompletionChunkEvent` has been removed. There is now `BaseChunkEvent` with children `TextChunkEvent` and `ActionChunkEvent`. `BaseChunkEvent` can replace `completion_chunk_event.token` by doing `str(base_chunk_event)`.

#### Before

```python
def handler_fn_stream(event: CompletionChunkEvent) -> None:
print(f"CompletionChunkEvent: {event.to_json()}")

def handler_fn_stream_text(event: CompletionChunkEvent) -> None:
# This prints out Tool actions with no easy way
# to filter them out
print(event.token, end="", flush=True)

EventListener(handler=handler_fn_stream, event_types=[CompletionChunkEvent])
EventListener(handler=handler_fn_stream_text, event_types=[CompletionChunkEvent])
```

#### After

```python
def handler_fn_stream(event: BaseChunkEvent) -> None:
print(str(e), end="", flush=True)
# print out each child event type
if isinstance(event, TextChunkEvent):
print(f"TextChunkEvent: {event.to_json()}")
if isinstance(event, ActionChunkEvent):
print(f"ActionChunkEvent: {event.to_json()}")


def handler_fn_stream_text(event: TextChunkEvent) -> None:
# This will only be text coming from the
# prompt driver, not Tool actions
print(event.token, end="", flush=True)

EventListener(handler=handler_fn_stream, event_types=[BaseChunkEvent])
EventListener(handler=handler_fn_stream_text, event_types=[TextChunkEvent])
```

### `EventListener.handler` behavior, `driver` parameter rename

Returning `None` from the `handler` function now causes the event to not be published to the `EventListenerDriver`.
Expand Down
4 changes: 4 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ test: test/unit test/integration
test/unit: ## Run unit tests.
@poetry run pytest -n auto tests/unit

.PHONY: test/unit/%
test/unit/%: ## Run specific unit tests.
@poetry run pytest -n auto tests/unit -k $*

.PHONY: test/unit/coverage
test/unit/coverage:
@poetry run pytest -n auto --cov=griptape tests/unit
Expand Down
48 changes: 48 additions & 0 deletions docs/griptape-framework/drivers/file-manager-drivers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
---
search:
boost: 2
---

## Overview

File Manager Drivers can be used to load and save files with local or external file systems.

You can use File Manager Drivers with Loaders:

```python
--8<-- "docs/griptape-framework/drivers/src/file_manager_driver.py"
```

Or use them independently as shown below for each driver:

## File Manager Drivers

### Griptape Cloud

!!! info
This driver requires the `drivers-file-manager-griptape-cloud` [extra](../index.md#extras).

The [GriptapeCloudFileManagerDriver](../../reference/griptape/drivers/file_manager/griptape_cloud_file_manager_driver.md) allows you to load and save files sourced from Griptape Cloud Asset and Bucket resources.

```python
--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_file_manager_driver.py"
```

### Local

The [LocalFileManagerDriver](../../reference/griptape/drivers/file_manager/local_file_manager_driver.md) allows you to load and save files sourced from a local directory.

```python
--8<-- "docs/griptape-framework/drivers/src/local_file_manager_driver.py"
```

### Amazon S3

!!! info
This driver requires the `drivers-file-manager-amazon-s3` [extra](../index.md#extras).

The [LocalFile ManagerDriver](../../reference/griptape/drivers/file_manager/amazon_s3_file_manager_driver.md) allows you to load and save files sourced from an Amazon S3 bucket.

```python
--8<-- "docs/griptape-framework/drivers/src/amazon_s3_file_manager_driver.py"
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

import boto3

from griptape.drivers import AmazonS3FileManagerDriver

amazon_s3_file_manager_driver = AmazonS3FileManagerDriver(
bucket=os.environ["AMAZON_S3_BUCKET"],
session=boto3.Session(
region_name=os.environ["AWS_DEFAULT_REGION"],
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
),
)

# Download File
file_contents = amazon_s3_file_manager_driver.load_file(os.environ["AMAZON_S3_KEY"])

print(file_contents)

# Upload File
response = amazon_s3_file_manager_driver.save_file(os.environ["AMAZON_S3_KEY"], file_contents.value)

print(response)
9 changes: 9 additions & 0 deletions docs/griptape-framework/drivers/src/file_manager_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from griptape.drivers import LocalFileManagerDriver
from griptape.loaders import TextLoader

local_file_manager_driver = LocalFileManagerDriver()

loader = TextLoader(file_manager_driver=local_file_manager_driver)
text_artifact = loader.load("tests/resources/test.txt")

print(text_artifact.value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

from griptape.drivers import GriptapeCloudFileManagerDriver

gtc_file_manager_driver = GriptapeCloudFileManagerDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
bucket_id=os.environ["GT_CLOUD_BUCKET_ID"],
)

# Download File
file_contents = gtc_file_manager_driver.load_file(os.environ["GT_CLOUD_ASSET_NAME"])

print(file_contents)

# Upload File
response = gtc_file_manager_driver.save_file(os.environ["GT_CLOUD_ASSET_NAME"], file_contents.value)

print(response)
13 changes: 13 additions & 0 deletions docs/griptape-framework/drivers/src/local_file_manager_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from griptape.drivers import LocalFileManagerDriver

local_file_manager_driver = LocalFileManagerDriver()

# Download File
file_contents = local_file_manager_driver.load_file("tests/resources/test.txt")

print(file_contents)

# Upload File
response = local_file_manager_driver.save_file("tests/resources/test.txt", file_contents.value)

print(response)
12 changes: 9 additions & 3 deletions docs/griptape-framework/misc/events.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,20 @@ The `EventListener` will automatically be added and removed from the [EventBus](

## Streaming

You can use the [CompletionChunkEvent](../../reference/griptape/events/completion_chunk_event.md) to stream the completion results from Prompt Drivers.
You can use the [BaseChunkEvent](../../reference/griptape/events/base_chunk_event.md) to stream the completion results from Prompt Drivers.

```python
--8<-- "docs/griptape-framework/misc/src/events_3.py"
```

You can also use the [Stream](../../reference/griptape/utils/stream.md) utility to automatically wrap
[CompletionChunkEvent](../../reference/griptape/events/completion_chunk_event.md)s in a Python iterator.
You can also use the [TextChunkEvent](../../reference/griptape/events/text_chunk_event.md) and [ActionChunkEvent](../../reference/griptape/events/action_chunk_event.md) to further differentiate the different types of chunks for more customized output.

```python
--8<-- "docs/griptape-framework/misc/src/events_chunk_stream.py"
```

If you want Griptape to handle the chunk events for you, use the [Stream](../../reference/griptape/utils/stream.md) utility to automatically wrap
[BaseChunkEvent](../../reference/griptape/events/base_chunk_event.md)s in a Python iterator.

```python
--8<-- "docs/griptape-framework/misc/src/events_4.py"
Expand Down
10 changes: 4 additions & 6 deletions docs/griptape-framework/misc/src/events_3.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
from typing import cast

from griptape.drivers import OpenAiChatPromptDriver
from griptape.events import CompletionChunkEvent, EventBus, EventListener
from griptape.events import BaseChunkEvent, EventBus, EventListener
from griptape.structures import Pipeline
from griptape.tasks import ToolkitTask
from griptape.tools import PromptSummaryTool, WebScraperTool

EventBus.add_event_listeners(
[
EventListener(
lambda e: print(cast(CompletionChunkEvent, e).token, end="", flush=True),
event_types=[CompletionChunkEvent],
)
lambda e: print(str(e), end="", flush=True),
event_types=[BaseChunkEvent],
),
]
)

Expand Down
29 changes: 29 additions & 0 deletions docs/griptape-framework/misc/src/events_chunk_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from griptape.drivers import OpenAiChatPromptDriver
from griptape.events import ActionChunkEvent, EventBus, EventListener, TextChunkEvent
from griptape.structures import Pipeline
from griptape.tasks import ToolkitTask
from griptape.tools import PromptSummaryTool, WebScraperTool

EventBus.add_event_listeners(
[
EventListener(
lambda e: print(str(e), end="", flush=True),
event_types=[TextChunkEvent],
),
EventListener(
lambda e: print(str(e), end="", flush=True),
event_types=[ActionChunkEvent],
),
]
)

pipeline = Pipeline()
pipeline.add_tasks(
ToolkitTask(
"Based on https://griptape.ai, tell me what griptape is.",
prompt_driver=OpenAiChatPromptDriver(model="gpt-4o", stream=True),
tools=[WebScraperTool(off_prompt=True), PromptSummaryTool(off_prompt=False)],
)
)

pipeline.run()
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
from .file_manager.base_file_manager_driver import BaseFileManagerDriver
from .file_manager.local_file_manager_driver import LocalFileManagerDriver
from .file_manager.amazon_s3_file_manager_driver import AmazonS3FileManagerDriver
from .file_manager.griptape_cloud_file_manager_driver import GriptapeCloudFileManagerDriver

from .rerank.base_rerank_driver import BaseRerankDriver
from .rerank.cohere_rerank_driver import CohereRerankDriver
Expand Down Expand Up @@ -230,6 +231,7 @@
"BaseFileManagerDriver",
"LocalFileManagerDriver",
"AmazonS3FileManagerDriver",
"GriptapeCloudFileManagerDriver",
"BaseRerankDriver",
"CohereRerankDriver",
"BaseRulesetDriver",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

from attrs import define, field

Expand All @@ -21,7 +21,7 @@ class PusherEventListenerDriver(BaseEventListenerDriver):
channel: str = field(kw_only=True, metadata={"serializable": True})
event_name: str = field(kw_only=True, metadata={"serializable": True})
ssl: bool = field(default=True, kw_only=True, metadata={"serializable": True})
_client: Pusher = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})
_client: Optional[Pusher] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
def client(self) -> Pusher:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ def try_load_file(self, path: str) -> bytes:
raise FileNotFoundError from e
raise e

def try_save_file(self, path: str, value: bytes) -> None:
def try_save_file(self, path: str, value: bytes) -> str:
full_key = self._to_full_key(path)
if self._is_a_directory(full_key):
raise IsADirectoryError
self.client.put_object(Bucket=self.bucket, Key=full_key, Body=value)
return f"s3://{self.bucket}/{full_key}"

def _to_full_key(self, path: str) -> str:
path = path.lstrip("/")
Expand Down
22 changes: 18 additions & 4 deletions griptape/drivers/file_manager/base_file_manager_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from attrs import define, field

from griptape.artifacts import BlobArtifact, InfoArtifact, TextArtifact
from griptape.artifacts import BaseArtifact, BlobArtifact, InfoArtifact, TextArtifact


@define
Expand Down Expand Up @@ -42,9 +42,23 @@ def save_file(self, path: str, value: bytes | str) -> InfoArtifact:
elif isinstance(value, (bytearray, memoryview)):
raise ValueError(f"Unsupported type: {type(value)}")

self.try_save_file(path, value)
location = self.try_save_file(path, value)

return InfoArtifact("Successfully saved file")
return InfoArtifact(f"Successfully saved file at: {location}")

@abstractmethod
def try_save_file(self, path: str, value: bytes) -> None: ...
def try_save_file(self, path: str, value: bytes) -> str: ...

def load_artifact(self, path: str) -> BaseArtifact:
response = self.try_load_file(path)
return BaseArtifact.from_json(
response.decode() if self.encoding is None else response.decode(encoding=self.encoding)
)

def save_artifact(self, path: str, artifact: BaseArtifact) -> InfoArtifact:
artifact_json = artifact.to_json()
value = artifact_json.encode() if self.encoding is None else artifact_json.encode(encoding=self.encoding)

location = self.try_save_file(path, value)

return InfoArtifact(f"Successfully saved artifact at: {location}")
Loading

0 comments on commit b5025d8

Please sign in to comment.